tpu支持

This commit is contained in:
Zchen
2025-10-12 18:41:26 +08:00
parent 1a906d3248
commit 40e4d00576
5 changed files with 456 additions and 226 deletions

View File

@@ -90,12 +90,12 @@ class BrainToTextDataset(Dataset):
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
def __len__(self): def __len__(self):
'''
How many batches are in this dataset.
Because training data is sampled randomly, there is no fixed dataset length,
however this method is required for DataLoader to work
''' '''
return self.n_batches How many batches are in this dataset.
Because training data is sampled randomly, there is no fixed dataset length,
however this method is required for DataLoader to work
'''
return self.n_batches if self.n_batches is not None else 0
def __getitem__(self, idx): def __getitem__(self, idx):
''' '''
@@ -269,7 +269,9 @@ def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_
# Get trials in each day # Get trials in each day
trials_per_day = {} trials_per_day = {}
for i, path in enumerate(file_paths): for i, path in enumerate(file_paths):
session = [s for s in path.split('/') if (s.startswith('t15.20') or s.startswith('t12.20'))][0] # Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
good_trial_indices = [] good_trial_indices = []

View File

@@ -72,11 +72,11 @@ class BrainToTextDecoder_Trainer:
# Create output directory # Create output directory
if args['mode'] == 'train': if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=False) os.makedirs(self.args['output_dir'], exist_ok=True)
# Create checkpoint directory # Create checkpoint directory
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']: if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok=False) os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
# Set up logging # Set up logging
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
@@ -188,12 +188,16 @@ class BrainToTextDecoder_Trainer:
# Use TPU-optimized dataloader settings if TPU is enabled # Use TPU-optimized dataloader settings if TPU is enabled
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers'] num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
# For TPU environments or when batch_size=None causes issues, use batch_size=1
# since our dataset already returns complete batches
batch_size_setting = 1 if (self.args.get('use_tpu', False) or self.accelerator.device.type == 'xla') else None
self.train_loader = DataLoader( self.train_loader = DataLoader(
self.train_dataset, self.train_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
shuffle = self.args['dataset']['loader_shuffle'], shuffle = self.args['dataset']['loader_shuffle'],
num_workers = num_workers, num_workers = num_workers,
pin_memory = True pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
) )
# val dataset and dataloader # val dataset and dataloader
@@ -209,10 +213,10 @@ class BrainToTextDecoder_Trainer:
) )
self.val_loader = DataLoader( self.val_loader = DataLoader(
self.val_dataset, self.val_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
shuffle = False, shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
) )
self.logger.info("Successfully initialized datasets") self.logger.info("Successfully initialized datasets")

View File

@@ -1,213 +0,0 @@
2025-10-12 09:32:47,330: Using device: cpu
2025-10-12 09:33:20,119: torch.compile disabled for new TripleGRUDecoder compatibility
2025-10-12 09:33:20,120: Initialized RNN decoding model
2025-10-12 09:33:20,120: Model type: <class 'rnn_model.TripleGRUDecoder'>
2025-10-12 09:33:20,120: Model is callable: True
2025-10-12 09:33:20,120: Model has forward method: True
2025-10-12 09:33:20,120: TripleGRUDecoder(
(noise_model): NoiseModel(
(day_layer_activation): Softsign()
(day_weights): ParameterList(
(0): Parameter containing: [torch.float32 of size 512x512]
(1): Parameter containing: [torch.float32 of size 512x512]
(2): Parameter containing: [torch.float32 of size 512x512]
(3): Parameter containing: [torch.float32 of size 512x512]
(4): Parameter containing: [torch.float32 of size 512x512]
(5): Parameter containing: [torch.float32 of size 512x512]
(6): Parameter containing: [torch.float32 of size 512x512]
(7): Parameter containing: [torch.float32 of size 512x512]
(8): Parameter containing: [torch.float32 of size 512x512]
(9): Parameter containing: [torch.float32 of size 512x512]
(10): Parameter containing: [torch.float32 of size 512x512]
(11): Parameter containing: [torch.float32 of size 512x512]
(12): Parameter containing: [torch.float32 of size 512x512]
(13): Parameter containing: [torch.float32 of size 512x512]
(14): Parameter containing: [torch.float32 of size 512x512]
(15): Parameter containing: [torch.float32 of size 512x512]
(16): Parameter containing: [torch.float32 of size 512x512]
(17): Parameter containing: [torch.float32 of size 512x512]
(18): Parameter containing: [torch.float32 of size 512x512]
(19): Parameter containing: [torch.float32 of size 512x512]
(20): Parameter containing: [torch.float32 of size 512x512]
(21): Parameter containing: [torch.float32 of size 512x512]
(22): Parameter containing: [torch.float32 of size 512x512]
(23): Parameter containing: [torch.float32 of size 512x512]
(24): Parameter containing: [torch.float32 of size 512x512]
(25): Parameter containing: [torch.float32 of size 512x512]
(26): Parameter containing: [torch.float32 of size 512x512]
(27): Parameter containing: [torch.float32 of size 512x512]
(28): Parameter containing: [torch.float32 of size 512x512]
(29): Parameter containing: [torch.float32 of size 512x512]
(30): Parameter containing: [torch.float32 of size 512x512]
(31): Parameter containing: [torch.float32 of size 512x512]
(32): Parameter containing: [torch.float32 of size 512x512]
(33): Parameter containing: [torch.float32 of size 512x512]
(34): Parameter containing: [torch.float32 of size 512x512]
(35): Parameter containing: [torch.float32 of size 512x512]
(36): Parameter containing: [torch.float32 of size 512x512]
(37): Parameter containing: [torch.float32 of size 512x512]
(38): Parameter containing: [torch.float32 of size 512x512]
(39): Parameter containing: [torch.float32 of size 512x512]
(40): Parameter containing: [torch.float32 of size 512x512]
(41): Parameter containing: [torch.float32 of size 512x512]
(42): Parameter containing: [torch.float32 of size 512x512]
(43): Parameter containing: [torch.float32 of size 512x512]
(44): Parameter containing: [torch.float32 of size 512x512]
)
(day_biases): ParameterList(
(0): Parameter containing: [torch.float32 of size 1x512]
(1): Parameter containing: [torch.float32 of size 1x512]
(2): Parameter containing: [torch.float32 of size 1x512]
(3): Parameter containing: [torch.float32 of size 1x512]
(4): Parameter containing: [torch.float32 of size 1x512]
(5): Parameter containing: [torch.float32 of size 1x512]
(6): Parameter containing: [torch.float32 of size 1x512]
(7): Parameter containing: [torch.float32 of size 1x512]
(8): Parameter containing: [torch.float32 of size 1x512]
(9): Parameter containing: [torch.float32 of size 1x512]
(10): Parameter containing: [torch.float32 of size 1x512]
(11): Parameter containing: [torch.float32 of size 1x512]
(12): Parameter containing: [torch.float32 of size 1x512]
(13): Parameter containing: [torch.float32 of size 1x512]
(14): Parameter containing: [torch.float32 of size 1x512]
(15): Parameter containing: [torch.float32 of size 1x512]
(16): Parameter containing: [torch.float32 of size 1x512]
(17): Parameter containing: [torch.float32 of size 1x512]
(18): Parameter containing: [torch.float32 of size 1x512]
(19): Parameter containing: [torch.float32 of size 1x512]
(20): Parameter containing: [torch.float32 of size 1x512]
(21): Parameter containing: [torch.float32 of size 1x512]
(22): Parameter containing: [torch.float32 of size 1x512]
(23): Parameter containing: [torch.float32 of size 1x512]
(24): Parameter containing: [torch.float32 of size 1x512]
(25): Parameter containing: [torch.float32 of size 1x512]
(26): Parameter containing: [torch.float32 of size 1x512]
(27): Parameter containing: [torch.float32 of size 1x512]
(28): Parameter containing: [torch.float32 of size 1x512]
(29): Parameter containing: [torch.float32 of size 1x512]
(30): Parameter containing: [torch.float32 of size 1x512]
(31): Parameter containing: [torch.float32 of size 1x512]
(32): Parameter containing: [torch.float32 of size 1x512]
(33): Parameter containing: [torch.float32 of size 1x512]
(34): Parameter containing: [torch.float32 of size 1x512]
(35): Parameter containing: [torch.float32 of size 1x512]
(36): Parameter containing: [torch.float32 of size 1x512]
(37): Parameter containing: [torch.float32 of size 1x512]
(38): Parameter containing: [torch.float32 of size 1x512]
(39): Parameter containing: [torch.float32 of size 1x512]
(40): Parameter containing: [torch.float32 of size 1x512]
(41): Parameter containing: [torch.float32 of size 1x512]
(42): Parameter containing: [torch.float32 of size 1x512]
(43): Parameter containing: [torch.float32 of size 1x512]
(44): Parameter containing: [torch.float32 of size 1x512]
)
(day_layer_dropout): Dropout(p=0.2, inplace=False)
(gru): GRU(7168, 7168, num_layers=2, batch_first=True, dropout=0.4)
)
(clean_speech_model): CleanSpeechModel(
(day_layer_activation): Softsign()
(day_weights): ParameterList(
(0): Parameter containing: [torch.float32 of size 512x512]
(1): Parameter containing: [torch.float32 of size 512x512]
(2): Parameter containing: [torch.float32 of size 512x512]
(3): Parameter containing: [torch.float32 of size 512x512]
(4): Parameter containing: [torch.float32 of size 512x512]
(5): Parameter containing: [torch.float32 of size 512x512]
(6): Parameter containing: [torch.float32 of size 512x512]
(7): Parameter containing: [torch.float32 of size 512x512]
(8): Parameter containing: [torch.float32 of size 512x512]
(9): Parameter containing: [torch.float32 of size 512x512]
(10): Parameter containing: [torch.float32 of size 512x512]
(11): Parameter containing: [torch.float32 of size 512x512]
(12): Parameter containing: [torch.float32 of size 512x512]
(13): Parameter containing: [torch.float32 of size 512x512]
(14): Parameter containing: [torch.float32 of size 512x512]
(15): Parameter containing: [torch.float32 of size 512x512]
(16): Parameter containing: [torch.float32 of size 512x512]
(17): Parameter containing: [torch.float32 of size 512x512]
(18): Parameter containing: [torch.float32 of size 512x512]
(19): Parameter containing: [torch.float32 of size 512x512]
(20): Parameter containing: [torch.float32 of size 512x512]
(21): Parameter containing: [torch.float32 of size 512x512]
(22): Parameter containing: [torch.float32 of size 512x512]
(23): Parameter containing: [torch.float32 of size 512x512]
(24): Parameter containing: [torch.float32 of size 512x512]
(25): Parameter containing: [torch.float32 of size 512x512]
(26): Parameter containing: [torch.float32 of size 512x512]
(27): Parameter containing: [torch.float32 of size 512x512]
(28): Parameter containing: [torch.float32 of size 512x512]
(29): Parameter containing: [torch.float32 of size 512x512]
(30): Parameter containing: [torch.float32 of size 512x512]
(31): Parameter containing: [torch.float32 of size 512x512]
(32): Parameter containing: [torch.float32 of size 512x512]
(33): Parameter containing: [torch.float32 of size 512x512]
(34): Parameter containing: [torch.float32 of size 512x512]
(35): Parameter containing: [torch.float32 of size 512x512]
(36): Parameter containing: [torch.float32 of size 512x512]
(37): Parameter containing: [torch.float32 of size 512x512]
(38): Parameter containing: [torch.float32 of size 512x512]
(39): Parameter containing: [torch.float32 of size 512x512]
(40): Parameter containing: [torch.float32 of size 512x512]
(41): Parameter containing: [torch.float32 of size 512x512]
(42): Parameter containing: [torch.float32 of size 512x512]
(43): Parameter containing: [torch.float32 of size 512x512]
(44): Parameter containing: [torch.float32 of size 512x512]
)
(day_biases): ParameterList(
(0): Parameter containing: [torch.float32 of size 1x512]
(1): Parameter containing: [torch.float32 of size 1x512]
(2): Parameter containing: [torch.float32 of size 1x512]
(3): Parameter containing: [torch.float32 of size 1x512]
(4): Parameter containing: [torch.float32 of size 1x512]
(5): Parameter containing: [torch.float32 of size 1x512]
(6): Parameter containing: [torch.float32 of size 1x512]
(7): Parameter containing: [torch.float32 of size 1x512]
(8): Parameter containing: [torch.float32 of size 1x512]
(9): Parameter containing: [torch.float32 of size 1x512]
(10): Parameter containing: [torch.float32 of size 1x512]
(11): Parameter containing: [torch.float32 of size 1x512]
(12): Parameter containing: [torch.float32 of size 1x512]
(13): Parameter containing: [torch.float32 of size 1x512]
(14): Parameter containing: [torch.float32 of size 1x512]
(15): Parameter containing: [torch.float32 of size 1x512]
(16): Parameter containing: [torch.float32 of size 1x512]
(17): Parameter containing: [torch.float32 of size 1x512]
(18): Parameter containing: [torch.float32 of size 1x512]
(19): Parameter containing: [torch.float32 of size 1x512]
(20): Parameter containing: [torch.float32 of size 1x512]
(21): Parameter containing: [torch.float32 of size 1x512]
(22): Parameter containing: [torch.float32 of size 1x512]
(23): Parameter containing: [torch.float32 of size 1x512]
(24): Parameter containing: [torch.float32 of size 1x512]
(25): Parameter containing: [torch.float32 of size 1x512]
(26): Parameter containing: [torch.float32 of size 1x512]
(27): Parameter containing: [torch.float32 of size 1x512]
(28): Parameter containing: [torch.float32 of size 1x512]
(29): Parameter containing: [torch.float32 of size 1x512]
(30): Parameter containing: [torch.float32 of size 1x512]
(31): Parameter containing: [torch.float32 of size 1x512]
(32): Parameter containing: [torch.float32 of size 1x512]
(33): Parameter containing: [torch.float32 of size 1x512]
(34): Parameter containing: [torch.float32 of size 1x512]
(35): Parameter containing: [torch.float32 of size 1x512]
(36): Parameter containing: [torch.float32 of size 1x512]
(37): Parameter containing: [torch.float32 of size 1x512]
(38): Parameter containing: [torch.float32 of size 1x512]
(39): Parameter containing: [torch.float32 of size 1x512]
(40): Parameter containing: [torch.float32 of size 1x512]
(41): Parameter containing: [torch.float32 of size 1x512]
(42): Parameter containing: [torch.float32 of size 1x512]
(43): Parameter containing: [torch.float32 of size 1x512]
(44): Parameter containing: [torch.float32 of size 1x512]
)
(day_layer_dropout): Dropout(p=0.2, inplace=False)
(gru): GRU(7168, 768, num_layers=3, batch_first=True, dropout=0.4)
(out): Linear(in_features=768, out_features=41, bias=True)
)
(noisy_speech_model): NoisySpeechModel(
(gru): GRU(7168, 768, num_layers=2, batch_first=True, dropout=0.4)
(out): Linear(in_features=768, out_features=41, bias=True)
)
)
2025-10-12 09:33:20,124: Model has 687,568,466 parameters
2025-10-12 09:33:20,124: Model has 23,639,040 day-specific parameters | 3.44% of total parameters

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,436 @@
2025-10-12 18:35:01,395: Using device: cpu
2025-10-12 18:35:01,395: Accelerator state: Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cpu
Mixed precision type: bf16
2025-10-12 18:35:38,176: torch.compile disabled for new TripleGRUDecoder compatibility
2025-10-12 18:35:38,176: Initialized RNN decoding model
2025-10-12 18:35:38,176: TripleGRUDecoder(
(noise_model): NoiseModel(
(day_layer_activation): Softsign()
(day_weights): ParameterList(
(0): Parameter containing: [torch.float32 of size 512x512]
(1): Parameter containing: [torch.float32 of size 512x512]
(2): Parameter containing: [torch.float32 of size 512x512]
(3): Parameter containing: [torch.float32 of size 512x512]
(4): Parameter containing: [torch.float32 of size 512x512]
(5): Parameter containing: [torch.float32 of size 512x512]
(6): Parameter containing: [torch.float32 of size 512x512]
(7): Parameter containing: [torch.float32 of size 512x512]
(8): Parameter containing: [torch.float32 of size 512x512]
(9): Parameter containing: [torch.float32 of size 512x512]
(10): Parameter containing: [torch.float32 of size 512x512]
(11): Parameter containing: [torch.float32 of size 512x512]
(12): Parameter containing: [torch.float32 of size 512x512]
(13): Parameter containing: [torch.float32 of size 512x512]
(14): Parameter containing: [torch.float32 of size 512x512]
(15): Parameter containing: [torch.float32 of size 512x512]
(16): Parameter containing: [torch.float32 of size 512x512]
(17): Parameter containing: [torch.float32 of size 512x512]
(18): Parameter containing: [torch.float32 of size 512x512]
(19): Parameter containing: [torch.float32 of size 512x512]
(20): Parameter containing: [torch.float32 of size 512x512]
(21): Parameter containing: [torch.float32 of size 512x512]
(22): Parameter containing: [torch.float32 of size 512x512]
(23): Parameter containing: [torch.float32 of size 512x512]
(24): Parameter containing: [torch.float32 of size 512x512]
(25): Parameter containing: [torch.float32 of size 512x512]
(26): Parameter containing: [torch.float32 of size 512x512]
(27): Parameter containing: [torch.float32 of size 512x512]
(28): Parameter containing: [torch.float32 of size 512x512]
(29): Parameter containing: [torch.float32 of size 512x512]
(30): Parameter containing: [torch.float32 of size 512x512]
(31): Parameter containing: [torch.float32 of size 512x512]
(32): Parameter containing: [torch.float32 of size 512x512]
(33): Parameter containing: [torch.float32 of size 512x512]
(34): Parameter containing: [torch.float32 of size 512x512]
(35): Parameter containing: [torch.float32 of size 512x512]
(36): Parameter containing: [torch.float32 of size 512x512]
(37): Parameter containing: [torch.float32 of size 512x512]
(38): Parameter containing: [torch.float32 of size 512x512]
(39): Parameter containing: [torch.float32 of size 512x512]
(40): Parameter containing: [torch.float32 of size 512x512]
(41): Parameter containing: [torch.float32 of size 512x512]
(42): Parameter containing: [torch.float32 of size 512x512]
(43): Parameter containing: [torch.float32 of size 512x512]
(44): Parameter containing: [torch.float32 of size 512x512]
)
(day_biases): ParameterList(
(0): Parameter containing: [torch.float32 of size 1x512]
(1): Parameter containing: [torch.float32 of size 1x512]
(2): Parameter containing: [torch.float32 of size 1x512]
(3): Parameter containing: [torch.float32 of size 1x512]
(4): Parameter containing: [torch.float32 of size 1x512]
(5): Parameter containing: [torch.float32 of size 1x512]
(6): Parameter containing: [torch.float32 of size 1x512]
(7): Parameter containing: [torch.float32 of size 1x512]
(8): Parameter containing: [torch.float32 of size 1x512]
(9): Parameter containing: [torch.float32 of size 1x512]
(10): Parameter containing: [torch.float32 of size 1x512]
(11): Parameter containing: [torch.float32 of size 1x512]
(12): Parameter containing: [torch.float32 of size 1x512]
(13): Parameter containing: [torch.float32 of size 1x512]
(14): Parameter containing: [torch.float32 of size 1x512]
(15): Parameter containing: [torch.float32 of size 1x512]
(16): Parameter containing: [torch.float32 of size 1x512]
(17): Parameter containing: [torch.float32 of size 1x512]
(18): Parameter containing: [torch.float32 of size 1x512]
(19): Parameter containing: [torch.float32 of size 1x512]
(20): Parameter containing: [torch.float32 of size 1x512]
(21): Parameter containing: [torch.float32 of size 1x512]
(22): Parameter containing: [torch.float32 of size 1x512]
(23): Parameter containing: [torch.float32 of size 1x512]
(24): Parameter containing: [torch.float32 of size 1x512]
(25): Parameter containing: [torch.float32 of size 1x512]
(26): Parameter containing: [torch.float32 of size 1x512]
(27): Parameter containing: [torch.float32 of size 1x512]
(28): Parameter containing: [torch.float32 of size 1x512]
(29): Parameter containing: [torch.float32 of size 1x512]
(30): Parameter containing: [torch.float32 of size 1x512]
(31): Parameter containing: [torch.float32 of size 1x512]
(32): Parameter containing: [torch.float32 of size 1x512]
(33): Parameter containing: [torch.float32 of size 1x512]
(34): Parameter containing: [torch.float32 of size 1x512]
(35): Parameter containing: [torch.float32 of size 1x512]
(36): Parameter containing: [torch.float32 of size 1x512]
(37): Parameter containing: [torch.float32 of size 1x512]
(38): Parameter containing: [torch.float32 of size 1x512]
(39): Parameter containing: [torch.float32 of size 1x512]
(40): Parameter containing: [torch.float32 of size 1x512]
(41): Parameter containing: [torch.float32 of size 1x512]
(42): Parameter containing: [torch.float32 of size 1x512]
(43): Parameter containing: [torch.float32 of size 1x512]
(44): Parameter containing: [torch.float32 of size 1x512]
)
(day_layer_dropout): Dropout(p=0.2, inplace=False)
(gru): GRU(7168, 7168, num_layers=2, batch_first=True, dropout=0.4)
)
(clean_speech_model): CleanSpeechModel(
(day_layer_activation): Softsign()
(day_weights): ParameterList(
(0): Parameter containing: [torch.float32 of size 512x512]
(1): Parameter containing: [torch.float32 of size 512x512]
(2): Parameter containing: [torch.float32 of size 512x512]
(3): Parameter containing: [torch.float32 of size 512x512]
(4): Parameter containing: [torch.float32 of size 512x512]
(5): Parameter containing: [torch.float32 of size 512x512]
(6): Parameter containing: [torch.float32 of size 512x512]
(7): Parameter containing: [torch.float32 of size 512x512]
(8): Parameter containing: [torch.float32 of size 512x512]
(9): Parameter containing: [torch.float32 of size 512x512]
(10): Parameter containing: [torch.float32 of size 512x512]
(11): Parameter containing: [torch.float32 of size 512x512]
(12): Parameter containing: [torch.float32 of size 512x512]
(13): Parameter containing: [torch.float32 of size 512x512]
(14): Parameter containing: [torch.float32 of size 512x512]
(15): Parameter containing: [torch.float32 of size 512x512]
(16): Parameter containing: [torch.float32 of size 512x512]
(17): Parameter containing: [torch.float32 of size 512x512]
(18): Parameter containing: [torch.float32 of size 512x512]
(19): Parameter containing: [torch.float32 of size 512x512]
(20): Parameter containing: [torch.float32 of size 512x512]
(21): Parameter containing: [torch.float32 of size 512x512]
(22): Parameter containing: [torch.float32 of size 512x512]
(23): Parameter containing: [torch.float32 of size 512x512]
(24): Parameter containing: [torch.float32 of size 512x512]
(25): Parameter containing: [torch.float32 of size 512x512]
(26): Parameter containing: [torch.float32 of size 512x512]
(27): Parameter containing: [torch.float32 of size 512x512]
(28): Parameter containing: [torch.float32 of size 512x512]
(29): Parameter containing: [torch.float32 of size 512x512]
(30): Parameter containing: [torch.float32 of size 512x512]
(31): Parameter containing: [torch.float32 of size 512x512]
(32): Parameter containing: [torch.float32 of size 512x512]
(33): Parameter containing: [torch.float32 of size 512x512]
(34): Parameter containing: [torch.float32 of size 512x512]
(35): Parameter containing: [torch.float32 of size 512x512]
(36): Parameter containing: [torch.float32 of size 512x512]
(37): Parameter containing: [torch.float32 of size 512x512]
(38): Parameter containing: [torch.float32 of size 512x512]
(39): Parameter containing: [torch.float32 of size 512x512]
(40): Parameter containing: [torch.float32 of size 512x512]
(41): Parameter containing: [torch.float32 of size 512x512]
(42): Parameter containing: [torch.float32 of size 512x512]
(43): Parameter containing: [torch.float32 of size 512x512]
(44): Parameter containing: [torch.float32 of size 512x512]
)
(day_biases): ParameterList(
(0): Parameter containing: [torch.float32 of size 1x512]
(1): Parameter containing: [torch.float32 of size 1x512]
(2): Parameter containing: [torch.float32 of size 1x512]
(3): Parameter containing: [torch.float32 of size 1x512]
(4): Parameter containing: [torch.float32 of size 1x512]
(5): Parameter containing: [torch.float32 of size 1x512]
(6): Parameter containing: [torch.float32 of size 1x512]
(7): Parameter containing: [torch.float32 of size 1x512]
(8): Parameter containing: [torch.float32 of size 1x512]
(9): Parameter containing: [torch.float32 of size 1x512]
(10): Parameter containing: [torch.float32 of size 1x512]
(11): Parameter containing: [torch.float32 of size 1x512]
(12): Parameter containing: [torch.float32 of size 1x512]
(13): Parameter containing: [torch.float32 of size 1x512]
(14): Parameter containing: [torch.float32 of size 1x512]
(15): Parameter containing: [torch.float32 of size 1x512]
(16): Parameter containing: [torch.float32 of size 1x512]
(17): Parameter containing: [torch.float32 of size 1x512]
(18): Parameter containing: [torch.float32 of size 1x512]
(19): Parameter containing: [torch.float32 of size 1x512]
(20): Parameter containing: [torch.float32 of size 1x512]
(21): Parameter containing: [torch.float32 of size 1x512]
(22): Parameter containing: [torch.float32 of size 1x512]
(23): Parameter containing: [torch.float32 of size 1x512]
(24): Parameter containing: [torch.float32 of size 1x512]
(25): Parameter containing: [torch.float32 of size 1x512]
(26): Parameter containing: [torch.float32 of size 1x512]
(27): Parameter containing: [torch.float32 of size 1x512]
(28): Parameter containing: [torch.float32 of size 1x512]
(29): Parameter containing: [torch.float32 of size 1x512]
(30): Parameter containing: [torch.float32 of size 1x512]
(31): Parameter containing: [torch.float32 of size 1x512]
(32): Parameter containing: [torch.float32 of size 1x512]
(33): Parameter containing: [torch.float32 of size 1x512]
(34): Parameter containing: [torch.float32 of size 1x512]
(35): Parameter containing: [torch.float32 of size 1x512]
(36): Parameter containing: [torch.float32 of size 1x512]
(37): Parameter containing: [torch.float32 of size 1x512]
(38): Parameter containing: [torch.float32 of size 1x512]
(39): Parameter containing: [torch.float32 of size 1x512]
(40): Parameter containing: [torch.float32 of size 1x512]
(41): Parameter containing: [torch.float32 of size 1x512]
(42): Parameter containing: [torch.float32 of size 1x512]
(43): Parameter containing: [torch.float32 of size 1x512]
(44): Parameter containing: [torch.float32 of size 1x512]
)
(day_layer_dropout): Dropout(p=0.2, inplace=False)
(gru): GRU(7168, 768, num_layers=3, batch_first=True, dropout=0.4)
(out): Linear(in_features=768, out_features=41, bias=True)
)
(noisy_speech_model): NoisySpeechModel(
(gru): GRU(7168, 768, num_layers=2, batch_first=True, dropout=0.4)
(out): Linear(in_features=768, out_features=41, bias=True)
)
)
2025-10-12 18:35:38,190: Model has 687,568,466 parameters
2025-10-12 18:35:38,191: Model has 23,639,040 day-specific parameters | 3.44% of total parameters
2025-10-12 18:38:50,217: Using device: cpu
2025-10-12 18:38:50,217: Accelerator state: Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cpu
Mixed precision type: bf16
2025-10-12 18:39:25,946: torch.compile disabled for new TripleGRUDecoder compatibility
2025-10-12 18:39:25,946: Initialized RNN decoding model
2025-10-12 18:39:25,946: TripleGRUDecoder(
(noise_model): NoiseModel(
(day_layer_activation): Softsign()
(day_weights): ParameterList(
(0): Parameter containing: [torch.float32 of size 512x512]
(1): Parameter containing: [torch.float32 of size 512x512]
(2): Parameter containing: [torch.float32 of size 512x512]
(3): Parameter containing: [torch.float32 of size 512x512]
(4): Parameter containing: [torch.float32 of size 512x512]
(5): Parameter containing: [torch.float32 of size 512x512]
(6): Parameter containing: [torch.float32 of size 512x512]
(7): Parameter containing: [torch.float32 of size 512x512]
(8): Parameter containing: [torch.float32 of size 512x512]
(9): Parameter containing: [torch.float32 of size 512x512]
(10): Parameter containing: [torch.float32 of size 512x512]
(11): Parameter containing: [torch.float32 of size 512x512]
(12): Parameter containing: [torch.float32 of size 512x512]
(13): Parameter containing: [torch.float32 of size 512x512]
(14): Parameter containing: [torch.float32 of size 512x512]
(15): Parameter containing: [torch.float32 of size 512x512]
(16): Parameter containing: [torch.float32 of size 512x512]
(17): Parameter containing: [torch.float32 of size 512x512]
(18): Parameter containing: [torch.float32 of size 512x512]
(19): Parameter containing: [torch.float32 of size 512x512]
(20): Parameter containing: [torch.float32 of size 512x512]
(21): Parameter containing: [torch.float32 of size 512x512]
(22): Parameter containing: [torch.float32 of size 512x512]
(23): Parameter containing: [torch.float32 of size 512x512]
(24): Parameter containing: [torch.float32 of size 512x512]
(25): Parameter containing: [torch.float32 of size 512x512]
(26): Parameter containing: [torch.float32 of size 512x512]
(27): Parameter containing: [torch.float32 of size 512x512]
(28): Parameter containing: [torch.float32 of size 512x512]
(29): Parameter containing: [torch.float32 of size 512x512]
(30): Parameter containing: [torch.float32 of size 512x512]
(31): Parameter containing: [torch.float32 of size 512x512]
(32): Parameter containing: [torch.float32 of size 512x512]
(33): Parameter containing: [torch.float32 of size 512x512]
(34): Parameter containing: [torch.float32 of size 512x512]
(35): Parameter containing: [torch.float32 of size 512x512]
(36): Parameter containing: [torch.float32 of size 512x512]
(37): Parameter containing: [torch.float32 of size 512x512]
(38): Parameter containing: [torch.float32 of size 512x512]
(39): Parameter containing: [torch.float32 of size 512x512]
(40): Parameter containing: [torch.float32 of size 512x512]
(41): Parameter containing: [torch.float32 of size 512x512]
(42): Parameter containing: [torch.float32 of size 512x512]
(43): Parameter containing: [torch.float32 of size 512x512]
(44): Parameter containing: [torch.float32 of size 512x512]
)
(day_biases): ParameterList(
(0): Parameter containing: [torch.float32 of size 1x512]
(1): Parameter containing: [torch.float32 of size 1x512]
(2): Parameter containing: [torch.float32 of size 1x512]
(3): Parameter containing: [torch.float32 of size 1x512]
(4): Parameter containing: [torch.float32 of size 1x512]
(5): Parameter containing: [torch.float32 of size 1x512]
(6): Parameter containing: [torch.float32 of size 1x512]
(7): Parameter containing: [torch.float32 of size 1x512]
(8): Parameter containing: [torch.float32 of size 1x512]
(9): Parameter containing: [torch.float32 of size 1x512]
(10): Parameter containing: [torch.float32 of size 1x512]
(11): Parameter containing: [torch.float32 of size 1x512]
(12): Parameter containing: [torch.float32 of size 1x512]
(13): Parameter containing: [torch.float32 of size 1x512]
(14): Parameter containing: [torch.float32 of size 1x512]
(15): Parameter containing: [torch.float32 of size 1x512]
(16): Parameter containing: [torch.float32 of size 1x512]
(17): Parameter containing: [torch.float32 of size 1x512]
(18): Parameter containing: [torch.float32 of size 1x512]
(19): Parameter containing: [torch.float32 of size 1x512]
(20): Parameter containing: [torch.float32 of size 1x512]
(21): Parameter containing: [torch.float32 of size 1x512]
(22): Parameter containing: [torch.float32 of size 1x512]
(23): Parameter containing: [torch.float32 of size 1x512]
(24): Parameter containing: [torch.float32 of size 1x512]
(25): Parameter containing: [torch.float32 of size 1x512]
(26): Parameter containing: [torch.float32 of size 1x512]
(27): Parameter containing: [torch.float32 of size 1x512]
(28): Parameter containing: [torch.float32 of size 1x512]
(29): Parameter containing: [torch.float32 of size 1x512]
(30): Parameter containing: [torch.float32 of size 1x512]
(31): Parameter containing: [torch.float32 of size 1x512]
(32): Parameter containing: [torch.float32 of size 1x512]
(33): Parameter containing: [torch.float32 of size 1x512]
(34): Parameter containing: [torch.float32 of size 1x512]
(35): Parameter containing: [torch.float32 of size 1x512]
(36): Parameter containing: [torch.float32 of size 1x512]
(37): Parameter containing: [torch.float32 of size 1x512]
(38): Parameter containing: [torch.float32 of size 1x512]
(39): Parameter containing: [torch.float32 of size 1x512]
(40): Parameter containing: [torch.float32 of size 1x512]
(41): Parameter containing: [torch.float32 of size 1x512]
(42): Parameter containing: [torch.float32 of size 1x512]
(43): Parameter containing: [torch.float32 of size 1x512]
(44): Parameter containing: [torch.float32 of size 1x512]
)
(day_layer_dropout): Dropout(p=0.2, inplace=False)
(gru): GRU(7168, 7168, num_layers=2, batch_first=True, dropout=0.4)
)
(clean_speech_model): CleanSpeechModel(
(day_layer_activation): Softsign()
(day_weights): ParameterList(
(0): Parameter containing: [torch.float32 of size 512x512]
(1): Parameter containing: [torch.float32 of size 512x512]
(2): Parameter containing: [torch.float32 of size 512x512]
(3): Parameter containing: [torch.float32 of size 512x512]
(4): Parameter containing: [torch.float32 of size 512x512]
(5): Parameter containing: [torch.float32 of size 512x512]
(6): Parameter containing: [torch.float32 of size 512x512]
(7): Parameter containing: [torch.float32 of size 512x512]
(8): Parameter containing: [torch.float32 of size 512x512]
(9): Parameter containing: [torch.float32 of size 512x512]
(10): Parameter containing: [torch.float32 of size 512x512]
(11): Parameter containing: [torch.float32 of size 512x512]
(12): Parameter containing: [torch.float32 of size 512x512]
(13): Parameter containing: [torch.float32 of size 512x512]
(14): Parameter containing: [torch.float32 of size 512x512]
(15): Parameter containing: [torch.float32 of size 512x512]
(16): Parameter containing: [torch.float32 of size 512x512]
(17): Parameter containing: [torch.float32 of size 512x512]
(18): Parameter containing: [torch.float32 of size 512x512]
(19): Parameter containing: [torch.float32 of size 512x512]
(20): Parameter containing: [torch.float32 of size 512x512]
(21): Parameter containing: [torch.float32 of size 512x512]
(22): Parameter containing: [torch.float32 of size 512x512]
(23): Parameter containing: [torch.float32 of size 512x512]
(24): Parameter containing: [torch.float32 of size 512x512]
(25): Parameter containing: [torch.float32 of size 512x512]
(26): Parameter containing: [torch.float32 of size 512x512]
(27): Parameter containing: [torch.float32 of size 512x512]
(28): Parameter containing: [torch.float32 of size 512x512]
(29): Parameter containing: [torch.float32 of size 512x512]
(30): Parameter containing: [torch.float32 of size 512x512]
(31): Parameter containing: [torch.float32 of size 512x512]
(32): Parameter containing: [torch.float32 of size 512x512]
(33): Parameter containing: [torch.float32 of size 512x512]
(34): Parameter containing: [torch.float32 of size 512x512]
(35): Parameter containing: [torch.float32 of size 512x512]
(36): Parameter containing: [torch.float32 of size 512x512]
(37): Parameter containing: [torch.float32 of size 512x512]
(38): Parameter containing: [torch.float32 of size 512x512]
(39): Parameter containing: [torch.float32 of size 512x512]
(40): Parameter containing: [torch.float32 of size 512x512]
(41): Parameter containing: [torch.float32 of size 512x512]
(42): Parameter containing: [torch.float32 of size 512x512]
(43): Parameter containing: [torch.float32 of size 512x512]
(44): Parameter containing: [torch.float32 of size 512x512]
)
(day_biases): ParameterList(
(0): Parameter containing: [torch.float32 of size 1x512]
(1): Parameter containing: [torch.float32 of size 1x512]
(2): Parameter containing: [torch.float32 of size 1x512]
(3): Parameter containing: [torch.float32 of size 1x512]
(4): Parameter containing: [torch.float32 of size 1x512]
(5): Parameter containing: [torch.float32 of size 1x512]
(6): Parameter containing: [torch.float32 of size 1x512]
(7): Parameter containing: [torch.float32 of size 1x512]
(8): Parameter containing: [torch.float32 of size 1x512]
(9): Parameter containing: [torch.float32 of size 1x512]
(10): Parameter containing: [torch.float32 of size 1x512]
(11): Parameter containing: [torch.float32 of size 1x512]
(12): Parameter containing: [torch.float32 of size 1x512]
(13): Parameter containing: [torch.float32 of size 1x512]
(14): Parameter containing: [torch.float32 of size 1x512]
(15): Parameter containing: [torch.float32 of size 1x512]
(16): Parameter containing: [torch.float32 of size 1x512]
(17): Parameter containing: [torch.float32 of size 1x512]
(18): Parameter containing: [torch.float32 of size 1x512]
(19): Parameter containing: [torch.float32 of size 1x512]
(20): Parameter containing: [torch.float32 of size 1x512]
(21): Parameter containing: [torch.float32 of size 1x512]
(22): Parameter containing: [torch.float32 of size 1x512]
(23): Parameter containing: [torch.float32 of size 1x512]
(24): Parameter containing: [torch.float32 of size 1x512]
(25): Parameter containing: [torch.float32 of size 1x512]
(26): Parameter containing: [torch.float32 of size 1x512]
(27): Parameter containing: [torch.float32 of size 1x512]
(28): Parameter containing: [torch.float32 of size 1x512]
(29): Parameter containing: [torch.float32 of size 1x512]
(30): Parameter containing: [torch.float32 of size 1x512]
(31): Parameter containing: [torch.float32 of size 1x512]
(32): Parameter containing: [torch.float32 of size 1x512]
(33): Parameter containing: [torch.float32 of size 1x512]
(34): Parameter containing: [torch.float32 of size 1x512]
(35): Parameter containing: [torch.float32 of size 1x512]
(36): Parameter containing: [torch.float32 of size 1x512]
(37): Parameter containing: [torch.float32 of size 1x512]
(38): Parameter containing: [torch.float32 of size 1x512]
(39): Parameter containing: [torch.float32 of size 1x512]
(40): Parameter containing: [torch.float32 of size 1x512]
(41): Parameter containing: [torch.float32 of size 1x512]
(42): Parameter containing: [torch.float32 of size 1x512]
(43): Parameter containing: [torch.float32 of size 1x512]
(44): Parameter containing: [torch.float32 of size 1x512]
)
(day_layer_dropout): Dropout(p=0.2, inplace=False)
(gru): GRU(7168, 768, num_layers=3, batch_first=True, dropout=0.4)
(out): Linear(in_features=768, out_features=41, bias=True)
)
(noisy_speech_model): NoisySpeechModel(
(gru): GRU(7168, 768, num_layers=2, batch_first=True, dropout=0.4)
(out): Linear(in_features=768, out_features=41, bias=True)
)
)
2025-10-12 18:39:25,958: Model has 687,568,466 parameters
2025-10-12 18:39:25,958: Model has 23,639,040 day-specific parameters | 3.44% of total parameters