diff --git a/CLAUDE.md b/CLAUDE.md index 1995ecd..98728aa 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -162,7 +162,17 @@ day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases # After (XLA-optimized): -x = torch.bmm(x, day_weights) + day_biases # bmm is highly optimized in XLA +x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency +``` + +#### 5. Mixed Precision Dtype Consistency +**Problem**: Mixed precision training causes dtype mismatches in bmm operations +**Solution**: Ensure all operands match input tensor dtype + +```python +# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training +# Fix: Add dtype conversions for all bmm operands +x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) ``` #### 3. Hidden State Initialization @@ -199,11 +209,11 @@ return clean_logits, noisy_logits, noise_output # Simple tuple return ### Files Modified for XLA Optimization - **`model_training_nnn/rnn_model.py`**: All three models optimized - - `NoiseModel.forward()`: Dynamic indexing → static gather operations - - `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + - `NoiseModel.forward()`: Dynamic indexing → static gather operations + dtype consistency + - `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops + dtype consistency - `NoisySpeechModel.forward()`: Hidden state optimization - `TripleGRUDecoder.forward()`: Complex return values → tuple returns - - `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + - `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations + dtype consistency ### Benefits of XLA Optimizations diff --git a/model_training_nnn/rnn_model.py b/model_training_nnn/rnn_model.py index 3ae19fb..625a22d 100644 --- a/model_training_nnn/rnn_model.py +++ b/model_training_nnn/rnn_model.py @@ -1,6 +1,24 @@ import torch from torch import nn +class GradientReversalFn(torch.autograd.Function): + """ + Gradient Reversal Layer (GRL) + Forward: identity + Backward: multiply incoming gradient by -lambda + """ + @staticmethod + def forward(ctx, x, lambd: float): + ctx.lambd = lambd + return x.view_as(x) + + @staticmethod + def backward(ctx, grad_output): + return -ctx.lambd * grad_output, None + +def gradient_reverse(x, lambd: float = 1.0): + return GradientReversalFn.apply(x, lambd) + class NoiseModel(nn.Module): ''' Noise Model: 2-layer GRU that learns to estimate noise in the neural data @@ -361,7 +379,8 @@ class TripleGRUDecoder(nn.Module): day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # Use bmm (batch matrix multiply) which is highly optimized in XLA - x_processed = torch.bmm(x, day_weights) + day_biases + # Ensure dtype consistency for mixed precision training + x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) x_processed = self.clean_speech_model.day_layer_activation(x_processed) # Apply patch processing if enabled @@ -405,7 +424,7 @@ class TripleGRUDecoder(nn.Module): logits = self.noisy_speech_model.out(output) return logits - def forward(self, x, day_idx, states=None, return_state=False, mode='inference'): + def forward(self, x, day_idx, states=None, return_state=False, mode='inference', grl_lambda: float = 0.0): ''' Three-model adversarial forward pass @@ -413,6 +432,7 @@ class TripleGRUDecoder(nn.Module): day_idx (tensor) - tensor of day indices for each example in the batch states (dict) - dictionary with 'noise', 'clean', 'noisy' states or None mode (str) - 'full' for training (all three models), 'inference' for inference (noise + clean only) + grl_lambda (float) - when > 0 and mode='full', applies Gradient Reversal to the noise branch input ''' if mode == 'full': @@ -435,7 +455,9 @@ class TripleGRUDecoder(nn.Module): states['clean'] if states else None) # 4. Noisy speech model processes noise signal directly (no day layers needed) - noisy_logits = self._noisy_forward_with_processed_input(noise_output, + # Optionally apply Gradient Reversal to enforce adversarial training on noise output + noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output + noisy_logits = self._noisy_forward_with_processed_input(noisy_input, states['noisy'] if states else None) # XLA-friendly return - use tuple instead of dict for better compilation diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index c41cb89..a5217fd 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -86,6 +86,14 @@ class BrainToTextDecoder_Trainer: self.transform_args = self.args['dataset']['data_transforms'] + # Adversarial training config (safe defaults if not provided) + adv_cfg = self.args.get('adversarial', {}) + self.adv_enabled = adv_cfg.get('enabled', False) + self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength + self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC + self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output + self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps + # Create output directory if args['mode'] == 'train': os.makedirs(self.args['output_dir'], exist_ok=True) @@ -291,6 +299,8 @@ class BrainToTextDecoder_Trainer: ) self.logger.info("Prepared model and dataloaders with Accelerator") + if self.adv_enabled: + self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}") def create_optimizer(self): ''' @@ -583,17 +593,47 @@ class BrainToTextDecoder_Trainer: # In mixed precision mode, ensure features match the expected precision features = features.to(torch.float32) - logits = self.model(features, day_indicies, None, False, 'inference') + # Forward pass: enable full adversarial mode if configured and past warmup + use_full = self.adv_enabled and (i >= self.adv_warmup_steps) + if use_full: + clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda) + else: + logits = self.model(features, day_indicies, None, False, 'inference') # Calculate CTC Loss - loss = self.ctc_loss( - log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]), - targets = labels, - input_lengths = adjusted_lens, - target_lengths = phone_seq_lens + if use_full: + # Clean CTC loss + clean_loss = self.ctc_loss( + torch.permute(clean_logits.log_softmax(2), [1, 0, 2]), + labels, + adjusted_lens, + phone_seq_lens ) - - loss = torch.mean(loss) # take mean loss over batches + clean_loss = torch.mean(clean_loss) + + # Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗) + noisy_loss = self.ctc_loss( + torch.permute(noisy_logits.log_softmax(2), [1, 0, 2]), + labels, + adjusted_lens, + phone_seq_lens + ) + noisy_loss = torch.mean(noisy_loss) + + # Optional noise energy regularization + noise_l2 = torch.tensor(0.0, device=self.device) + if self.adv_noise_l2_weight > 0.0: + noise_l2 = torch.mean(noise_output.pow(2)) + + loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 + else: + loss = self.ctc_loss( + log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]), + targets = labels, + input_lengths = adjusted_lens, + target_lengths = phone_seq_lens + ) + loss = torch.mean(loss) # take mean loss over batches # Use Accelerator's backward for distributed training self.accelerator.backward(loss) @@ -673,7 +713,7 @@ class BrainToTextDecoder_Trainer: # Optionally save this validation checkpoint, regardless of performance if self.args['save_all_val_steps']: - self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER']) + self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss']) # Early stopping if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps): @@ -689,7 +729,8 @@ class BrainToTextDecoder_Trainer: # Save final model if self.args['save_final_model']: - self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1]) + last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf') + self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss) train_stats = {} train_stats['train_losses'] = train_losses