tpu
This commit is contained in:
@@ -19,7 +19,7 @@ import torchaudio.functional as F # for edit distance
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# Import Accelerate for TPU support
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DataLoaderConfiguration
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
|
||||
@@ -40,12 +40,18 @@ class BrainToTextDecoder_Trainer:
|
||||
args : dictionary of training arguments
|
||||
'''
|
||||
|
||||
# Configure DataLoader behavior for TPU compatibility
|
||||
dataloader_config = DataLoaderConfiguration(
|
||||
even_batches=False # Required for batch_size=None DataLoaders on TPU
|
||||
)
|
||||
|
||||
# Initialize Accelerator for TPU/multi-device support
|
||||
self.accelerator = Accelerator(
|
||||
mixed_precision='bf16' if args.get('use_amp', True) else 'no',
|
||||
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
|
||||
log_with=None, # We'll use our own logging
|
||||
project_dir=args.get('output_dir', './output'),
|
||||
dataloader_config=dataloader_config,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user