diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 6cba21b..feea704 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -11,6 +11,7 @@ import logging import sys import json import pickle +from contextlib import nullcontext from dataset import BrainToTextDataset, train_test_split_indicies from data_augmentations import gauss_smooth @@ -302,6 +303,12 @@ class BrainToTextDecoder_Trainer: if self.adv_enabled: self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}") + def autocast_context(self): + """Return appropriate autocast context; disable on XLA to avoid dtype mismatches.""" + if self.device.type == 'xla': + return nullcontext() + return self.accelerator.autocast() + def create_optimizer(self): ''' Create the optimizer with special param groups @@ -578,7 +585,7 @@ class BrainToTextDecoder_Trainer: day_indicies = batch['day_indicies'] # Use Accelerator's autocast (mixed precision handled by Accelerator init) - with self.accelerator.autocast(): + with self.autocast_context(): # Apply augmentations to the data features, n_time_steps = self.transform_data(features, n_time_steps, 'train') @@ -792,7 +799,7 @@ class BrainToTextDecoder_Trainer: with torch.no_grad(): - with self.accelerator.autocast(): + with self.autocast_context(): features, n_time_steps = self.transform_data(features, n_time_steps, 'val') # Ensure proper dtype handling for TPU mixed precision @@ -873,7 +880,7 @@ class BrainToTextDecoder_Trainer: self.model.eval() with torch.no_grad(): - with self.accelerator.autocast(): + with self.autocast_context(): # Apply data transformations (no augmentation for inference) features, n_time_steps = self.transform_data(features, n_time_steps, 'val') @@ -899,7 +906,7 @@ class BrainToTextDecoder_Trainer: n_time_steps = batch['n_time_steps'] with torch.no_grad(): - with self.accelerator.autocast(): + with self.autocast_context(): # Apply data transformations (no augmentation for inference) features, n_time_steps = self.transform_data(features, n_time_steps, 'val')