diff --git a/model_training_nnn/dataset.py b/model_training_nnn/dataset.py index 59656a4..7b5e81d 100644 --- a/model_training_nnn/dataset.py +++ b/model_training_nnn/dataset.py @@ -90,12 +90,12 @@ class BrainToTextDataset(Dataset): self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data def __len__(self): - ''' - How many batches are in this dataset. - Because training data is sampled randomly, there is no fixed dataset length, - however this method is required for DataLoader to work ''' - return self.n_batches + How many batches are in this dataset. + Because training data is sampled randomly, there is no fixed dataset length, + however this method is required for DataLoader to work + ''' + return self.n_batches if self.n_batches is not None else 0 def __getitem__(self, idx): ''' @@ -269,7 +269,9 @@ def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_ # Get trials in each day trials_per_day = {} for i, path in enumerate(file_paths): - session = [s for s in path.split('/') if (s.startswith('t15.20') or s.startswith('t12.20'))][0] + # Handle both Windows and Unix path separators + path_parts = path.replace('\\', '/').split('/') + session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0] good_trial_indices = [] diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index bb9f93e..822c775 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -72,11 +72,11 @@ class BrainToTextDecoder_Trainer: # Create output directory if args['mode'] == 'train': - os.makedirs(self.args['output_dir'], exist_ok=False) + os.makedirs(self.args['output_dir'], exist_ok=True) # Create checkpoint directory - if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']: - os.makedirs(self.args['checkpoint_dir'], exist_ok=False) + if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']: + os.makedirs(self.args['checkpoint_dir'], exist_ok=True) # Set up logging self.logger = logging.getLogger(__name__) @@ -188,12 +188,16 @@ class BrainToTextDecoder_Trainer: # Use TPU-optimized dataloader settings if TPU is enabled num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers'] + # For TPU environments or when batch_size=None causes issues, use batch_size=1 + # since our dataset already returns complete batches + batch_size_setting = 1 if (self.args.get('use_tpu', False) or self.accelerator.device.type == 'xla') else None + self.train_loader = DataLoader( self.train_dataset, - batch_size = None, # Dataset.__getitem__() already returns batches + batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches shuffle = self.args['dataset']['loader_shuffle'], num_workers = num_workers, - pin_memory = True + pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory ) # val dataset and dataloader @@ -209,10 +213,10 @@ class BrainToTextDecoder_Trainer: ) self.val_loader = DataLoader( self.val_dataset, - batch_size = None, # Dataset.__getitem__() already returns batches + batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches shuffle = False, num_workers = 0, # Keep validation dataloader single-threaded for consistency - pin_memory = True + pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory ) self.logger.info("Successfully initialized datasets") diff --git a/model_training_nnn/trained_models/baseline_rnn/training_log b/model_training_nnn/trained_models/baseline_rnn/training_log deleted file mode 100644 index c835d6e..0000000 --- a/model_training_nnn/trained_models/baseline_rnn/training_log +++ /dev/null @@ -1,213 +0,0 @@ -2025-10-12 09:32:47,330: Using device: cpu -2025-10-12 09:33:20,119: torch.compile disabled for new TripleGRUDecoder compatibility -2025-10-12 09:33:20,120: Initialized RNN decoding model -2025-10-12 09:33:20,120: Model type: -2025-10-12 09:33:20,120: Model is callable: True -2025-10-12 09:33:20,120: Model has forward method: True -2025-10-12 09:33:20,120: TripleGRUDecoder( - (noise_model): NoiseModel( - (day_layer_activation): Softsign() - (day_weights): ParameterList( - (0): Parameter containing: [torch.float32 of size 512x512] - (1): Parameter containing: [torch.float32 of size 512x512] - (2): Parameter containing: [torch.float32 of size 512x512] - (3): Parameter containing: [torch.float32 of size 512x512] - (4): Parameter containing: [torch.float32 of size 512x512] - (5): Parameter containing: [torch.float32 of size 512x512] - (6): Parameter containing: [torch.float32 of size 512x512] - (7): Parameter containing: [torch.float32 of size 512x512] - (8): Parameter containing: [torch.float32 of size 512x512] - (9): Parameter containing: [torch.float32 of size 512x512] - (10): Parameter containing: [torch.float32 of size 512x512] - (11): Parameter containing: [torch.float32 of size 512x512] - (12): Parameter containing: [torch.float32 of size 512x512] - (13): Parameter containing: [torch.float32 of size 512x512] - (14): Parameter containing: [torch.float32 of size 512x512] - (15): Parameter containing: [torch.float32 of size 512x512] - (16): Parameter containing: [torch.float32 of size 512x512] - (17): Parameter containing: [torch.float32 of size 512x512] - (18): Parameter containing: [torch.float32 of size 512x512] - (19): Parameter containing: [torch.float32 of size 512x512] - (20): Parameter containing: [torch.float32 of size 512x512] - (21): Parameter containing: [torch.float32 of size 512x512] - (22): Parameter containing: [torch.float32 of size 512x512] - (23): Parameter containing: [torch.float32 of size 512x512] - (24): Parameter containing: [torch.float32 of size 512x512] - (25): Parameter containing: [torch.float32 of size 512x512] - (26): Parameter containing: [torch.float32 of size 512x512] - (27): Parameter containing: [torch.float32 of size 512x512] - (28): Parameter containing: [torch.float32 of size 512x512] - (29): Parameter containing: [torch.float32 of size 512x512] - (30): Parameter containing: [torch.float32 of size 512x512] - (31): Parameter containing: [torch.float32 of size 512x512] - (32): Parameter containing: [torch.float32 of size 512x512] - (33): Parameter containing: [torch.float32 of size 512x512] - (34): Parameter containing: [torch.float32 of size 512x512] - (35): Parameter containing: [torch.float32 of size 512x512] - (36): Parameter containing: [torch.float32 of size 512x512] - (37): Parameter containing: [torch.float32 of size 512x512] - (38): Parameter containing: [torch.float32 of size 512x512] - (39): Parameter containing: [torch.float32 of size 512x512] - (40): Parameter containing: [torch.float32 of size 512x512] - (41): Parameter containing: [torch.float32 of size 512x512] - (42): Parameter containing: [torch.float32 of size 512x512] - (43): Parameter containing: [torch.float32 of size 512x512] - (44): Parameter containing: [torch.float32 of size 512x512] - ) - (day_biases): ParameterList( - (0): Parameter containing: [torch.float32 of size 1x512] - (1): Parameter containing: [torch.float32 of size 1x512] - (2): Parameter containing: [torch.float32 of size 1x512] - (3): Parameter containing: [torch.float32 of size 1x512] - (4): Parameter containing: [torch.float32 of size 1x512] - (5): Parameter containing: [torch.float32 of size 1x512] - (6): Parameter containing: [torch.float32 of size 1x512] - (7): Parameter containing: [torch.float32 of size 1x512] - (8): Parameter containing: [torch.float32 of size 1x512] - (9): Parameter containing: [torch.float32 of size 1x512] - (10): Parameter containing: [torch.float32 of size 1x512] - (11): Parameter containing: [torch.float32 of size 1x512] - (12): Parameter containing: [torch.float32 of size 1x512] - (13): Parameter containing: [torch.float32 of size 1x512] - (14): Parameter containing: [torch.float32 of size 1x512] - (15): Parameter containing: [torch.float32 of size 1x512] - (16): Parameter containing: [torch.float32 of size 1x512] - (17): Parameter containing: [torch.float32 of size 1x512] - (18): Parameter containing: [torch.float32 of size 1x512] - (19): Parameter containing: [torch.float32 of size 1x512] - (20): Parameter containing: [torch.float32 of size 1x512] - (21): Parameter containing: [torch.float32 of size 1x512] - (22): Parameter containing: [torch.float32 of size 1x512] - (23): Parameter containing: [torch.float32 of size 1x512] - (24): Parameter containing: [torch.float32 of size 1x512] - (25): Parameter containing: [torch.float32 of size 1x512] - (26): Parameter containing: [torch.float32 of size 1x512] - (27): Parameter containing: [torch.float32 of size 1x512] - (28): Parameter containing: [torch.float32 of size 1x512] - (29): Parameter containing: [torch.float32 of size 1x512] - (30): Parameter containing: [torch.float32 of size 1x512] - (31): Parameter containing: [torch.float32 of size 1x512] - (32): Parameter containing: [torch.float32 of size 1x512] - (33): Parameter containing: [torch.float32 of size 1x512] - (34): Parameter containing: [torch.float32 of size 1x512] - (35): Parameter containing: [torch.float32 of size 1x512] - (36): Parameter containing: [torch.float32 of size 1x512] - (37): Parameter containing: [torch.float32 of size 1x512] - (38): Parameter containing: [torch.float32 of size 1x512] - (39): Parameter containing: [torch.float32 of size 1x512] - (40): Parameter containing: [torch.float32 of size 1x512] - (41): Parameter containing: [torch.float32 of size 1x512] - (42): Parameter containing: [torch.float32 of size 1x512] - (43): Parameter containing: [torch.float32 of size 1x512] - (44): Parameter containing: [torch.float32 of size 1x512] - ) - (day_layer_dropout): Dropout(p=0.2, inplace=False) - (gru): GRU(7168, 7168, num_layers=2, batch_first=True, dropout=0.4) - ) - (clean_speech_model): CleanSpeechModel( - (day_layer_activation): Softsign() - (day_weights): ParameterList( - (0): Parameter containing: [torch.float32 of size 512x512] - (1): Parameter containing: [torch.float32 of size 512x512] - (2): Parameter containing: [torch.float32 of size 512x512] - (3): Parameter containing: [torch.float32 of size 512x512] - (4): Parameter containing: [torch.float32 of size 512x512] - (5): Parameter containing: [torch.float32 of size 512x512] - (6): Parameter containing: [torch.float32 of size 512x512] - (7): Parameter containing: [torch.float32 of size 512x512] - (8): Parameter containing: [torch.float32 of size 512x512] - (9): Parameter containing: [torch.float32 of size 512x512] - (10): Parameter containing: [torch.float32 of size 512x512] - (11): Parameter containing: [torch.float32 of size 512x512] - (12): Parameter containing: [torch.float32 of size 512x512] - (13): Parameter containing: [torch.float32 of size 512x512] - (14): Parameter containing: [torch.float32 of size 512x512] - (15): Parameter containing: [torch.float32 of size 512x512] - (16): Parameter containing: [torch.float32 of size 512x512] - (17): Parameter containing: [torch.float32 of size 512x512] - (18): Parameter containing: [torch.float32 of size 512x512] - (19): Parameter containing: [torch.float32 of size 512x512] - (20): Parameter containing: [torch.float32 of size 512x512] - (21): Parameter containing: [torch.float32 of size 512x512] - (22): Parameter containing: [torch.float32 of size 512x512] - (23): Parameter containing: [torch.float32 of size 512x512] - (24): Parameter containing: [torch.float32 of size 512x512] - (25): Parameter containing: [torch.float32 of size 512x512] - (26): Parameter containing: [torch.float32 of size 512x512] - (27): Parameter containing: [torch.float32 of size 512x512] - (28): Parameter containing: [torch.float32 of size 512x512] - (29): Parameter containing: [torch.float32 of size 512x512] - (30): Parameter containing: [torch.float32 of size 512x512] - (31): Parameter containing: [torch.float32 of size 512x512] - (32): Parameter containing: [torch.float32 of size 512x512] - (33): Parameter containing: [torch.float32 of size 512x512] - (34): Parameter containing: [torch.float32 of size 512x512] - (35): Parameter containing: [torch.float32 of size 512x512] - (36): Parameter containing: [torch.float32 of size 512x512] - (37): Parameter containing: [torch.float32 of size 512x512] - (38): Parameter containing: [torch.float32 of size 512x512] - (39): Parameter containing: [torch.float32 of size 512x512] - (40): Parameter containing: [torch.float32 of size 512x512] - (41): Parameter containing: [torch.float32 of size 512x512] - (42): Parameter containing: [torch.float32 of size 512x512] - (43): Parameter containing: [torch.float32 of size 512x512] - (44): Parameter containing: [torch.float32 of size 512x512] - ) - (day_biases): ParameterList( - (0): Parameter containing: [torch.float32 of size 1x512] - (1): Parameter containing: [torch.float32 of size 1x512] - (2): Parameter containing: [torch.float32 of size 1x512] - (3): Parameter containing: [torch.float32 of size 1x512] - (4): Parameter containing: [torch.float32 of size 1x512] - (5): Parameter containing: [torch.float32 of size 1x512] - (6): Parameter containing: [torch.float32 of size 1x512] - (7): Parameter containing: [torch.float32 of size 1x512] - (8): Parameter containing: [torch.float32 of size 1x512] - (9): Parameter containing: [torch.float32 of size 1x512] - (10): Parameter containing: [torch.float32 of size 1x512] - (11): Parameter containing: [torch.float32 of size 1x512] - (12): Parameter containing: [torch.float32 of size 1x512] - (13): Parameter containing: [torch.float32 of size 1x512] - (14): Parameter containing: [torch.float32 of size 1x512] - (15): Parameter containing: [torch.float32 of size 1x512] - (16): Parameter containing: [torch.float32 of size 1x512] - (17): Parameter containing: [torch.float32 of size 1x512] - (18): Parameter containing: [torch.float32 of size 1x512] - (19): Parameter containing: [torch.float32 of size 1x512] - (20): Parameter containing: [torch.float32 of size 1x512] - (21): Parameter containing: [torch.float32 of size 1x512] - (22): Parameter containing: [torch.float32 of size 1x512] - (23): Parameter containing: [torch.float32 of size 1x512] - (24): Parameter containing: [torch.float32 of size 1x512] - (25): Parameter containing: [torch.float32 of size 1x512] - (26): Parameter containing: [torch.float32 of size 1x512] - (27): Parameter containing: [torch.float32 of size 1x512] - (28): Parameter containing: [torch.float32 of size 1x512] - (29): Parameter containing: [torch.float32 of size 1x512] - (30): Parameter containing: [torch.float32 of size 1x512] - (31): Parameter containing: [torch.float32 of size 1x512] - (32): Parameter containing: [torch.float32 of size 1x512] - (33): Parameter containing: [torch.float32 of size 1x512] - (34): Parameter containing: [torch.float32 of size 1x512] - (35): Parameter containing: [torch.float32 of size 1x512] - (36): Parameter containing: [torch.float32 of size 1x512] - (37): Parameter containing: [torch.float32 of size 1x512] - (38): Parameter containing: [torch.float32 of size 1x512] - (39): Parameter containing: [torch.float32 of size 1x512] - (40): Parameter containing: [torch.float32 of size 1x512] - (41): Parameter containing: [torch.float32 of size 1x512] - (42): Parameter containing: [torch.float32 of size 1x512] - (43): Parameter containing: [torch.float32 of size 1x512] - (44): Parameter containing: [torch.float32 of size 1x512] - ) - (day_layer_dropout): Dropout(p=0.2, inplace=False) - (gru): GRU(7168, 768, num_layers=3, batch_first=True, dropout=0.4) - (out): Linear(in_features=768, out_features=41, bias=True) - ) - (noisy_speech_model): NoisySpeechModel( - (gru): GRU(7168, 768, num_layers=2, batch_first=True, dropout=0.4) - (out): Linear(in_features=768, out_features=41, bias=True) - ) -) -2025-10-12 09:33:20,124: Model has 687,568,466 parameters -2025-10-12 09:33:20,124: Model has 23,639,040 day-specific parameters | 3.44% of total parameters diff --git a/trained_models/baseline_rnn/train_val_trials.json b/trained_models/baseline_rnn/train_val_trials.json new file mode 100644 index 0000000..4c7c183 --- /dev/null +++ b/trained_models/baseline_rnn/train_val_trials.json @@ -0,0 +1 @@ +{"train": {"0": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.11\\data_train.hdf5"}, "1": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.13\\data_train.hdf5"}, "2": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.18\\data_train.hdf5"}, "3": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.20\\data_train.hdf5"}, "4": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.25\\data_train.hdf5"}, "5": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.27\\data_train.hdf5"}, "6": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.09.01\\data_train.hdf5"}, "7": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.09.03\\data_train.hdf5"}, "8": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.09.24\\data_train.hdf5"}, "9": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.09.29\\data_train.hdf5"}, "10": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.01\\data_train.hdf5"}, "11": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.06\\data_train.hdf5"}, "12": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.08\\data_train.hdf5"}, "13": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.13\\data_train.hdf5"}, "14": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.15\\data_train.hdf5"}, "15": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.20\\data_train.hdf5"}, "16": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.22\\data_train.hdf5"}, "17": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.03\\data_train.hdf5"}, "18": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.04\\data_train.hdf5"}, "19": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.17\\data_train.hdf5"}, "20": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.19\\data_train.hdf5"}, "21": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.26\\data_train.hdf5"}, "22": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.03\\data_train.hdf5"}, "23": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.08\\data_train.hdf5"}, "24": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.10\\data_train.hdf5"}, "25": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.17\\data_train.hdf5"}, "26": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.29\\data_train.hdf5"}, "27": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.02.25\\data_train.hdf5"}, "28": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.03.03\\data_train.hdf5"}, "29": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.03.08\\data_train.hdf5"}, "30": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.03.15\\data_train.hdf5"}, "31": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.03.17\\data_train.hdf5"}, "32": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.04.25\\data_train.hdf5"}, "33": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.04.28\\data_train.hdf5"}, "34": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.05.10\\data_train.hdf5"}, "35": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.06.14\\data_train.hdf5"}, "36": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.07.19\\data_train.hdf5"}, "37": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.07.21\\data_train.hdf5"}, "38": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.07.28\\data_train.hdf5"}, "39": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.01.10\\data_train.hdf5"}, "40": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.01.12\\data_train.hdf5"}, "41": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.03.14\\data_train.hdf5"}, "42": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.03.16\\data_train.hdf5"}, "43": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.03.30\\data_train.hdf5"}, "44": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.04.13\\data_train.hdf5"}}, "val": {"0": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.11\\data_val.hdf5"}, "1": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.13\\data_val.hdf5"}, "2": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.18\\data_val.hdf5"}, "3": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.20\\data_val.hdf5"}, "4": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.25\\data_val.hdf5"}, "5": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.08.27\\data_val.hdf5"}, "6": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.09.01\\data_val.hdf5"}, "7": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.09.03\\data_val.hdf5"}, "8": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.09.24\\data_val.hdf5"}, "9": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.09.29\\data_val.hdf5"}, "10": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.01\\data_val.hdf5"}, "11": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.06\\data_val.hdf5"}, "12": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.08\\data_val.hdf5"}, "13": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.13\\data_val.hdf5"}, "14": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.15\\data_val.hdf5"}, "15": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.20\\data_val.hdf5"}, "16": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.10.22\\data_val.hdf5"}, "17": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.03\\data_val.hdf5"}, "18": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.04\\data_val.hdf5"}, "19": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.17\\data_val.hdf5"}, "20": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.19\\data_val.hdf5"}, "21": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.11.26\\data_val.hdf5"}, "22": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.03\\data_val.hdf5"}, "23": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.08\\data_val.hdf5"}, "24": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.10\\data_val.hdf5"}, "25": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.17\\data_val.hdf5"}, "26": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2023.12.29\\data_val.hdf5"}, "27": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.02.25\\data_val.hdf5"}, "28": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.03.03\\data_val.hdf5"}, "29": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.03.08\\data_val.hdf5"}, "30": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.03.15\\data_val.hdf5"}, "31": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.03.17\\data_val.hdf5"}, "32": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.04.25\\data_val.hdf5"}, "33": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.04.28\\data_val.hdf5"}, "34": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.05.10\\data_val.hdf5"}, "35": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.06.14\\data_val.hdf5"}, "36": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.07.19\\data_val.hdf5"}, "37": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.07.21\\data_val.hdf5"}, "38": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2024.07.28\\data_val.hdf5"}, "39": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.01.10\\data_val.hdf5"}, "40": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.01.12\\data_val.hdf5"}, "41": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.03.14\\data_val.hdf5"}, "42": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.03.16\\data_val.hdf5"}, "43": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.03.30\\data_val.hdf5"}, "44": {"trials": [], "session_path": "../data/hdf5_data_final\\t15.2025.04.13\\data_val.hdf5"}}} \ No newline at end of file diff --git a/trained_models/baseline_rnn/training_log b/trained_models/baseline_rnn/training_log new file mode 100644 index 0000000..c215d6f --- /dev/null +++ b/trained_models/baseline_rnn/training_log @@ -0,0 +1,436 @@ +2025-10-12 18:35:01,395: Using device: cpu +2025-10-12 18:35:01,395: Accelerator state: Distributed environment: NO +Num processes: 1 +Process index: 0 +Local process index: 0 +Device: cpu + +Mixed precision type: bf16 + +2025-10-12 18:35:38,176: torch.compile disabled for new TripleGRUDecoder compatibility +2025-10-12 18:35:38,176: Initialized RNN decoding model +2025-10-12 18:35:38,176: TripleGRUDecoder( + (noise_model): NoiseModel( + (day_layer_activation): Softsign() + (day_weights): ParameterList( + (0): Parameter containing: [torch.float32 of size 512x512] + (1): Parameter containing: [torch.float32 of size 512x512] + (2): Parameter containing: [torch.float32 of size 512x512] + (3): Parameter containing: [torch.float32 of size 512x512] + (4): Parameter containing: [torch.float32 of size 512x512] + (5): Parameter containing: [torch.float32 of size 512x512] + (6): Parameter containing: [torch.float32 of size 512x512] + (7): Parameter containing: [torch.float32 of size 512x512] + (8): Parameter containing: [torch.float32 of size 512x512] + (9): Parameter containing: [torch.float32 of size 512x512] + (10): Parameter containing: [torch.float32 of size 512x512] + (11): Parameter containing: [torch.float32 of size 512x512] + (12): Parameter containing: [torch.float32 of size 512x512] + (13): Parameter containing: [torch.float32 of size 512x512] + (14): Parameter containing: [torch.float32 of size 512x512] + (15): Parameter containing: [torch.float32 of size 512x512] + (16): Parameter containing: [torch.float32 of size 512x512] + (17): Parameter containing: [torch.float32 of size 512x512] + (18): Parameter containing: [torch.float32 of size 512x512] + (19): Parameter containing: [torch.float32 of size 512x512] + (20): Parameter containing: [torch.float32 of size 512x512] + (21): Parameter containing: [torch.float32 of size 512x512] + (22): Parameter containing: [torch.float32 of size 512x512] + (23): Parameter containing: [torch.float32 of size 512x512] + (24): Parameter containing: [torch.float32 of size 512x512] + (25): Parameter containing: [torch.float32 of size 512x512] + (26): Parameter containing: [torch.float32 of size 512x512] + (27): Parameter containing: [torch.float32 of size 512x512] + (28): Parameter containing: [torch.float32 of size 512x512] + (29): Parameter containing: [torch.float32 of size 512x512] + (30): Parameter containing: [torch.float32 of size 512x512] + (31): Parameter containing: [torch.float32 of size 512x512] + (32): Parameter containing: [torch.float32 of size 512x512] + (33): Parameter containing: [torch.float32 of size 512x512] + (34): Parameter containing: [torch.float32 of size 512x512] + (35): Parameter containing: [torch.float32 of size 512x512] + (36): Parameter containing: [torch.float32 of size 512x512] + (37): Parameter containing: [torch.float32 of size 512x512] + (38): Parameter containing: [torch.float32 of size 512x512] + (39): Parameter containing: [torch.float32 of size 512x512] + (40): Parameter containing: [torch.float32 of size 512x512] + (41): Parameter containing: [torch.float32 of size 512x512] + (42): Parameter containing: [torch.float32 of size 512x512] + (43): Parameter containing: [torch.float32 of size 512x512] + (44): Parameter containing: [torch.float32 of size 512x512] + ) + (day_biases): ParameterList( + (0): Parameter containing: [torch.float32 of size 1x512] + (1): Parameter containing: [torch.float32 of size 1x512] + (2): Parameter containing: [torch.float32 of size 1x512] + (3): Parameter containing: [torch.float32 of size 1x512] + (4): Parameter containing: [torch.float32 of size 1x512] + (5): Parameter containing: [torch.float32 of size 1x512] + (6): Parameter containing: [torch.float32 of size 1x512] + (7): Parameter containing: [torch.float32 of size 1x512] + (8): Parameter containing: [torch.float32 of size 1x512] + (9): Parameter containing: [torch.float32 of size 1x512] + (10): Parameter containing: [torch.float32 of size 1x512] + (11): Parameter containing: [torch.float32 of size 1x512] + (12): Parameter containing: [torch.float32 of size 1x512] + (13): Parameter containing: [torch.float32 of size 1x512] + (14): Parameter containing: [torch.float32 of size 1x512] + (15): Parameter containing: [torch.float32 of size 1x512] + (16): Parameter containing: [torch.float32 of size 1x512] + (17): Parameter containing: [torch.float32 of size 1x512] + (18): Parameter containing: [torch.float32 of size 1x512] + (19): Parameter containing: [torch.float32 of size 1x512] + (20): Parameter containing: [torch.float32 of size 1x512] + (21): Parameter containing: [torch.float32 of size 1x512] + (22): Parameter containing: [torch.float32 of size 1x512] + (23): Parameter containing: [torch.float32 of size 1x512] + (24): Parameter containing: [torch.float32 of size 1x512] + (25): Parameter containing: [torch.float32 of size 1x512] + (26): Parameter containing: [torch.float32 of size 1x512] + (27): Parameter containing: [torch.float32 of size 1x512] + (28): Parameter containing: [torch.float32 of size 1x512] + (29): Parameter containing: [torch.float32 of size 1x512] + (30): Parameter containing: [torch.float32 of size 1x512] + (31): Parameter containing: [torch.float32 of size 1x512] + (32): Parameter containing: [torch.float32 of size 1x512] + (33): Parameter containing: [torch.float32 of size 1x512] + (34): Parameter containing: [torch.float32 of size 1x512] + (35): Parameter containing: [torch.float32 of size 1x512] + (36): Parameter containing: [torch.float32 of size 1x512] + (37): Parameter containing: [torch.float32 of size 1x512] + (38): Parameter containing: [torch.float32 of size 1x512] + (39): Parameter containing: [torch.float32 of size 1x512] + (40): Parameter containing: [torch.float32 of size 1x512] + (41): Parameter containing: [torch.float32 of size 1x512] + (42): Parameter containing: [torch.float32 of size 1x512] + (43): Parameter containing: [torch.float32 of size 1x512] + (44): Parameter containing: [torch.float32 of size 1x512] + ) + (day_layer_dropout): Dropout(p=0.2, inplace=False) + (gru): GRU(7168, 7168, num_layers=2, batch_first=True, dropout=0.4) + ) + (clean_speech_model): CleanSpeechModel( + (day_layer_activation): Softsign() + (day_weights): ParameterList( + (0): Parameter containing: [torch.float32 of size 512x512] + (1): Parameter containing: [torch.float32 of size 512x512] + (2): Parameter containing: [torch.float32 of size 512x512] + (3): Parameter containing: [torch.float32 of size 512x512] + (4): Parameter containing: [torch.float32 of size 512x512] + (5): Parameter containing: [torch.float32 of size 512x512] + (6): Parameter containing: [torch.float32 of size 512x512] + (7): Parameter containing: [torch.float32 of size 512x512] + (8): Parameter containing: [torch.float32 of size 512x512] + (9): Parameter containing: [torch.float32 of size 512x512] + (10): Parameter containing: [torch.float32 of size 512x512] + (11): Parameter containing: [torch.float32 of size 512x512] + (12): Parameter containing: [torch.float32 of size 512x512] + (13): Parameter containing: [torch.float32 of size 512x512] + (14): Parameter containing: [torch.float32 of size 512x512] + (15): Parameter containing: [torch.float32 of size 512x512] + (16): Parameter containing: [torch.float32 of size 512x512] + (17): Parameter containing: [torch.float32 of size 512x512] + (18): Parameter containing: [torch.float32 of size 512x512] + (19): Parameter containing: [torch.float32 of size 512x512] + (20): Parameter containing: [torch.float32 of size 512x512] + (21): Parameter containing: [torch.float32 of size 512x512] + (22): Parameter containing: [torch.float32 of size 512x512] + (23): Parameter containing: [torch.float32 of size 512x512] + (24): Parameter containing: [torch.float32 of size 512x512] + (25): Parameter containing: [torch.float32 of size 512x512] + (26): Parameter containing: [torch.float32 of size 512x512] + (27): Parameter containing: [torch.float32 of size 512x512] + (28): Parameter containing: [torch.float32 of size 512x512] + (29): Parameter containing: [torch.float32 of size 512x512] + (30): Parameter containing: [torch.float32 of size 512x512] + (31): Parameter containing: [torch.float32 of size 512x512] + (32): Parameter containing: [torch.float32 of size 512x512] + (33): Parameter containing: [torch.float32 of size 512x512] + (34): Parameter containing: [torch.float32 of size 512x512] + (35): Parameter containing: [torch.float32 of size 512x512] + (36): Parameter containing: [torch.float32 of size 512x512] + (37): Parameter containing: [torch.float32 of size 512x512] + (38): Parameter containing: [torch.float32 of size 512x512] + (39): Parameter containing: [torch.float32 of size 512x512] + (40): Parameter containing: [torch.float32 of size 512x512] + (41): Parameter containing: [torch.float32 of size 512x512] + (42): Parameter containing: [torch.float32 of size 512x512] + (43): Parameter containing: [torch.float32 of size 512x512] + (44): Parameter containing: [torch.float32 of size 512x512] + ) + (day_biases): ParameterList( + (0): Parameter containing: [torch.float32 of size 1x512] + (1): Parameter containing: [torch.float32 of size 1x512] + (2): Parameter containing: [torch.float32 of size 1x512] + (3): Parameter containing: [torch.float32 of size 1x512] + (4): Parameter containing: [torch.float32 of size 1x512] + (5): Parameter containing: [torch.float32 of size 1x512] + (6): Parameter containing: [torch.float32 of size 1x512] + (7): Parameter containing: [torch.float32 of size 1x512] + (8): Parameter containing: [torch.float32 of size 1x512] + (9): Parameter containing: [torch.float32 of size 1x512] + (10): Parameter containing: [torch.float32 of size 1x512] + (11): Parameter containing: [torch.float32 of size 1x512] + (12): Parameter containing: [torch.float32 of size 1x512] + (13): Parameter containing: [torch.float32 of size 1x512] + (14): Parameter containing: [torch.float32 of size 1x512] + (15): Parameter containing: [torch.float32 of size 1x512] + (16): Parameter containing: [torch.float32 of size 1x512] + (17): Parameter containing: [torch.float32 of size 1x512] + (18): Parameter containing: [torch.float32 of size 1x512] + (19): Parameter containing: [torch.float32 of size 1x512] + (20): Parameter containing: [torch.float32 of size 1x512] + (21): Parameter containing: [torch.float32 of size 1x512] + (22): Parameter containing: [torch.float32 of size 1x512] + (23): Parameter containing: [torch.float32 of size 1x512] + (24): Parameter containing: [torch.float32 of size 1x512] + (25): Parameter containing: [torch.float32 of size 1x512] + (26): Parameter containing: [torch.float32 of size 1x512] + (27): Parameter containing: [torch.float32 of size 1x512] + (28): Parameter containing: [torch.float32 of size 1x512] + (29): Parameter containing: [torch.float32 of size 1x512] + (30): Parameter containing: [torch.float32 of size 1x512] + (31): Parameter containing: [torch.float32 of size 1x512] + (32): Parameter containing: [torch.float32 of size 1x512] + (33): Parameter containing: [torch.float32 of size 1x512] + (34): Parameter containing: [torch.float32 of size 1x512] + (35): Parameter containing: [torch.float32 of size 1x512] + (36): Parameter containing: [torch.float32 of size 1x512] + (37): Parameter containing: [torch.float32 of size 1x512] + (38): Parameter containing: [torch.float32 of size 1x512] + (39): Parameter containing: [torch.float32 of size 1x512] + (40): Parameter containing: [torch.float32 of size 1x512] + (41): Parameter containing: [torch.float32 of size 1x512] + (42): Parameter containing: [torch.float32 of size 1x512] + (43): Parameter containing: [torch.float32 of size 1x512] + (44): Parameter containing: [torch.float32 of size 1x512] + ) + (day_layer_dropout): Dropout(p=0.2, inplace=False) + (gru): GRU(7168, 768, num_layers=3, batch_first=True, dropout=0.4) + (out): Linear(in_features=768, out_features=41, bias=True) + ) + (noisy_speech_model): NoisySpeechModel( + (gru): GRU(7168, 768, num_layers=2, batch_first=True, dropout=0.4) + (out): Linear(in_features=768, out_features=41, bias=True) + ) +) +2025-10-12 18:35:38,190: Model has 687,568,466 parameters +2025-10-12 18:35:38,191: Model has 23,639,040 day-specific parameters | 3.44% of total parameters +2025-10-12 18:38:50,217: Using device: cpu +2025-10-12 18:38:50,217: Accelerator state: Distributed environment: NO +Num processes: 1 +Process index: 0 +Local process index: 0 +Device: cpu + +Mixed precision type: bf16 + +2025-10-12 18:39:25,946: torch.compile disabled for new TripleGRUDecoder compatibility +2025-10-12 18:39:25,946: Initialized RNN decoding model +2025-10-12 18:39:25,946: TripleGRUDecoder( + (noise_model): NoiseModel( + (day_layer_activation): Softsign() + (day_weights): ParameterList( + (0): Parameter containing: [torch.float32 of size 512x512] + (1): Parameter containing: [torch.float32 of size 512x512] + (2): Parameter containing: [torch.float32 of size 512x512] + (3): Parameter containing: [torch.float32 of size 512x512] + (4): Parameter containing: [torch.float32 of size 512x512] + (5): Parameter containing: [torch.float32 of size 512x512] + (6): Parameter containing: [torch.float32 of size 512x512] + (7): Parameter containing: [torch.float32 of size 512x512] + (8): Parameter containing: [torch.float32 of size 512x512] + (9): Parameter containing: [torch.float32 of size 512x512] + (10): Parameter containing: [torch.float32 of size 512x512] + (11): Parameter containing: [torch.float32 of size 512x512] + (12): Parameter containing: [torch.float32 of size 512x512] + (13): Parameter containing: [torch.float32 of size 512x512] + (14): Parameter containing: [torch.float32 of size 512x512] + (15): Parameter containing: [torch.float32 of size 512x512] + (16): Parameter containing: [torch.float32 of size 512x512] + (17): Parameter containing: [torch.float32 of size 512x512] + (18): Parameter containing: [torch.float32 of size 512x512] + (19): Parameter containing: [torch.float32 of size 512x512] + (20): Parameter containing: [torch.float32 of size 512x512] + (21): Parameter containing: [torch.float32 of size 512x512] + (22): Parameter containing: [torch.float32 of size 512x512] + (23): Parameter containing: [torch.float32 of size 512x512] + (24): Parameter containing: [torch.float32 of size 512x512] + (25): Parameter containing: [torch.float32 of size 512x512] + (26): Parameter containing: [torch.float32 of size 512x512] + (27): Parameter containing: [torch.float32 of size 512x512] + (28): Parameter containing: [torch.float32 of size 512x512] + (29): Parameter containing: [torch.float32 of size 512x512] + (30): Parameter containing: [torch.float32 of size 512x512] + (31): Parameter containing: [torch.float32 of size 512x512] + (32): Parameter containing: [torch.float32 of size 512x512] + (33): Parameter containing: [torch.float32 of size 512x512] + (34): Parameter containing: [torch.float32 of size 512x512] + (35): Parameter containing: [torch.float32 of size 512x512] + (36): Parameter containing: [torch.float32 of size 512x512] + (37): Parameter containing: [torch.float32 of size 512x512] + (38): Parameter containing: [torch.float32 of size 512x512] + (39): Parameter containing: [torch.float32 of size 512x512] + (40): Parameter containing: [torch.float32 of size 512x512] + (41): Parameter containing: [torch.float32 of size 512x512] + (42): Parameter containing: [torch.float32 of size 512x512] + (43): Parameter containing: [torch.float32 of size 512x512] + (44): Parameter containing: [torch.float32 of size 512x512] + ) + (day_biases): ParameterList( + (0): Parameter containing: [torch.float32 of size 1x512] + (1): Parameter containing: [torch.float32 of size 1x512] + (2): Parameter containing: [torch.float32 of size 1x512] + (3): Parameter containing: [torch.float32 of size 1x512] + (4): Parameter containing: [torch.float32 of size 1x512] + (5): Parameter containing: [torch.float32 of size 1x512] + (6): Parameter containing: [torch.float32 of size 1x512] + (7): Parameter containing: [torch.float32 of size 1x512] + (8): Parameter containing: [torch.float32 of size 1x512] + (9): Parameter containing: [torch.float32 of size 1x512] + (10): Parameter containing: [torch.float32 of size 1x512] + (11): Parameter containing: [torch.float32 of size 1x512] + (12): Parameter containing: [torch.float32 of size 1x512] + (13): Parameter containing: [torch.float32 of size 1x512] + (14): Parameter containing: [torch.float32 of size 1x512] + (15): Parameter containing: [torch.float32 of size 1x512] + (16): Parameter containing: [torch.float32 of size 1x512] + (17): Parameter containing: [torch.float32 of size 1x512] + (18): Parameter containing: [torch.float32 of size 1x512] + (19): Parameter containing: [torch.float32 of size 1x512] + (20): Parameter containing: [torch.float32 of size 1x512] + (21): Parameter containing: [torch.float32 of size 1x512] + (22): Parameter containing: [torch.float32 of size 1x512] + (23): Parameter containing: [torch.float32 of size 1x512] + (24): Parameter containing: [torch.float32 of size 1x512] + (25): Parameter containing: [torch.float32 of size 1x512] + (26): Parameter containing: [torch.float32 of size 1x512] + (27): Parameter containing: [torch.float32 of size 1x512] + (28): Parameter containing: [torch.float32 of size 1x512] + (29): Parameter containing: [torch.float32 of size 1x512] + (30): Parameter containing: [torch.float32 of size 1x512] + (31): Parameter containing: [torch.float32 of size 1x512] + (32): Parameter containing: [torch.float32 of size 1x512] + (33): Parameter containing: [torch.float32 of size 1x512] + (34): Parameter containing: [torch.float32 of size 1x512] + (35): Parameter containing: [torch.float32 of size 1x512] + (36): Parameter containing: [torch.float32 of size 1x512] + (37): Parameter containing: [torch.float32 of size 1x512] + (38): Parameter containing: [torch.float32 of size 1x512] + (39): Parameter containing: [torch.float32 of size 1x512] + (40): Parameter containing: [torch.float32 of size 1x512] + (41): Parameter containing: [torch.float32 of size 1x512] + (42): Parameter containing: [torch.float32 of size 1x512] + (43): Parameter containing: [torch.float32 of size 1x512] + (44): Parameter containing: [torch.float32 of size 1x512] + ) + (day_layer_dropout): Dropout(p=0.2, inplace=False) + (gru): GRU(7168, 7168, num_layers=2, batch_first=True, dropout=0.4) + ) + (clean_speech_model): CleanSpeechModel( + (day_layer_activation): Softsign() + (day_weights): ParameterList( + (0): Parameter containing: [torch.float32 of size 512x512] + (1): Parameter containing: [torch.float32 of size 512x512] + (2): Parameter containing: [torch.float32 of size 512x512] + (3): Parameter containing: [torch.float32 of size 512x512] + (4): Parameter containing: [torch.float32 of size 512x512] + (5): Parameter containing: [torch.float32 of size 512x512] + (6): Parameter containing: [torch.float32 of size 512x512] + (7): Parameter containing: [torch.float32 of size 512x512] + (8): Parameter containing: [torch.float32 of size 512x512] + (9): Parameter containing: [torch.float32 of size 512x512] + (10): Parameter containing: [torch.float32 of size 512x512] + (11): Parameter containing: [torch.float32 of size 512x512] + (12): Parameter containing: [torch.float32 of size 512x512] + (13): Parameter containing: [torch.float32 of size 512x512] + (14): Parameter containing: [torch.float32 of size 512x512] + (15): Parameter containing: [torch.float32 of size 512x512] + (16): Parameter containing: [torch.float32 of size 512x512] + (17): Parameter containing: [torch.float32 of size 512x512] + (18): Parameter containing: [torch.float32 of size 512x512] + (19): Parameter containing: [torch.float32 of size 512x512] + (20): Parameter containing: [torch.float32 of size 512x512] + (21): Parameter containing: [torch.float32 of size 512x512] + (22): Parameter containing: [torch.float32 of size 512x512] + (23): Parameter containing: [torch.float32 of size 512x512] + (24): Parameter containing: [torch.float32 of size 512x512] + (25): Parameter containing: [torch.float32 of size 512x512] + (26): Parameter containing: [torch.float32 of size 512x512] + (27): Parameter containing: [torch.float32 of size 512x512] + (28): Parameter containing: [torch.float32 of size 512x512] + (29): Parameter containing: [torch.float32 of size 512x512] + (30): Parameter containing: [torch.float32 of size 512x512] + (31): Parameter containing: [torch.float32 of size 512x512] + (32): Parameter containing: [torch.float32 of size 512x512] + (33): Parameter containing: [torch.float32 of size 512x512] + (34): Parameter containing: [torch.float32 of size 512x512] + (35): Parameter containing: [torch.float32 of size 512x512] + (36): Parameter containing: [torch.float32 of size 512x512] + (37): Parameter containing: [torch.float32 of size 512x512] + (38): Parameter containing: [torch.float32 of size 512x512] + (39): Parameter containing: [torch.float32 of size 512x512] + (40): Parameter containing: [torch.float32 of size 512x512] + (41): Parameter containing: [torch.float32 of size 512x512] + (42): Parameter containing: [torch.float32 of size 512x512] + (43): Parameter containing: [torch.float32 of size 512x512] + (44): Parameter containing: [torch.float32 of size 512x512] + ) + (day_biases): ParameterList( + (0): Parameter containing: [torch.float32 of size 1x512] + (1): Parameter containing: [torch.float32 of size 1x512] + (2): Parameter containing: [torch.float32 of size 1x512] + (3): Parameter containing: [torch.float32 of size 1x512] + (4): Parameter containing: [torch.float32 of size 1x512] + (5): Parameter containing: [torch.float32 of size 1x512] + (6): Parameter containing: [torch.float32 of size 1x512] + (7): Parameter containing: [torch.float32 of size 1x512] + (8): Parameter containing: [torch.float32 of size 1x512] + (9): Parameter containing: [torch.float32 of size 1x512] + (10): Parameter containing: [torch.float32 of size 1x512] + (11): Parameter containing: [torch.float32 of size 1x512] + (12): Parameter containing: [torch.float32 of size 1x512] + (13): Parameter containing: [torch.float32 of size 1x512] + (14): Parameter containing: [torch.float32 of size 1x512] + (15): Parameter containing: [torch.float32 of size 1x512] + (16): Parameter containing: [torch.float32 of size 1x512] + (17): Parameter containing: [torch.float32 of size 1x512] + (18): Parameter containing: [torch.float32 of size 1x512] + (19): Parameter containing: [torch.float32 of size 1x512] + (20): Parameter containing: [torch.float32 of size 1x512] + (21): Parameter containing: [torch.float32 of size 1x512] + (22): Parameter containing: [torch.float32 of size 1x512] + (23): Parameter containing: [torch.float32 of size 1x512] + (24): Parameter containing: [torch.float32 of size 1x512] + (25): Parameter containing: [torch.float32 of size 1x512] + (26): Parameter containing: [torch.float32 of size 1x512] + (27): Parameter containing: [torch.float32 of size 1x512] + (28): Parameter containing: [torch.float32 of size 1x512] + (29): Parameter containing: [torch.float32 of size 1x512] + (30): Parameter containing: [torch.float32 of size 1x512] + (31): Parameter containing: [torch.float32 of size 1x512] + (32): Parameter containing: [torch.float32 of size 1x512] + (33): Parameter containing: [torch.float32 of size 1x512] + (34): Parameter containing: [torch.float32 of size 1x512] + (35): Parameter containing: [torch.float32 of size 1x512] + (36): Parameter containing: [torch.float32 of size 1x512] + (37): Parameter containing: [torch.float32 of size 1x512] + (38): Parameter containing: [torch.float32 of size 1x512] + (39): Parameter containing: [torch.float32 of size 1x512] + (40): Parameter containing: [torch.float32 of size 1x512] + (41): Parameter containing: [torch.float32 of size 1x512] + (42): Parameter containing: [torch.float32 of size 1x512] + (43): Parameter containing: [torch.float32 of size 1x512] + (44): Parameter containing: [torch.float32 of size 1x512] + ) + (day_layer_dropout): Dropout(p=0.2, inplace=False) + (gru): GRU(7168, 768, num_layers=3, batch_first=True, dropout=0.4) + (out): Linear(in_features=768, out_features=41, bias=True) + ) + (noisy_speech_model): NoisySpeechModel( + (gru): GRU(7168, 768, num_layers=2, batch_first=True, dropout=0.4) + (out): Linear(in_features=768, out_features=41, bias=True) + ) +) +2025-10-12 18:39:25,958: Model has 687,568,466 parameters +2025-10-12 18:39:25,958: Model has 23,639,040 day-specific parameters | 3.44% of total parameters