tpu without bf16
This commit is contained in:
@@ -11,6 +11,7 @@ import logging
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
from contextlib import nullcontext
|
||||
|
||||
from dataset import BrainToTextDataset, train_test_split_indicies
|
||||
from data_augmentations import gauss_smooth
|
||||
@@ -302,6 +303,12 @@ class BrainToTextDecoder_Trainer:
|
||||
if self.adv_enabled:
|
||||
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
|
||||
|
||||
def autocast_context(self):
|
||||
"""Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
|
||||
if self.device.type == 'xla':
|
||||
return nullcontext()
|
||||
return self.accelerator.autocast()
|
||||
|
||||
def create_optimizer(self):
|
||||
'''
|
||||
Create the optimizer with special param groups
|
||||
@@ -578,7 +585,7 @@ class BrainToTextDecoder_Trainer:
|
||||
day_indicies = batch['day_indicies']
|
||||
|
||||
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
|
||||
with self.accelerator.autocast():
|
||||
with self.autocast_context():
|
||||
|
||||
# Apply augmentations to the data
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
|
||||
@@ -792,7 +799,7 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
with self.accelerator.autocast():
|
||||
with self.autocast_context():
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
# Ensure proper dtype handling for TPU mixed precision
|
||||
@@ -873,7 +880,7 @@ class BrainToTextDecoder_Trainer:
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
with self.accelerator.autocast():
|
||||
with self.autocast_context():
|
||||
# Apply data transformations (no augmentation for inference)
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
@@ -899,7 +906,7 @@ class BrainToTextDecoder_Trainer:
|
||||
n_time_steps = batch['n_time_steps']
|
||||
|
||||
with torch.no_grad():
|
||||
with self.accelerator.autocast():
|
||||
with self.autocast_context():
|
||||
# Apply data transformations (no augmentation for inference)
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
|
Reference in New Issue
Block a user