This commit is contained in:
Zchen
2025-10-12 22:59:45 +08:00
parent 5c941d9efa
commit 6cfc568f9a
2 changed files with 56 additions and 0 deletions

View File

@@ -578,6 +578,11 @@ class BrainToTextDecoder_Trainer:
# Get phoneme predictions using inference mode during training
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16':
# In mixed precision mode, ensure features match the expected precision
features = features.to(torch.float32)
logits = self.model(features, day_indicies, None, False, 'inference')
# Calculate CTC Loss
@@ -752,6 +757,11 @@ class BrainToTextDecoder_Trainer:
# Ensure proper dtype handling for TPU mixed precision
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16':
# In mixed precision mode, ensure features match the expected precision
features = features.to(torch.float32)
logits = self.model(features, day_indicies, None, False, 'inference')
loss = self.ctc_loss(
@@ -826,6 +836,11 @@ class BrainToTextDecoder_Trainer:
# Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16':
# In mixed precision mode, ensure features match the expected precision
features = features.to(torch.float32)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
@@ -850,6 +865,11 @@ class BrainToTextDecoder_Trainer:
# Calculate adjusted sequence lengths for CTC with proper dtype handling
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16':
# In mixed precision mode, ensure features match the expected precision
features = features.to(torch.float32)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)