fixed : tf call cuda
This commit is contained in:
142
CLAUDE.md
142
CLAUDE.md
@@ -449,5 +449,145 @@ The deprecated APIs still work but generate warnings. For production code:
|
|||||||
- Test thoroughly as synchronization behavior may differ slightly
|
- Test thoroughly as synchronization behavior may differ slightly
|
||||||
- Legacy code will continue to function until removed in future versions
|
- Legacy code will continue to function until removed in future versions
|
||||||
|
|
||||||
|
## TensorFlow TPU Implementation
|
||||||
|
|
||||||
|
The original PyTorch implementation has been converted to TensorFlow for optimal performance on TPU v5e-8 environments, particularly for the Brain-to-Text '25 Competition on Kaggle.
|
||||||
|
|
||||||
|
### Key TensorFlow Components (`model_training_nnn_tpu/`)
|
||||||
|
|
||||||
|
#### Core Files
|
||||||
|
- **`rnn_model_tf.py`**: TensorFlow implementation of TripleGRUDecoder architecture
|
||||||
|
- `NoiseModel`: 2-layer GRU for noise estimation with day-specific layers
|
||||||
|
- `CleanSpeechModel`: 3-layer GRU for clean speech recognition with day-specific layers
|
||||||
|
- `NoisySpeechModel`: 2-layer GRU for noisy speech recognition (no day layers)
|
||||||
|
- `TripleGRUDecoder`: Main adversarial architecture combining all three models
|
||||||
|
- `CTCLoss`: Custom CTC loss implementation for TPU compatibility
|
||||||
|
- `create_tpu_strategy()`: Enhanced TPU connection function with robust environment detection
|
||||||
|
|
||||||
|
- **`trainer_tf.py`**: TensorFlow training pipeline with distributed TPU support
|
||||||
|
- **`dataset_tf.py`**: TensorFlow data loading with augmentation pipeline optimized for TPU
|
||||||
|
- **`train_model_tf.py`**: Main training script entry point
|
||||||
|
- **`evaluate_model_tf.py`**: Evaluation pipeline for model performance analysis
|
||||||
|
|
||||||
|
### TPU v5e-8 Specific Optimizations
|
||||||
|
|
||||||
|
#### 1. Enhanced TPU Connection
|
||||||
|
The `create_tpu_strategy()` function provides robust TPU detection across different environments:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def create_tpu_strategy():
|
||||||
|
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
||||||
|
# Multi-environment TPU detection
|
||||||
|
if 'COLAB_TPU_ADDR' in os.environ:
|
||||||
|
tpu_address = os.environ['COLAB_TPU_ADDR']
|
||||||
|
elif 'TPU_NAME' in os.environ:
|
||||||
|
tpu_name = os.environ['TPU_NAME']
|
||||||
|
elif 'TPU_WORKER_ID' in os.environ:
|
||||||
|
# Kaggle TPU environment
|
||||||
|
tpu_address = f'grpc://10.0.0.2:8470' # Default Kaggle TPU address
|
||||||
|
|
||||||
|
# Enhanced error handling and debugging output
|
||||||
|
# Fallback to default strategy if TPU connection fails
|
||||||
|
```
|
||||||
|
|
||||||
|
**Environment Variables Detected**:
|
||||||
|
- `COLAB_TPU_ADDR`: Google Colab TPU environment
|
||||||
|
- `TPU_NAME`: Generic TPU name specification
|
||||||
|
- `TPU_WORKER_ID`: Kaggle TPU environment indicator
|
||||||
|
|
||||||
|
**Troubleshooting TPU Connection Issues**:
|
||||||
|
- Error: "Failed to initialize TPU: Please provide a TPU Name to connect to."
|
||||||
|
- Solution: The function automatically detects and uses appropriate TPU addresses based on environment
|
||||||
|
- Debugging: All TPU-related environment variables are printed during initialization
|
||||||
|
|
||||||
|
#### 2. Mixed Precision Training
|
||||||
|
Configured for optimal TPU v5e-8 performance:
|
||||||
|
```python
|
||||||
|
def configure_mixed_precision():
|
||||||
|
"""Configure mixed precision for optimal TPU v5e-8 performance"""
|
||||||
|
policy = keras.mixed_precision.Policy('mixed_bfloat16')
|
||||||
|
keras.mixed_precision.set_global_policy(policy)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. XLA-Optimized Operations
|
||||||
|
- **Static Tensor Operations**: Using `tf.stack()` and `tf.gather()` instead of dynamic indexing
|
||||||
|
- **Efficient Matrix Operations**: `tf.linalg.matmul()` for batch matrix multiplication
|
||||||
|
- **TPU-Friendly GRU Layers**: Disabled recurrent dropout for better TPU performance
|
||||||
|
- **Patch Processing**: TensorFlow equivalent of PyTorch's unfold using `tf.image.extract_patches()`
|
||||||
|
|
||||||
|
### Key Architecture Differences from PyTorch
|
||||||
|
|
||||||
|
#### 1. Gradient Reversal Layer (GRL)
|
||||||
|
```python
|
||||||
|
@tf.custom_gradient
|
||||||
|
def gradient_reverse(x, lambd=1.0):
|
||||||
|
"""Gradient Reversal Layer for TensorFlow"""
|
||||||
|
def grad(dy):
|
||||||
|
return -lambd * dy # Only return gradient w.r.t. x
|
||||||
|
return tf.identity(x), grad
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. CTC Loss Implementation
|
||||||
|
Custom sparse tensor conversion for TPU compatibility:
|
||||||
|
```python
|
||||||
|
def dense_to_sparse(dense_tensor, sequence_lengths):
|
||||||
|
"""Convert dense tensor to sparse tensor for CTC"""
|
||||||
|
mask = tf.not_equal(dense_tensor, 0)
|
||||||
|
indices = tf.where(mask)
|
||||||
|
values = tf.gather_nd(dense_tensor, indices)
|
||||||
|
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Day-Specific Layers
|
||||||
|
Using `add_weight()` for TPU-compatible variable management:
|
||||||
|
```python
|
||||||
|
for i in range(n_days):
|
||||||
|
weight = self.add_weight(
|
||||||
|
name=f'day_weight_{i}',
|
||||||
|
shape=(neural_dim, neural_dim),
|
||||||
|
initializer=tf.keras.initializers.Identity(),
|
||||||
|
trainable=True
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training on TPU v5e-8
|
||||||
|
|
||||||
|
#### Basic Training Command
|
||||||
|
```python
|
||||||
|
# In Kaggle TPU v5e-8 environment
|
||||||
|
python train_model_tf.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Expected Output
|
||||||
|
```
|
||||||
|
🔍 Detecting TPU environment...
|
||||||
|
📍 Kaggle TPU detected, worker ID: 0, address: grpc://10.0.0.2:8470
|
||||||
|
✅ TPU initialized successfully!
|
||||||
|
🎉 Number of TPU cores: 8
|
||||||
|
Training on 8 TPU cores # Should show 8 cores, not 1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Performance Benefits
|
||||||
|
|
||||||
|
1. **Multi-Core Utilization**: Properly configured TPU strategy utilizes all 8 TPU v5e-8 cores
|
||||||
|
2. **Mixed Precision**: bfloat16 precision optimized for TPU matrix units
|
||||||
|
3. **XLA Compilation**: Static operations enable efficient XLA graph compilation
|
||||||
|
4. **Memory Efficiency**: Optimized for TPU memory constraints and batch processing
|
||||||
|
|
||||||
|
### Common Issues and Solutions
|
||||||
|
|
||||||
|
#### Issue: "Training on 1 TPU cores" instead of 8
|
||||||
|
**Cause**: TPU connection fallback to default strategy
|
||||||
|
**Solution**: Enhanced `create_tpu_strategy()` function with environment detection
|
||||||
|
**Check**: Verify TPU environment variables are properly set
|
||||||
|
|
||||||
|
#### Issue: CTC Loss dtype errors
|
||||||
|
**Cause**: Mixed precision dtype mismatches
|
||||||
|
**Solution**: Explicit dtype casting in `CTCLoss.call()`
|
||||||
|
|
||||||
|
#### Issue: Gradient Reversal Layer errors
|
||||||
|
**Cause**: Incorrect gradient return format
|
||||||
|
**Solution**: Return only gradient w.r.t. input tensor, not lambda parameter
|
||||||
|
|
||||||
## Competition Context
|
## Competition Context
|
||||||
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.
|
This codebase serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing both PyTorch and TensorFlow reference implementations for neural signal decoding with optimizations for TPU v5e-8 training environments.
|
@@ -763,6 +763,14 @@ def create_tpu_strategy():
|
|||||||
|
|
||||||
print("🔍 Detecting TPU environment...")
|
print("🔍 Detecting TPU environment...")
|
||||||
|
|
||||||
|
# Disable GPU to avoid CUDA conflicts in TPU environment
|
||||||
|
try:
|
||||||
|
print("🚫 Disabling GPU to prevent CUDA conflicts...")
|
||||||
|
tf.config.set_visible_devices([], 'GPU')
|
||||||
|
print("✅ GPU disabled successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Warning: Could not disable GPU: {e}")
|
||||||
|
|
||||||
# Check for various TPU environment variables
|
# Check for various TPU environment variables
|
||||||
tpu_address = None
|
tpu_address = None
|
||||||
tpu_name = None
|
tpu_name = None
|
||||||
|
Reference in New Issue
Block a user