This commit is contained in:
Zchen
2025-10-14 23:35:42 +08:00
parent cd52ba51ba
commit 4b6d680283
2 changed files with 118 additions and 23 deletions

113
CLAUDE.md
View File

@@ -165,19 +165,24 @@ x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency
``` ```
#### 5. Mixed Precision Dtype Consistency #### 5. Mixed Precision Dtype Consistency (Comprehensive Fix)
**Problem**: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations **Problem**: Mixed precision training causes dtype mismatches throughout the adversarial training pipeline
**Solution**: Ensure all operands match input tensor dtype and preserve dtype through all operations **Error**: `Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168]`
**Root Cause Analysis**: The error occurred at dimension 7168 = 512 * 14, indicating patch processing with patch_size=14. The dtype mismatch cascaded through multiple layers:
1. Initial bmm operations in day-specific transformations
2. Adversarial training residual connections between models
3. Patch processing operations (unfold, permute, reshape)
4. Gradient Reversal Layer (GRL) operations
5. Hidden state initialization in adversarial training helper methods
**Comprehensive Solution**: Implement dtype consistency across the entire adversarial training data flow:
```python ```python
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training # Fix 1: Basic bmm operations with dtype consistency
# Fix 1: Add dtype conversions for all bmm operands
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
# Fix 2: Ensure dtype consistency in adversarial training residual connections # Fix 2: Patch processing with explicit dtype preservation
denoised_input = x_processed - noise_output.to(x_processed.dtype)
# Fix 3: Preserve dtype through patch processing operations
if self.patch_size > 0: if self.patch_size > 0:
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
x = x.unsqueeze(1) x = x.unsqueeze(1)
@@ -188,8 +193,37 @@ if self.patch_size > 0:
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1) x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations # Ensure dtype consistency after patch processing operations
x = x.to(original_dtype) x = x.to(original_dtype)
# Fix 3: Adversarial training residual connections
noise_output = noise_output.to(x_processed.dtype)
denoised_input = x_processed - noise_output
# Fix 4: Gradient Reversal Layer dtype handling
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda else noise_output
# Ensure dtype consistency after GRL (preserves input dtype but explicit check)
noisy_input = noisy_input.to(x_processed.dtype)
# Fix 5: Hidden state dtype consistency in helper methods
# In _clean_forward_with_processed_input:
if states is None:
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype)
# In _noisy_forward_with_processed_input:
if states is None:
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype)
``` ```
**Key Implementation Details**:
- **GradientReversalFn**: Preserves input dtype automatically (identity forward, gradient reversal backward)
- **Patch Processing**: Explicit dtype preservation prevents unfold operations from changing precision
- **Residual Connections**: All tensor arithmetic operations ensure matching dtypes
- **Helper Methods**: Hidden state initialization matches processed input dtype
- **Data Flow**: NoiseModel → GRL → NoisySpeechModel maintains dtype consistency throughout
#### 3. Hidden State Initialization #### 3. Hidden State Initialization
**Problem**: Dynamic batch size allocation causes XLA recompilation **Problem**: Dynamic batch size allocation causes XLA recompilation
**Solution**: Use static shapes and avoid x.shape[0] in tensor creation **Solution**: Use static shapes and avoid x.shape[0] in tensor creation
@@ -223,12 +257,25 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return
### Files Modified for XLA Optimization ### Files Modified for XLA Optimization
- **`model_training_nnn/rnn_model.py`**: All three models optimized - **`model_training_nnn/rnn_model.py`**: Comprehensive XLA optimization with dtype consistency
- `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency - **`GradientReversalFn`**: Added adversarial training gradient reversal layer
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency - **`NoiseModel.forward()`**: Dynamic indexing → static gather operations + comprehensive dtype consistency + patch processing dtype preservation
- `NoisySpeechModel.forward()`: Hidden state optimization - **`CleanSpeechModel.forward()`**: Same optimizations + bmm for matrix ops + comprehensive dtype consistency + patch processing dtype preservation
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns + adversarial residual connection dtype fix - **`NoisySpeechModel.forward()`**: Hidden state optimization (no day layers, simplified)
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency - **`TripleGRUDecoder.forward()`**: Complex return values → tuple returns + comprehensive adversarial training dtype fixes + residual connection dtype consistency + GRL dtype handling
- **`TripleGRUDecoder._apply_preprocessing()`**: Static preprocessing operations + dtype consistency + patch processing dtype preservation
- **`TripleGRUDecoder._clean_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
- **`TripleGRUDecoder._noisy_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
**Specific Dtype Consistency Fixes Applied**:
1. **Basic Operations**: All `torch.bmm()` operations with `.to(x.dtype)` conversions
2. **Patch Processing**: Explicit dtype preservation through unfold/permute/reshape operations
3. **Adversarial Training**: Residual connections with `.to(x_processed.dtype)` conversions
4. **Gradient Reversal**: Dtype consistency after GRL operations
5. **Hidden States**: All hidden state initialization with `.to(x_processed.dtype)` conversions
6. **Data Flow**: End-to-end dtype consistency in NoiseModel → GRL → NoisySpeechModel pipeline
**Error Resolved**: `f32[32,7168] vs bf16[32,7168]` dtype mismatch in mixed precision TPU training
### Benefits of XLA Optimizations ### Benefits of XLA Optimizations
@@ -252,5 +299,41 @@ Created test scripts to verify model consistency:
- Backward compatibility with existing training scripts is maintained - Backward compatibility with existing training scripts is maintained
- TPU training should now show improved compilation times and memory efficiency - TPU training should now show improved compilation times and memory efficiency
### Troubleshooting Dtype Issues in Mixed Precision Training
**Common Error Pattern**:
```
Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[X,Y], argument shape: bf16[X,Y]
```
**Diagnosis Steps**:
1. **Identify Operation**: Look at the tensor dimensions to identify which operation is failing
- `7168 = 512 * 14`: Patch processing operation with patch_size=14
- `512`: Basic neural features
- Other patterns may indicate different operations
2. **Check Data Flow**: Trace the tensor through the adversarial training pipeline
- Input → NoiseModel → residual connection → CleanSpeechModel
- Input → NoiseModel → GRL → NoisySpeechModel
3. **Verify Dtype Consistency**: Ensure all operations maintain input dtype
- Use `.to(x.dtype)` for all operand tensors
- Preserve dtype through complex operations (unfold, permute, reshape)
- Match hidden state dtype to input tensor dtype
**Quick Fix Template**:
```python
# For any tensor operation between tensors a and b:
result = operation(a, b.to(a.dtype))
# For complex operations that might change dtype:
original_dtype = tensor.dtype
tensor = complex_operation(tensor)
tensor = tensor.to(original_dtype)
# For hidden state initialization:
states = states.to(input_tensor.dtype)
```
## Competition Context ## Competition Context
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding. This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.

View File

@@ -407,9 +407,11 @@ class TripleGRUDecoder(nn.Module):
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)''' '''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
batch_size = x_processed.size(0) batch_size = x_processed.size(0)
# XLA-friendly hidden state initialization # XLA-friendly hidden state initialization with dtype consistency
if states is None: if states is None:
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous() states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype)
# GRU forward pass (skip preprocessing since input is already processed) # GRU forward pass (skip preprocessing since input is already processed)
output, hidden_states = self.clean_speech_model.gru(x_processed, states) output, hidden_states = self.clean_speech_model.gru(x_processed, states)
@@ -422,9 +424,11 @@ class TripleGRUDecoder(nn.Module):
'''Forward pass for NoisySpeechModel with already processed input''' '''Forward pass for NoisySpeechModel with already processed input'''
batch_size = x_processed.size(0) batch_size = x_processed.size(0)
# XLA-friendly hidden state initialization # XLA-friendly hidden state initialization with dtype consistency
if states is None: if states is None:
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous() states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype)
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway) # GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
output, hidden_states = self.noisy_speech_model.gru(x_processed, states) output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
@@ -455,9 +459,11 @@ class TripleGRUDecoder(nn.Module):
# Apply the same preprocessing that the models use internally # Apply the same preprocessing that the models use internally
x_processed = self._apply_preprocessing(x, day_idx) x_processed = self._apply_preprocessing(x, day_idx)
# Ensure dtype consistency between processed input and noise output
noise_output = noise_output.to(x_processed.dtype)
# 3. Clean speech model processes denoised signal # 3. Clean speech model processes denoised signal
# Ensure dtype consistency for mixed precision training in residual connection denoised_input = x_processed - noise_output # Residual connection in processed space
denoised_input = x_processed - noise_output.to(x_processed.dtype) # Residual connection in processed space
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data # Clean speech model will apply its own preprocessing, so we pass the denoised processed data
# But we need to reverse the preprocessing first, then let clean model do its own # But we need to reverse the preprocessing first, then let clean model do its own
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing # Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
@@ -467,6 +473,9 @@ class TripleGRUDecoder(nn.Module):
# 4. Noisy speech model processes noise signal directly (no day layers needed) # 4. Noisy speech model processes noise signal directly (no day layers needed)
# Optionally apply Gradient Reversal to enforce adversarial training on noise output # Optionally apply Gradient Reversal to enforce adversarial training on noise output
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
# Ensure dtype consistency - GradientReversalFn should preserve dtype, but ensure compatibility
# Use x_processed.dtype as reference since it's the main data flow dtype
noisy_input = noisy_input.to(x_processed.dtype)
noisy_logits = self._noisy_forward_with_processed_input(noisy_input, noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
states['noisy'] if states else None) states['noisy'] if states else None)
@@ -485,9 +494,9 @@ class TripleGRUDecoder(nn.Module):
# 2. For residual connection, we need x in the same space as noise_output # 2. For residual connection, we need x in the same space as noise_output
x_processed = self._apply_preprocessing(x, day_idx) x_processed = self._apply_preprocessing(x, day_idx)
# 3. Process denoised signal # Ensure dtype consistency for mixed precision residual connection
# Ensure dtype consistency for mixed precision training in residual connection noise_output = noise_output.to(x_processed.dtype)
denoised_input = x_processed - noise_output.to(x_processed.dtype) denoised_input = x_processed - noise_output
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None) states['clean'] if states else None)
@@ -505,7 +514,10 @@ class TripleGRUDecoder(nn.Module):
clean_grad (tensor) - gradients from clean speech model output layer clean_grad (tensor) - gradients from clean speech model output layer
noisy_grad (tensor) - gradients from noisy speech model output layer noisy_grad (tensor) - gradients from noisy speech model output layer
learning_rate (float) - learning rate for gradient update if grl_lambda and grl_lambda != 0.0:
noisy_input = gradient_reverse(noise_output, grl_lambda)
else:
noisy_input = noise_output
''' '''
# Combine gradients: negative from clean model, positive from noisy model # Combine gradients: negative from clean model, positive from noisy model
combined_grad = -clean_grad + noisy_grad combined_grad = -clean_grad + noisy_grad