tpu
This commit is contained in:
@@ -447,7 +447,8 @@ class TripleGRUDecoder(nn.Module):
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
|
||||
# 3. Clean speech model processes denoised signal
|
||||
denoised_input = x_processed - noise_output # Residual connection in processed space
|
||||
# Ensure dtype consistency for mixed precision training in residual connection
|
||||
denoised_input = x_processed - noise_output.to(x_processed.dtype) # Residual connection in processed space
|
||||
# Clean speech model will apply its own preprocessing, so we pass the denoised processed data
|
||||
# But we need to reverse the preprocessing first, then let clean model do its own
|
||||
# Actually, it's simpler to pass the residual directly to clean model after bypassing its preprocessing
|
||||
@@ -476,7 +477,8 @@ class TripleGRUDecoder(nn.Module):
|
||||
x_processed = self._apply_preprocessing(x, day_idx)
|
||||
|
||||
# 3. Process denoised signal
|
||||
denoised_input = x_processed - noise_output
|
||||
# Ensure dtype consistency for mixed precision training in residual connection
|
||||
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
||||
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
|
||||
states['clean'] if states else None)
|
||||
|
||||
|
Reference in New Issue
Block a user