diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 07cc2a2..b365ce9 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -59,10 +59,9 @@ class BrainToTextDecoder_Trainer: ) # Initialize Accelerator for TPU/multi-device support - use_xla = bool(xm.get_xla_supported_devices()) - mixed_precision_mode = 'no' - if not use_xla and args.get('use_amp', True): - mixed_precision_mode = 'bf16' + self.use_xla = bool(xm.get_xla_supported_devices()) + self.amp_requested = args.get('use_amp', True) + mixed_precision_mode = 'bf16' if self.amp_requested else 'no' self.accelerator = Accelerator( mixed_precision=mixed_precision_mode, @@ -131,8 +130,8 @@ class BrainToTextDecoder_Trainer: self.logger.info(f'Accelerator state: {self.accelerator.state}') if self.accelerator.num_processes > 1: self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes') - if mixed_precision_mode == 'no' and use_xla and args.get('use_amp', True): - self.logger.info('AMP requested but disabled on XLA to avoid dtype mismatches; running in float32 on TPU.') + if self.use_xla and self.amp_requested: + self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.') # Set seed if provided (using Accelerator's set_seed for proper distributed seeding) if self.args['seed'] != -1: @@ -150,6 +149,12 @@ class BrainToTextDecoder_Trainer: patch_stride = self.args['model']['patch_stride'], ) + if self.use_xla and self.amp_requested: + self.model = self.model.to(torch.bfloat16) + self.logger.info('Converted model parameters to bfloat16 for TPU training.') + + self.model_dtype = next(self.model.parameters()).dtype + # Temporarily disable torch.compile for compatibility with new model architecture # TODO: Re-enable torch.compile once model is stable # self.logger.info("Using torch.compile") @@ -306,6 +311,8 @@ class BrainToTextDecoder_Trainer: self.val_loader, ) + self.model_dtype = next(self.model.parameters()).dtype + self.logger.info("Prepared model and dataloaders with Accelerator") if self.adv_enabled: self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}") @@ -541,11 +548,14 @@ class BrainToTextDecoder_Trainer: # This is done in both training and validation if self.transform_args['smooth_data']: features = gauss_smooth( - inputs = features, + inputs = features, device = self.device, smooth_kernel_std = self.transform_args['smooth_kernel_std'], smooth_kernel_size= self.transform_args['smooth_kernel_size'], ) + + if hasattr(self, 'model_dtype'): + features = features.to(self.model_dtype) return features, n_time_steps @@ -603,9 +613,8 @@ class BrainToTextDecoder_Trainer: # Get phoneme predictions using inference mode during training # (We use inference mode for simplicity - only clean logits are used for CTC loss) # Ensure features tensor matches model parameter dtype for TPU compatibility - model_param = next(self.model.parameters()) if self.model is not None else None - if model_param is not None and features.dtype != model_param.dtype: - features = features.to(model_param.dtype) + if features.dtype != self.model_dtype: + features = features.to(self.model_dtype) # Forward pass: enable full adversarial mode if configured and past warmup use_full = self.adv_enabled and (i >= self.adv_warmup_steps) @@ -617,8 +626,9 @@ class BrainToTextDecoder_Trainer: # Calculate CTC Loss if use_full: # Clean CTC loss + clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2) clean_loss = self.ctc_loss( - torch.permute(clean_logits.log_softmax(2), [1, 0, 2]), + clean_log_probs, labels, adjusted_lens, phone_seq_lens @@ -626,8 +636,9 @@ class BrainToTextDecoder_Trainer: clean_loss = torch.mean(clean_loss) # Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗) + noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2) noisy_loss = self.ctc_loss( - torch.permute(noisy_logits.log_softmax(2), [1, 0, 2]), + noisy_log_probs, labels, adjusted_lens, phone_seq_lens @@ -637,16 +648,17 @@ class BrainToTextDecoder_Trainer: # Optional noise energy regularization noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype) if self.adv_noise_l2_weight > 0.0: - noise_l2 = torch.mean(noise_output.pow(2)) + noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype) loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: + log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2) loss = self.ctc_loss( - log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]), - targets = labels, - input_lengths = adjusted_lens, - target_lengths = phone_seq_lens - ) + log_probs=log_probs, + targets=labels, + input_lengths=adjusted_lens, + target_lengths=phone_seq_lens + ) loss = torch.mean(loss) # take mean loss over batches # Use Accelerator's backward for distributed training @@ -818,9 +830,10 @@ class BrainToTextDecoder_Trainer: features = features.to(model_param.dtype) logits = self.model(features, day_indicies, None, False, 'inference') - + + val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2) loss = self.ctc_loss( - torch.permute(logits.log_softmax(2), [1, 0, 2]), + val_log_probs, labels, adjusted_lens, phone_seq_lens, @@ -872,11 +885,16 @@ class BrainToTextDecoder_Trainer: metrics['trial_nums'].append(batch['trial_nums'].numpy()) metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy()) - avg_PER = total_edit_distance / total_seq_length + if isinstance(total_seq_length, torch.Tensor): + total_length_value = float(total_seq_length.item()) + else: + total_length_value = float(total_seq_length) + + avg_PER = total_edit_distance / max(total_length_value, 1e-6) metrics['day_PERs'] = day_per - metrics['avg_PER'] = avg_PER.item() - metrics['avg_loss'] = np.mean(metrics['losses']) + metrics['avg_PER'] = avg_PER + metrics['avg_loss'] = float(np.mean(metrics['losses'])) return metrics @@ -892,9 +910,8 @@ class BrainToTextDecoder_Trainer: features, n_time_steps = self.transform_data(features, n_time_steps, 'val') # Ensure features tensor matches model parameter dtype for TPU compatibility - model_param = next(self.model.parameters()) if self.model is not None else None - if model_param is not None and features.dtype != model_param.dtype: - features = features.to(model_param.dtype) + if features.dtype != self.model_dtype: + features = features.to(self.model_dtype) # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode) @@ -921,9 +938,8 @@ class BrainToTextDecoder_Trainer: adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Ensure features tensor matches model parameter dtype for TPU compatibility - model_param = next(self.model.parameters()) if self.model is not None else None - if model_param is not None and features.dtype != model_param.dtype: - features = features.to(model_param.dtype) + if features.dtype != self.model_dtype: + features = features.to(self.model_dtype) # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode)