tpu without bf16
This commit is contained in:
@@ -11,6 +11,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
from dataset import BrainToTextDataset, train_test_split_indicies
|
from dataset import BrainToTextDataset, train_test_split_indicies
|
||||||
from data_augmentations import gauss_smooth
|
from data_augmentations import gauss_smooth
|
||||||
@@ -302,6 +303,12 @@ class BrainToTextDecoder_Trainer:
|
|||||||
if self.adv_enabled:
|
if self.adv_enabled:
|
||||||
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
|
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
|
||||||
|
|
||||||
|
def autocast_context(self):
|
||||||
|
"""Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
|
||||||
|
if self.device.type == 'xla':
|
||||||
|
return nullcontext()
|
||||||
|
return self.accelerator.autocast()
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
'''
|
'''
|
||||||
Create the optimizer with special param groups
|
Create the optimizer with special param groups
|
||||||
@@ -578,7 +585,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
day_indicies = batch['day_indicies']
|
day_indicies = batch['day_indicies']
|
||||||
|
|
||||||
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
|
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
|
||||||
with self.accelerator.autocast():
|
with self.autocast_context():
|
||||||
|
|
||||||
# Apply augmentations to the data
|
# Apply augmentations to the data
|
||||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
|
||||||
@@ -792,7 +799,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
with self.accelerator.autocast():
|
with self.autocast_context():
|
||||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
# Ensure proper dtype handling for TPU mixed precision
|
# Ensure proper dtype handling for TPU mixed precision
|
||||||
@@ -873,7 +880,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with self.accelerator.autocast():
|
with self.autocast_context():
|
||||||
# Apply data transformations (no augmentation for inference)
|
# Apply data transformations (no augmentation for inference)
|
||||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
@@ -899,7 +906,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
n_time_steps = batch['n_time_steps']
|
n_time_steps = batch['n_time_steps']
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with self.accelerator.autocast():
|
with self.autocast_context():
|
||||||
# Apply data transformations (no augmentation for inference)
|
# Apply data transformations (no augmentation for inference)
|
||||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user