tpu
This commit is contained in:
@@ -319,6 +319,37 @@ if xm.get_xla_supported_devices():
|
|||||||
|
|
||||||
**预期改进**: XLA图编译时间从5-15分钟缩短到2-8分钟
|
**预期改进**: XLA图编译时间从5-15分钟缩短到2-8分钟
|
||||||
|
|
||||||
|
## New Issue: DType Mismatch in adjusted_lens Calculation (2025-10-12 16:45)
|
||||||
|
|
||||||
|
### Error Description
|
||||||
|
```
|
||||||
|
Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 1 shape: f32[21504], argument shape: bf16[21504].
|
||||||
|
```
|
||||||
|
|
||||||
|
### Root Cause
|
||||||
|
The `adjusted_lens` calculation was causing dtype mismatches in TPU mixed precision (bf16) training. When `n_time_steps` is processed under `accelerator.autocast()`, it becomes bfloat16, but the arithmetic operations were creating float32 results.
|
||||||
|
|
||||||
|
### Problem Code
|
||||||
|
```python
|
||||||
|
# Before (causes f32/bf16 mismatch):
|
||||||
|
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Solution
|
||||||
|
Explicit float conversion before dtype casting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# After (explicit dtype control):
|
||||||
|
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fixed Locations
|
||||||
|
- `rnn_trainer.py:577` - Training loop
|
||||||
|
- `rnn_trainer.py:753` - Validation loop
|
||||||
|
- `rnn_trainer.py:851` - Inference batch function
|
||||||
|
|
||||||
|
**Key Insight**: Mixed precision training requires explicit dtype management for ALL tensor operations, even intermediate calculations.
|
||||||
|
|
||||||
## Lessons Learned
|
## Lessons Learned
|
||||||
- **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors
|
- **Root Cause**: TPU XLA compiler requires strict dtype consistency across all tensors
|
||||||
- **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype
|
- **Key Insight**: `torch.eye()` and `torch.zeros()` default to f32 - must explicitly specify dtype
|
||||||
|
@@ -25,8 +25,9 @@ class NoiseModel(nn.Module):
|
|||||||
|
|
||||||
# Day-specific input layers
|
# Day-specific input layers
|
||||||
self.day_layer_activation = nn.Softsign()
|
self.day_layer_activation = nn.Softsign()
|
||||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
# Let Accelerator handle dtype automatically for TPU compatibility
|
||||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
||||||
|
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
||||||
self.day_layer_dropout = nn.Dropout(input_dropout)
|
self.day_layer_dropout = nn.Dropout(input_dropout)
|
||||||
|
|
||||||
# Calculate input size after patching
|
# Calculate input size after patching
|
||||||
@@ -51,8 +52,8 @@ class NoiseModel(nn.Module):
|
|||||||
if "weight_ih" in name:
|
if "weight_ih" in name:
|
||||||
nn.init.xavier_uniform_(param)
|
nn.init.xavier_uniform_(param)
|
||||||
|
|
||||||
# Learnable initial hidden state
|
# Learnable initial hidden state - let Accelerator handle dtype
|
||||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size, dtype=torch.bfloat16)))
|
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
|
||||||
|
|
||||||
def forward(self, x, day_idx, states=None):
|
def forward(self, x, day_idx, states=None):
|
||||||
# Apply day-specific transformation
|
# Apply day-specific transformation
|
||||||
@@ -110,8 +111,9 @@ class CleanSpeechModel(nn.Module):
|
|||||||
|
|
||||||
# Day-specific input layers
|
# Day-specific input layers
|
||||||
self.day_layer_activation = nn.Softsign()
|
self.day_layer_activation = nn.Softsign()
|
||||||
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
# Let Accelerator handle dtype automatically for TPU compatibility
|
||||||
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
|
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
|
||||||
|
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
|
||||||
self.day_layer_dropout = nn.Dropout(input_dropout)
|
self.day_layer_dropout = nn.Dropout(input_dropout)
|
||||||
|
|
||||||
# Calculate input size after patching
|
# Calculate input size after patching
|
||||||
@@ -141,7 +143,7 @@ class CleanSpeechModel(nn.Module):
|
|||||||
nn.init.xavier_uniform_(self.out.weight)
|
nn.init.xavier_uniform_(self.out.weight)
|
||||||
|
|
||||||
# Learnable initial hidden state
|
# Learnable initial hidden state
|
||||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
|
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||||
|
|
||||||
def forward(self, x, day_idx, states=None, return_state=False):
|
def forward(self, x, day_idx, states=None, return_state=False):
|
||||||
# Apply day-specific transformation
|
# Apply day-specific transformation
|
||||||
@@ -229,7 +231,7 @@ class NoisySpeechModel(nn.Module):
|
|||||||
nn.init.xavier_uniform_(self.out.weight)
|
nn.init.xavier_uniform_(self.out.weight)
|
||||||
|
|
||||||
# Learnable initial hidden state
|
# Learnable initial hidden state
|
||||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
|
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||||
|
|
||||||
def forward(self, x, states=None, return_state=False):
|
def forward(self, x, states=None, return_state=False):
|
||||||
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
|
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
|
||||||
|
@@ -573,7 +573,8 @@ class BrainToTextDecoder_Trainer:
|
|||||||
# Apply augmentations to the data
|
# Apply augmentations to the data
|
||||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
|
||||||
|
|
||||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
# Ensure proper dtype handling for TPU mixed precision
|
||||||
|
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
|
|
||||||
# Get phoneme predictions using inference mode during training
|
# Get phoneme predictions using inference mode during training
|
||||||
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
|
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
|
||||||
@@ -748,7 +749,8 @@ class BrainToTextDecoder_Trainer:
|
|||||||
with self.accelerator.autocast():
|
with self.accelerator.autocast():
|
||||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
# Ensure proper dtype handling for TPU mixed precision
|
||||||
|
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
|
|
||||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||||
|
|
||||||
@@ -845,8 +847,8 @@ class BrainToTextDecoder_Trainer:
|
|||||||
# Apply data transformations (no augmentation for inference)
|
# Apply data transformations (no augmentation for inference)
|
||||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
# Calculate adjusted sequence lengths for CTC
|
# Calculate adjusted sequence lengths for CTC with proper dtype handling
|
||||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||||
|
|
||||||
# Get phoneme predictions
|
# Get phoneme predictions
|
||||||
logits = self.model(features, day_indicies, None, False, mode)
|
logits = self.model(features, day_indicies, None, False, mode)
|
||||||
|
Reference in New Issue
Block a user