tpu
This commit is contained in:
@@ -578,6 +578,11 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
# Get phoneme predictions using inference mode during training
|
||||
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if self.accelerator.mixed_precision == 'bf16':
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
|
||||
# Calculate CTC Loss
|
||||
@@ -752,6 +757,11 @@ class BrainToTextDecoder_Trainer:
|
||||
# Ensure proper dtype handling for TPU mixed precision
|
||||
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if self.accelerator.mixed_precision == 'bf16':
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
|
||||
logits = self.model(features, day_indicies, None, False, 'inference')
|
||||
|
||||
loss = self.ctc_loss(
|
||||
@@ -826,6 +836,11 @@ class BrainToTextDecoder_Trainer:
|
||||
# Apply data transformations (no augmentation for inference)
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if self.accelerator.mixed_precision == 'bf16':
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies, None, False, mode)
|
||||
|
||||
@@ -850,6 +865,11 @@ class BrainToTextDecoder_Trainer:
|
||||
# Calculate adjusted sequence lengths for CTC with proper dtype handling
|
||||
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Ensure features tensor matches model parameter dtype for TPU compatibility
|
||||
if self.accelerator.mixed_precision == 'bf16':
|
||||
# In mixed precision mode, ensure features match the expected precision
|
||||
features = features.to(torch.float32)
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies, None, False, mode)
|
||||
|
||||
|
Reference in New Issue
Block a user