tpu without bf16

This commit is contained in:
Zchen
2025-10-15 00:25:39 +08:00
parent 603bb12220
commit 9025267400

View File

@@ -11,6 +11,7 @@ import logging
import sys import sys
import json import json
import pickle import pickle
from contextlib import nullcontext
from dataset import BrainToTextDataset, train_test_split_indicies from dataset import BrainToTextDataset, train_test_split_indicies
from data_augmentations import gauss_smooth from data_augmentations import gauss_smooth
@@ -302,6 +303,12 @@ class BrainToTextDecoder_Trainer:
if self.adv_enabled: if self.adv_enabled:
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}") self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
def autocast_context(self):
"""Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
if self.device.type == 'xla':
return nullcontext()
return self.accelerator.autocast()
def create_optimizer(self): def create_optimizer(self):
''' '''
Create the optimizer with special param groups Create the optimizer with special param groups
@@ -578,7 +585,7 @@ class BrainToTextDecoder_Trainer:
day_indicies = batch['day_indicies'] day_indicies = batch['day_indicies']
# Use Accelerator's autocast (mixed precision handled by Accelerator init) # Use Accelerator's autocast (mixed precision handled by Accelerator init)
with self.accelerator.autocast(): with self.autocast_context():
# Apply augmentations to the data # Apply augmentations to the data
features, n_time_steps = self.transform_data(features, n_time_steps, 'train') features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
@@ -792,7 +799,7 @@ class BrainToTextDecoder_Trainer:
with torch.no_grad(): with torch.no_grad():
with self.accelerator.autocast(): with self.autocast_context():
features, n_time_steps = self.transform_data(features, n_time_steps, 'val') features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure proper dtype handling for TPU mixed precision # Ensure proper dtype handling for TPU mixed precision
@@ -873,7 +880,7 @@ class BrainToTextDecoder_Trainer:
self.model.eval() self.model.eval()
with torch.no_grad(): with torch.no_grad():
with self.accelerator.autocast(): with self.autocast_context():
# Apply data transformations (no augmentation for inference) # Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val') features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
@@ -899,7 +906,7 @@ class BrainToTextDecoder_Trainer:
n_time_steps = batch['n_time_steps'] n_time_steps = batch['n_time_steps']
with torch.no_grad(): with torch.no_grad():
with self.accelerator.autocast(): with self.autocast_context():
# Apply data transformations (no augmentation for inference) # Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val') features, n_time_steps = self.transform_data(features, n_time_steps, 'val')