This commit is contained in:
Zchen
2025-10-14 23:54:53 +08:00
parent 4b6d680283
commit aef96f5646
2 changed files with 64 additions and 24 deletions

View File

@@ -1,5 +1,6 @@
import torch import torch
from torch import nn from torch import nn
from typing import cast
class GradientReversalFn(torch.autograd.Function): class GradientReversalFn(torch.autograd.Function):
""" """
@@ -106,9 +107,15 @@ class NoiseModel(nn.Module):
# Ensure dtype consistency after patch processing operations # Ensure dtype consistency after patch processing operations
x = x.to(original_dtype) x = x.to(original_dtype)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
# XLA-friendly hidden state initialization - avoid dynamic allocation # XLA-friendly hidden state initialization - avoid dynamic allocation
if states is None: if states is None:
states = self.h0.expand(2, batch_size, self.input_size).contiguous() states = self.h0.expand(2, batch_size, self.input_size).contiguous()
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
# GRU forward pass # GRU forward pass
output, hidden_states = self.gru(x, states) output, hidden_states = self.gru(x, states)
@@ -208,9 +215,15 @@ class CleanSpeechModel(nn.Module):
# Ensure dtype consistency after patch processing operations # Ensure dtype consistency after patch processing operations
x = x.to(original_dtype) x = x.to(original_dtype)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
# XLA-friendly hidden state initialization # XLA-friendly hidden state initialization
if states is None: if states is None:
states = self.h0.expand(3, batch_size, self.n_units).contiguous() states = self.h0.expand(3, batch_size, self.n_units).contiguous()
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
# GRU forward pass # GRU forward pass
output, hidden_states = self.gru(x, states) output, hidden_states = self.gru(x, states)
@@ -280,9 +293,21 @@ class NoisySpeechModel(nn.Module):
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise # Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
batch_size = x.size(0) batch_size = x.size(0)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
gru_dtype = next(self.gru.parameters()).dtype
if x.dtype != gru_dtype:
x = x.to(gru_dtype)
# XLA-friendly hidden state initialization # XLA-friendly hidden state initialization
if states is None: if states is None:
states = self.h0.expand(2, batch_size, self.n_units).contiguous() states = self.h0.expand(2, batch_size, self.n_units).contiguous()
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
if states.dtype != gru_dtype:
states = states.to(gru_dtype)
# GRU forward pass # GRU forward pass
output, hidden_states = self.gru(x, states) output, hidden_states = self.gru(x, states)
@@ -407,11 +432,16 @@ class TripleGRUDecoder(nn.Module):
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)''' '''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
batch_size = x_processed.size(0) batch_size = x_processed.size(0)
clean_gru_dtype = next(self.clean_speech_model.gru.parameters()).dtype
if x_processed.dtype != clean_gru_dtype:
x_processed = x_processed.to(clean_gru_dtype)
# XLA-friendly hidden state initialization with dtype consistency # XLA-friendly hidden state initialization with dtype consistency
if states is None: if states is None:
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous() states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training # Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype) if states.dtype != clean_gru_dtype:
states = states.to(clean_gru_dtype)
# GRU forward pass (skip preprocessing since input is already processed) # GRU forward pass (skip preprocessing since input is already processed)
output, hidden_states = self.clean_speech_model.gru(x_processed, states) output, hidden_states = self.clean_speech_model.gru(x_processed, states)
@@ -424,11 +454,16 @@ class TripleGRUDecoder(nn.Module):
'''Forward pass for NoisySpeechModel with already processed input''' '''Forward pass for NoisySpeechModel with already processed input'''
batch_size = x_processed.size(0) batch_size = x_processed.size(0)
noisy_gru_dtype = next(self.noisy_speech_model.gru.parameters()).dtype
if x_processed.dtype != noisy_gru_dtype:
x_processed = x_processed.to(noisy_gru_dtype)
# XLA-friendly hidden state initialization with dtype consistency # XLA-friendly hidden state initialization with dtype consistency
if states is None: if states is None:
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous() states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
# Ensure hidden states match input dtype for mixed precision training # Ensure hidden states match input dtype for mixed precision training
states = states.to(x_processed.dtype) if states.dtype != noisy_gru_dtype:
states = states.to(noisy_gru_dtype)
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway) # GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
output, hidden_states = self.noisy_speech_model.gru(x_processed, states) output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
@@ -458,9 +493,13 @@ class TripleGRUDecoder(nn.Module):
# 2. For residual connection, we need x in the same space as noise_output # 2. For residual connection, we need x in the same space as noise_output
# Apply the same preprocessing that the models use internally # Apply the same preprocessing that the models use internally
x_processed = self._apply_preprocessing(x, day_idx) x_processed = self._apply_preprocessing(x, day_idx)
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
# Ensure dtype consistency between processed input and noise output # Ensure dtype consistency between processed input and noise output
noise_output = noise_output.to(x_processed.dtype) if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
# 3. Clean speech model processes denoised signal # 3. Clean speech model processes denoised signal
denoised_input = x_processed - noise_output # Residual connection in processed space denoised_input = x_processed - noise_output # Residual connection in processed space
@@ -473,9 +512,10 @@ class TripleGRUDecoder(nn.Module):
# 4. Noisy speech model processes noise signal directly (no day layers needed) # 4. Noisy speech model processes noise signal directly (no day layers needed)
# Optionally apply Gradient Reversal to enforce adversarial training on noise output # Optionally apply Gradient Reversal to enforce adversarial training on noise output
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda and grl_lambda != 0.0 else noise_output
# Ensure dtype consistency - GradientReversalFn should preserve dtype, but ensure compatibility noisy_input = cast(torch.Tensor, noisy_input)
# Use x_processed.dtype as reference since it's the main data flow dtype noisy_dtype = next(self.noisy_speech_model.parameters()).dtype
noisy_input = noisy_input.to(x_processed.dtype) if noisy_input.dtype != noisy_dtype:
noisy_input = noisy_input.to(noisy_dtype)
noisy_logits = self._noisy_forward_with_processed_input(noisy_input, noisy_logits = self._noisy_forward_with_processed_input(noisy_input,
states['noisy'] if states else None) states['noisy'] if states else None)
@@ -493,9 +533,13 @@ class TripleGRUDecoder(nn.Module):
# 2. For residual connection, we need x in the same space as noise_output # 2. For residual connection, we need x in the same space as noise_output
x_processed = self._apply_preprocessing(x, day_idx) x_processed = self._apply_preprocessing(x, day_idx)
clean_dtype = next(self.clean_speech_model.parameters()).dtype
if x_processed.dtype != clean_dtype:
x_processed = x_processed.to(clean_dtype)
# Ensure dtype consistency for mixed precision residual connection # Ensure dtype consistency for mixed precision residual connection
noise_output = noise_output.to(x_processed.dtype) if noise_output.dtype != clean_dtype:
noise_output = noise_output.to(clean_dtype)
denoised_input = x_processed - noise_output denoised_input = x_processed - noise_output
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx, clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None) states['clean'] if states else None)
@@ -514,10 +558,6 @@ class TripleGRUDecoder(nn.Module):
clean_grad (tensor) - gradients from clean speech model output layer clean_grad (tensor) - gradients from clean speech model output layer
noisy_grad (tensor) - gradients from noisy speech model output layer noisy_grad (tensor) - gradients from noisy speech model output layer
if grl_lambda and grl_lambda != 0.0:
noisy_input = gradient_reverse(noise_output, grl_lambda)
else:
noisy_input = noise_output
''' '''
# Combine gradients: negative from clean model, positive from noisy model # Combine gradients: negative from clean model, positive from noisy model
combined_grad = -clean_grad + noisy_grad combined_grad = -clean_grad + noisy_grad

View File

@@ -589,9 +589,9 @@ class BrainToTextDecoder_Trainer:
# Get phoneme predictions using inference mode during training # Get phoneme predictions using inference mode during training
# (We use inference mode for simplicity - only clean logits are used for CTC loss) # (We use inference mode for simplicity - only clean logits are used for CTC loss)
# Ensure features tensor matches model parameter dtype for TPU compatibility # Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16': model_param = next(self.model.parameters()) if self.model is not None else None
# In mixed precision mode, ensure features match the expected precision if model_param is not None and features.dtype != model_param.dtype:
features = features.to(torch.float32) features = features.to(model_param.dtype)
# Forward pass: enable full adversarial mode if configured and past warmup # Forward pass: enable full adversarial mode if configured and past warmup
use_full = self.adv_enabled and (i >= self.adv_warmup_steps) use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
@@ -621,7 +621,7 @@ class BrainToTextDecoder_Trainer:
noisy_loss = torch.mean(noisy_loss) noisy_loss = torch.mean(noisy_loss)
# Optional noise energy regularization # Optional noise energy regularization
noise_l2 = torch.tensor(0.0, device=self.device) noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
if self.adv_noise_l2_weight > 0.0: if self.adv_noise_l2_weight > 0.0:
noise_l2 = torch.mean(noise_output.pow(2)) noise_l2 = torch.mean(noise_output.pow(2))
@@ -799,9 +799,9 @@ class BrainToTextDecoder_Trainer:
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility # Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16': model_param = next(self.model.parameters()) if self.model is not None else None
# In mixed precision mode, ensure features match the expected precision if model_param is not None and features.dtype != model_param.dtype:
features = features.to(torch.float32) features = features.to(model_param.dtype)
logits = self.model(features, day_indicies, None, False, 'inference') logits = self.model(features, day_indicies, None, False, 'inference')
@@ -878,9 +878,9 @@ class BrainToTextDecoder_Trainer:
features, n_time_steps = self.transform_data(features, n_time_steps, 'val') features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure features tensor matches model parameter dtype for TPU compatibility # Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16': model_param = next(self.model.parameters()) if self.model is not None else None
# In mixed precision mode, ensure features match the expected precision if model_param is not None and features.dtype != model_param.dtype:
features = features.to(torch.float32) features = features.to(model_param.dtype)
# Get phoneme predictions # Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode) logits = self.model(features, day_indicies, None, False, mode)
@@ -907,9 +907,9 @@ class BrainToTextDecoder_Trainer:
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility # Ensure features tensor matches model parameter dtype for TPU compatibility
if self.accelerator.mixed_precision == 'bf16': model_param = next(self.model.parameters()) if self.model is not None else None
# In mixed precision mode, ensure features match the expected precision if model_param is not None and features.dtype != model_param.dtype:
features = features.to(torch.float32) features = features.to(model_param.dtype)
# Get phoneme predictions # Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode) logits = self.model(features, day_indicies, None, False, mode)