diff --git a/CLAUDE.md b/CLAUDE.md index aae2b9f..28dae7a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -449,5 +449,145 @@ The deprecated APIs still work but generate warnings. For production code: - Test thoroughly as synchronization behavior may differ slightly - Legacy code will continue to function until removed in future versions +## TensorFlow TPU Implementation + +The original PyTorch implementation has been converted to TensorFlow for optimal performance on TPU v5e-8 environments, particularly for the Brain-to-Text '25 Competition on Kaggle. + +### Key TensorFlow Components (`model_training_nnn_tpu/`) + +#### Core Files +- **`rnn_model_tf.py`**: TensorFlow implementation of TripleGRUDecoder architecture + - `NoiseModel`: 2-layer GRU for noise estimation with day-specific layers + - `CleanSpeechModel`: 3-layer GRU for clean speech recognition with day-specific layers + - `NoisySpeechModel`: 2-layer GRU for noisy speech recognition (no day layers) + - `TripleGRUDecoder`: Main adversarial architecture combining all three models + - `CTCLoss`: Custom CTC loss implementation for TPU compatibility + - `create_tpu_strategy()`: Enhanced TPU connection function with robust environment detection + +- **`trainer_tf.py`**: TensorFlow training pipeline with distributed TPU support +- **`dataset_tf.py`**: TensorFlow data loading with augmentation pipeline optimized for TPU +- **`train_model_tf.py`**: Main training script entry point +- **`evaluate_model_tf.py`**: Evaluation pipeline for model performance analysis + +### TPU v5e-8 Specific Optimizations + +#### 1. Enhanced TPU Connection +The `create_tpu_strategy()` function provides robust TPU detection across different environments: + +```python +def create_tpu_strategy(): + """Create TPU strategy for distributed training on TPU v5e-8""" + # Multi-environment TPU detection + if 'COLAB_TPU_ADDR' in os.environ: + tpu_address = os.environ['COLAB_TPU_ADDR'] + elif 'TPU_NAME' in os.environ: + tpu_name = os.environ['TPU_NAME'] + elif 'TPU_WORKER_ID' in os.environ: + # Kaggle TPU environment + tpu_address = f'grpc://10.0.0.2:8470' # Default Kaggle TPU address + + # Enhanced error handling and debugging output + # Fallback to default strategy if TPU connection fails +``` + +**Environment Variables Detected**: +- `COLAB_TPU_ADDR`: Google Colab TPU environment +- `TPU_NAME`: Generic TPU name specification +- `TPU_WORKER_ID`: Kaggle TPU environment indicator + +**Troubleshooting TPU Connection Issues**: +- Error: "Failed to initialize TPU: Please provide a TPU Name to connect to." +- Solution: The function automatically detects and uses appropriate TPU addresses based on environment +- Debugging: All TPU-related environment variables are printed during initialization + +#### 2. Mixed Precision Training +Configured for optimal TPU v5e-8 performance: +```python +def configure_mixed_precision(): + """Configure mixed precision for optimal TPU v5e-8 performance""" + policy = keras.mixed_precision.Policy('mixed_bfloat16') + keras.mixed_precision.set_global_policy(policy) +``` + +#### 3. XLA-Optimized Operations +- **Static Tensor Operations**: Using `tf.stack()` and `tf.gather()` instead of dynamic indexing +- **Efficient Matrix Operations**: `tf.linalg.matmul()` for batch matrix multiplication +- **TPU-Friendly GRU Layers**: Disabled recurrent dropout for better TPU performance +- **Patch Processing**: TensorFlow equivalent of PyTorch's unfold using `tf.image.extract_patches()` + +### Key Architecture Differences from PyTorch + +#### 1. Gradient Reversal Layer (GRL) +```python +@tf.custom_gradient +def gradient_reverse(x, lambd=1.0): + """Gradient Reversal Layer for TensorFlow""" + def grad(dy): + return -lambd * dy # Only return gradient w.r.t. x + return tf.identity(x), grad +``` + +#### 2. CTC Loss Implementation +Custom sparse tensor conversion for TPU compatibility: +```python +def dense_to_sparse(dense_tensor, sequence_lengths): + """Convert dense tensor to sparse tensor for CTC""" + mask = tf.not_equal(dense_tensor, 0) + indices = tf.where(mask) + values = tf.gather_nd(dense_tensor, indices) + return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape) +``` + +#### 3. Day-Specific Layers +Using `add_weight()` for TPU-compatible variable management: +```python +for i in range(n_days): + weight = self.add_weight( + name=f'day_weight_{i}', + shape=(neural_dim, neural_dim), + initializer=tf.keras.initializers.Identity(), + trainable=True + ) +``` + +### Training on TPU v5e-8 + +#### Basic Training Command +```python +# In Kaggle TPU v5e-8 environment +python train_model_tf.py +``` + +#### Expected Output +``` +🔍 Detecting TPU environment... +📍 Kaggle TPU detected, worker ID: 0, address: grpc://10.0.0.2:8470 +✅ TPU initialized successfully! +🎉 Number of TPU cores: 8 +Training on 8 TPU cores # Should show 8 cores, not 1 +``` + +### Performance Benefits + +1. **Multi-Core Utilization**: Properly configured TPU strategy utilizes all 8 TPU v5e-8 cores +2. **Mixed Precision**: bfloat16 precision optimized for TPU matrix units +3. **XLA Compilation**: Static operations enable efficient XLA graph compilation +4. **Memory Efficiency**: Optimized for TPU memory constraints and batch processing + +### Common Issues and Solutions + +#### Issue: "Training on 1 TPU cores" instead of 8 +**Cause**: TPU connection fallback to default strategy +**Solution**: Enhanced `create_tpu_strategy()` function with environment detection +**Check**: Verify TPU environment variables are properly set + +#### Issue: CTC Loss dtype errors +**Cause**: Mixed precision dtype mismatches +**Solution**: Explicit dtype casting in `CTCLoss.call()` + +#### Issue: Gradient Reversal Layer errors +**Cause**: Incorrect gradient return format +**Solution**: Return only gradient w.r.t. input tensor, not lambda parameter + ## Competition Context -This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding. \ No newline at end of file +This codebase serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing both PyTorch and TensorFlow reference implementations for neural signal decoding with optimizations for TPU v5e-8 training environments. \ No newline at end of file diff --git a/model_training_nnn_tpu/rnn_model_tf.py b/model_training_nnn_tpu/rnn_model_tf.py index 273f4d9..7f29924 100644 --- a/model_training_nnn_tpu/rnn_model_tf.py +++ b/model_training_nnn_tpu/rnn_model_tf.py @@ -763,6 +763,14 @@ def create_tpu_strategy(): print("🔍 Detecting TPU environment...") + # Disable GPU to avoid CUDA conflicts in TPU environment + try: + print("🚫 Disabling GPU to prevent CUDA conflicts...") + tf.config.set_visible_devices([], 'GPU') + print("✅ GPU disabled successfully") + except Exception as e: + print(f"⚠️ Warning: Could not disable GPU: {e}") + # Check for various TPU environment variables tpu_address = None tpu_name = None