diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index feea704..07cc2a2 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -59,8 +59,13 @@ class BrainToTextDecoder_Trainer: ) # Initialize Accelerator for TPU/multi-device support + use_xla = bool(xm.get_xla_supported_devices()) + mixed_precision_mode = 'no' + if not use_xla and args.get('use_amp', True): + mixed_precision_mode = 'bf16' + self.accelerator = Accelerator( - mixed_precision='bf16' if args.get('use_amp', True) else 'no', + mixed_precision=mixed_precision_mode, gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), log_with=None, # We'll use our own logging project_dir=args.get('output_dir', './output'), @@ -126,6 +131,8 @@ class BrainToTextDecoder_Trainer: self.logger.info(f'Accelerator state: {self.accelerator.state}') if self.accelerator.num_processes > 1: self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes') + if mixed_precision_mode == 'no' and use_xla and args.get('use_amp', True): + self.logger.info('AMP requested but disabled on XLA to avoid dtype mismatches; running in float32 on TPU.') # Set seed if provided (using Accelerator's set_seed for proper distributed seeding) if self.args['seed'] != -1: