From 603bb122208658cec5d2b4905e12f4a15ce4cb75 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 15 Oct 2025 00:18:05 +0800 Subject: [PATCH] tpu --- model_training_nnn/rnn_model.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/model_training_nnn/rnn_model.py b/model_training_nnn/rnn_model.py index 8cc0219..5aa91d5 100644 --- a/model_training_nnn/rnn_model.py +++ b/model_training_nnn/rnn_model.py @@ -117,8 +117,10 @@ class NoiseModel(nn.Module): if states.dtype != gru_dtype: states = states.to(gru_dtype) - # GRU forward pass - output, hidden_states = self.gru(x, states) + # Disable autocast for GRU to avoid dtype mismatches on XLA + device_type = x.device.type + with torch.autocast(device_type=device_type, enabled=False): + output, hidden_states = self.gru(x, states) return output, hidden_states @@ -225,8 +227,9 @@ class CleanSpeechModel(nn.Module): if states.dtype != gru_dtype: states = states.to(gru_dtype) - # GRU forward pass - output, hidden_states = self.gru(x, states) + device_type = x.device.type + with torch.autocast(device_type=device_type, enabled=False): + output, hidden_states = self.gru(x, states) # Classification logits = self.out(output) @@ -309,8 +312,9 @@ class NoisySpeechModel(nn.Module): if states.dtype != gru_dtype: states = states.to(gru_dtype) - # GRU forward pass - output, hidden_states = self.gru(x, states) + device_type = x.device.type + with torch.autocast(device_type=device_type, enabled=False): + output, hidden_states = self.gru(x, states) # Classification logits = self.out(output) @@ -444,7 +448,9 @@ class TripleGRUDecoder(nn.Module): states = states.to(clean_gru_dtype) # GRU forward pass (skip preprocessing since input is already processed) - output, hidden_states = self.clean_speech_model.gru(x_processed, states) + device_type = x_processed.device.type + with torch.autocast(device_type=device_type, enabled=False): + output, hidden_states = self.clean_speech_model.gru(x_processed, states) # Classification logits = self.clean_speech_model.out(output) @@ -466,7 +472,9 @@ class TripleGRUDecoder(nn.Module): states = states.to(noisy_gru_dtype) # GRU forward pass (NoisySpeechModel doesn't have day layers anyway) - output, hidden_states = self.noisy_speech_model.gru(x_processed, states) + device_type = x_processed.device.type + with torch.autocast(device_type=device_type, enabled=False): + output, hidden_states = self.noisy_speech_model.gru(x_processed, states) # Classification logits = self.noisy_speech_model.out(output)