This commit is contained in:
Zchen
2025-10-15 00:44:08 +08:00
parent 5dcbf28c96
commit 11ee6ebc51

View File

@@ -59,10 +59,9 @@ class BrainToTextDecoder_Trainer:
)
# Initialize Accelerator for TPU/multi-device support
use_xla = bool(xm.get_xla_supported_devices())
mixed_precision_mode = 'no'
if not use_xla and args.get('use_amp', True):
mixed_precision_mode = 'bf16'
self.use_xla = bool(xm.get_xla_supported_devices())
self.amp_requested = args.get('use_amp', True)
mixed_precision_mode = 'bf16' if self.amp_requested else 'no'
self.accelerator = Accelerator(
mixed_precision=mixed_precision_mode,
@@ -131,8 +130,8 @@ class BrainToTextDecoder_Trainer:
self.logger.info(f'Accelerator state: {self.accelerator.state}')
if self.accelerator.num_processes > 1:
self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
if mixed_precision_mode == 'no' and use_xla and args.get('use_amp', True):
self.logger.info('AMP requested but disabled on XLA to avoid dtype mismatches; running in float32 on TPU.')
if self.use_xla and self.amp_requested:
self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.')
# Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
if self.args['seed'] != -1:
@@ -150,6 +149,12 @@ class BrainToTextDecoder_Trainer:
patch_stride = self.args['model']['patch_stride'],
)
if self.use_xla and self.amp_requested:
self.model = self.model.to(torch.bfloat16)
self.logger.info('Converted model parameters to bfloat16 for TPU training.')
self.model_dtype = next(self.model.parameters()).dtype
# Temporarily disable torch.compile for compatibility with new model architecture
# TODO: Re-enable torch.compile once model is stable
# self.logger.info("Using torch.compile")
@@ -306,6 +311,8 @@ class BrainToTextDecoder_Trainer:
self.val_loader,
)
self.model_dtype = next(self.model.parameters()).dtype
self.logger.info("Prepared model and dataloaders with Accelerator")
if self.adv_enabled:
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
@@ -547,6 +554,9 @@ class BrainToTextDecoder_Trainer:
smooth_kernel_size= self.transform_args['smooth_kernel_size'],
)
if hasattr(self, 'model_dtype'):
features = features.to(self.model_dtype)
return features, n_time_steps
@@ -603,9 +613,8 @@ class BrainToTextDecoder_Trainer:
# Get phoneme predictions using inference mode during training
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
# Ensure features tensor matches model parameter dtype for TPU compatibility
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Forward pass: enable full adversarial mode if configured and past warmup
use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
@@ -617,8 +626,9 @@ class BrainToTextDecoder_Trainer:
# Calculate CTC Loss
if use_full:
# Clean CTC loss
clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2)
clean_loss = self.ctc_loss(
torch.permute(clean_logits.log_softmax(2), [1, 0, 2]),
clean_log_probs,
labels,
adjusted_lens,
phone_seq_lens
@@ -626,8 +636,9 @@ class BrainToTextDecoder_Trainer:
clean_loss = torch.mean(clean_loss)
# Noisy branch CTC loss让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2)
noisy_loss = self.ctc_loss(
torch.permute(noisy_logits.log_softmax(2), [1, 0, 2]),
noisy_log_probs,
labels,
adjusted_lens,
phone_seq_lens
@@ -637,16 +648,17 @@ class BrainToTextDecoder_Trainer:
# Optional noise energy regularization
noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
if self.adv_noise_l2_weight > 0.0:
noise_l2 = torch.mean(noise_output.pow(2))
noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype)
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
loss = self.ctc_loss(
log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]),
targets = labels,
input_lengths = adjusted_lens,
target_lengths = phone_seq_lens
)
log_probs=log_probs,
targets=labels,
input_lengths=adjusted_lens,
target_lengths=phone_seq_lens
)
loss = torch.mean(loss) # take mean loss over batches
# Use Accelerator's backward for distributed training
@@ -819,8 +831,9 @@ class BrainToTextDecoder_Trainer:
logits = self.model(features, day_indicies, None, False, 'inference')
val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
loss = self.ctc_loss(
torch.permute(logits.log_softmax(2), [1, 0, 2]),
val_log_probs,
labels,
adjusted_lens,
phone_seq_lens,
@@ -872,11 +885,16 @@ class BrainToTextDecoder_Trainer:
metrics['trial_nums'].append(batch['trial_nums'].numpy())
metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
avg_PER = total_edit_distance / total_seq_length
if isinstance(total_seq_length, torch.Tensor):
total_length_value = float(total_seq_length.item())
else:
total_length_value = float(total_seq_length)
avg_PER = total_edit_distance / max(total_length_value, 1e-6)
metrics['day_PERs'] = day_per
metrics['avg_PER'] = avg_PER.item()
metrics['avg_loss'] = np.mean(metrics['losses'])
metrics['avg_PER'] = avg_PER
metrics['avg_loss'] = float(np.mean(metrics['losses']))
return metrics
@@ -892,9 +910,8 @@ class BrainToTextDecoder_Trainer:
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure features tensor matches model parameter dtype for TPU compatibility
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
@@ -921,9 +938,8 @@ class BrainToTextDecoder_Trainer:
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)