tpu
This commit is contained in:
@@ -350,6 +350,42 @@ adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / sel
|
||||
|
||||
**Key Insight**: Mixed precision training requires explicit dtype management for ALL tensor operations, even intermediate calculations.
|
||||
|
||||
## New Issue: Features Tensor DType Mismatch (2025-10-12 17:00)
|
||||
|
||||
### Error Description
|
||||
```
|
||||
Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168].
|
||||
```
|
||||
|
||||
### Root Cause Analysis
|
||||
After fixing the `adjusted_lens` dtype issue, a new mismatch emerged in the `features` tensor dimensions `[32, 7168]` representing (batch_size=32, neural_dim×patch_size=512×14=7168). Under `accelerator.autocast()` with mixed precision `bf16`, input tensors are automatically converted to bfloat16, but model parameters remained in float32 after removing hardcoded dtype specifications, creating a mismatch at the model input level.
|
||||
|
||||
### Problem Code
|
||||
```python
|
||||
# Inside accelerator.autocast() context:
|
||||
# features becomes bf16 automatically by autocast
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
# Model expects f32 parameters but receives bf16 input → mismatch
|
||||
```
|
||||
|
||||
### Solution
|
||||
Add explicit dtype conversion before all model calls to ensure consistency:
|
||||
|
||||
```python
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if self.accelerator.mixed_precision == 'bf16':
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
```
|
||||
|
||||
### Fixed Locations
|
||||
- `rnn_trainer.py:582-584` - Training loop model call
|
||||
- `rnn_trainer.py:760-763` - Validation loop model call
|
||||
- `rnn_trainer.py:839-842` - Inference method model call
|
||||
- `rnn_trainer.py:863-866` - Inference batch method model call
|
||||
|
||||
**Key Insight**: Mixed precision autocast converts inputs but not necessarily model parameters. When removing hardcoded dtypes, explicit conversion ensures compatibility between autocast inputs and model parameters.
|
||||
|
||||
## Lessons Learned
|
||||
- **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors
|
||||
- **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype
|
||||
|
@@ -578,6 +578,11 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
# Get phoneme predictions using inference mode during training
|
||||
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if self.accelerator.mixed_precision == 'bf16':
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
|
||||
# Calculate CTC Loss
|
||||
@@ -752,6 +757,11 @@ class BrainToTextDecoder_Trainer:
|
||||
# Ensure proper dtype handling for TPU mixed precision
|
||||
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if self.accelerator.mixed_precision == 'bf16':
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
|
||||
loss = self.ctc_loss(
|
||||
@@ -826,6 +836,11 @@ class BrainToTextDecoder_Trainer:
|
||||
# Apply data transformations (no augmentation for inference)
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if self.accelerator.mixed_precision == 'bf16':
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies, None, False, mode)
|
||||
|
||||
@@ -850,6 +865,11 @@ class BrainToTextDecoder_Trainer:
|
||||
# Calculate adjusted sequence lengths for CTC with proper dtype handling
|
||||
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if self.accelerator.mixed_precision == 'bf16':
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies, None, False, mode)
|
||||
|
||||
|
Reference in New Issue
Block a user