tpu
This commit is contained in:
@@ -59,8 +59,13 @@ class BrainToTextDecoder_Trainer:
|
||||
)
|
||||
|
||||
# Initialize Accelerator for TPU/multi-device support
|
||||
use_xla = bool(xm.get_xla_supported_devices())
|
||||
mixed_precision_mode = 'no'
|
||||
if not use_xla and args.get('use_amp', True):
|
||||
mixed_precision_mode = 'bf16'
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
mixed_precision='bf16' if args.get('use_amp', True) else 'no',
|
||||
mixed_precision=mixed_precision_mode,
|
||||
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
|
||||
log_with=None, # We'll use our own logging
|
||||
project_dir=args.get('output_dir', './output'),
|
||||
@@ -126,6 +131,8 @@ class BrainToTextDecoder_Trainer:
|
||||
self.logger.info(f'Accelerator state: {self.accelerator.state}')
|
||||
if self.accelerator.num_processes > 1:
|
||||
self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
|
||||
if mixed_precision_mode == 'no' and use_xla and args.get('use_amp', True):
|
||||
self.logger.info('AMP requested but disabled on XLA to avoid dtype mismatches; running in float32 on TPU.')
|
||||
|
||||
# Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
|
||||
if self.args['seed'] != -1:
|
||||
|
Reference in New Issue
Block a user