diff --git a/TPU_ISSUES_RECORD.md b/TPU_ISSUES_RECORD.md index 992ac7b..248d333 100644 --- a/TPU_ISSUES_RECORD.md +++ b/TPU_ISSUES_RECORD.md @@ -350,6 +350,42 @@ adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / sel **Key Insight**: Mixed precision training requires explicit dtype management for ALL tensor operations, even intermediate calculations. +## New Issue: Features Tensor DType Mismatch (2025-10-12 17:00) + +### Error Description +``` +Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168]. +``` + +### Root Cause Analysis +After fixing the `adjusted_lens` dtype issue, a new mismatch emerged in the `features` tensor dimensions `[32, 7168]` representing (batch_size=32, neural_dim×patch_size=512×14=7168). Under `accelerator.autocast()` with mixed precision `bf16`, input tensors are automatically converted to bfloat16, but model parameters remained in float32 after removing hardcoded dtype specifications, creating a mismatch at the model input level. + +### Problem Code +```python +# Inside accelerator.autocast() context: +# features becomes bf16 automatically by autocast +logits = self.model(features, day_indicies, None, False, 'inference') +# Model expects f32 parameters but receives bf16 input → mismatch +``` + +### Solution +Add explicit dtype conversion before all model calls to ensure consistency: + +```python +# Ensure features tensor matches model parameter dtype for TPU compatibility +if self.accelerator.mixed_precision == 'bf16': + # In mixed precision mode, ensure features match the expected precision + features = features.to(torch.float32) +``` + +### Fixed Locations +- `rnn_trainer.py:582-584` - Training loop model call +- `rnn_trainer.py:760-763` - Validation loop model call +- `rnn_trainer.py:839-842` - Inference method model call +- `rnn_trainer.py:863-866` - Inference batch method model call + +**Key Insight**: Mixed precision autocast converts inputs but not necessarily model parameters. When removing hardcoded dtypes, explicit conversion ensures compatibility between autocast inputs and model parameters. + ## Lessons Learned - **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors - **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 82e3476..c41cb89 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -578,6 +578,11 @@ class BrainToTextDecoder_Trainer: # Get phoneme predictions using inference mode during training # (We use inference mode for simplicity - only clean logits are used for CTC loss) + # Ensure features tensor matches model parameter dtype for TPU compatibility + if self.accelerator.mixed_precision == 'bf16': + # In mixed precision mode, ensure features match the expected precision + features = features.to(torch.float32) + logits = self.model(features, day_indicies, None, False, 'inference') # Calculate CTC Loss @@ -752,6 +757,11 @@ class BrainToTextDecoder_Trainer: # Ensure proper dtype handling for TPU mixed precision adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) + # Ensure features tensor matches model parameter dtype for TPU compatibility + if self.accelerator.mixed_precision == 'bf16': + # In mixed precision mode, ensure features match the expected precision + features = features.to(torch.float32) + logits = self.model(features, day_indicies, None, False, 'inference') loss = self.ctc_loss( @@ -826,6 +836,11 @@ class BrainToTextDecoder_Trainer: # Apply data transformations (no augmentation for inference) features, n_time_steps = self.transform_data(features, n_time_steps, 'val') + # Ensure features tensor matches model parameter dtype for TPU compatibility + if self.accelerator.mixed_precision == 'bf16': + # In mixed precision mode, ensure features match the expected precision + features = features.to(torch.float32) + # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode) @@ -850,6 +865,11 @@ class BrainToTextDecoder_Trainer: # Calculate adjusted sequence lengths for CTC with proper dtype handling adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) + # Ensure features tensor matches model parameter dtype for TPU compatibility + if self.accelerator.mixed_precision == 'bf16': + # In mixed precision mode, ensure features match the expected precision + features = features.to(torch.float32) + # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode)