This commit is contained in:
Zchen
2025-10-15 00:30:56 +08:00
parent 9025267400
commit 5dcbf28c96

View File

@@ -59,8 +59,13 @@ class BrainToTextDecoder_Trainer:
)
# Initialize Accelerator for TPU/multi-device support
use_xla = bool(xm.get_xla_supported_devices())
mixed_precision_mode = 'no'
if not use_xla and args.get('use_amp', True):
mixed_precision_mode = 'bf16'
self.accelerator = Accelerator(
mixed_precision='bf16' if args.get('use_amp', True) else 'no',
mixed_precision=mixed_precision_mode,
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
log_with=None, # We'll use our own logging
project_dir=args.get('output_dir', './output'),
@@ -126,6 +131,8 @@ class BrainToTextDecoder_Trainer:
self.logger.info(f'Accelerator state: {self.accelerator.state}')
if self.accelerator.num_processes > 1:
self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
if mixed_precision_mode == 'no' and use_xla and args.get('use_amp', True):
self.logger.info('AMP requested but disabled on XLA to avoid dtype mismatches; running in float32 on TPU.')
# Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
if self.args['seed'] != -1: