This commit is contained in:
Zchen
2025-10-14 23:54:53 +08:00
parent 4b6d680283
commit aef96f5646
2 changed files with 64 additions and 24 deletions

View File

@@ -589,9 +589,9 @@ class BrainToTextDecoder_Trainer:
# Get phoneme predictions using inference mode during training
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16':
# In mixed precision mode, ensure features match the expected precision
features = features.to(torch.float32)
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
# Forward pass: enable full adversarial mode if configured and past warmup
use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
@@ -621,7 +621,7 @@ class BrainToTextDecoder_Trainer:
noisy_loss = torch.mean(noisy_loss)
# Optional noise energy regularization
noise_l2 = torch.tensor(0.0, device=self.device)
noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
if self.adv_noise_l2_weight > 0.0:
noise_l2 = torch.mean(noise_output.pow(2))
@@ -799,9 +799,9 @@ class BrainToTextDecoder_Trainer:
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16':
# In mixed precision mode, ensure features match the expected precision
features = features.to(torch.float32)
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
logits = self.model(features, day_indicies, None, False, 'inference')
@@ -878,9 +878,9 @@ class BrainToTextDecoder_Trainer:
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16':
# In mixed precision mode, ensure features match the expected precision
features = features.to(torch.float32)
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
@@ -907,9 +907,9 @@ class BrainToTextDecoder_Trainer:
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16':
# In mixed precision mode, ensure features match the expected precision
features = features.to(torch.float32)
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)