This commit is contained in:
Zchen
2025-10-12 21:56:34 +08:00
parent 4dad570eea
commit 0cbb83e052
2 changed files with 67 additions and 24 deletions

View File

@@ -25,8 +25,8 @@ class NoiseModel(nn.Module):
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
@@ -52,7 +52,7 @@ class NoiseModel(nn.Module):
nn.init.xavier_uniform_(param)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size, dtype=torch.bfloat16)))
def forward(self, x, day_idx, states=None):
# Apply day-specific transformation
@@ -110,8 +110,8 @@ class CleanSpeechModel(nn.Module):
# Day-specific input layers
self.day_layer_activation = nn.Softsign()
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)])
self.day_weights = nn.ParameterList([nn.Parameter(torch.eye(self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_biases = nn.ParameterList([nn.Parameter(torch.zeros(1, self.neural_dim, dtype=torch.bfloat16)) for _ in range(self.n_days)])
self.day_layer_dropout = nn.Dropout(input_dropout)
# Calculate input size after patching
@@ -141,7 +141,7 @@ class CleanSpeechModel(nn.Module):
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
def forward(self, x, day_idx, states=None, return_state=False):
# Apply day-specific transformation
@@ -229,7 +229,7 @@ class NoisySpeechModel(nn.Module):
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden state
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units, dtype=torch.bfloat16)))
def forward(self, x, states=None, return_state=False):
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise