tpu支持

This commit is contained in:
Zchen
2025-10-12 18:41:26 +08:00
parent 1a906d3248
commit 40e4d00576
5 changed files with 456 additions and 226 deletions

View File

@@ -90,12 +90,12 @@ class BrainToTextDataset(Dataset):
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
def __len__(self):
'''
How many batches are in this dataset.
Because training data is sampled randomly, there is no fixed dataset length,
however this method is required for DataLoader to work
'''
return self.n_batches
How many batches are in this dataset.
Because training data is sampled randomly, there is no fixed dataset length,
however this method is required for DataLoader to work
'''
return self.n_batches if self.n_batches is not None else 0
def __getitem__(self, idx):
'''
@@ -269,7 +269,9 @@ def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_
# Get trials in each day
trials_per_day = {}
for i, path in enumerate(file_paths):
session = [s for s in path.split('/') if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
# Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
good_trial_indices = []

View File

@@ -72,11 +72,11 @@ class BrainToTextDecoder_Trainer:
# Create output directory
if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=False)
os.makedirs(self.args['output_dir'], exist_ok=True)
# Create checkpoint directory
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok=False)
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
# Set up logging
self.logger = logging.getLogger(__name__)
@@ -188,12 +188,16 @@ class BrainToTextDecoder_Trainer:
# Use TPU-optimized dataloader settings if TPU is enabled
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
# For TPU environments or when batch_size=None causes issues, use batch_size=1
# since our dataset already returns complete batches
batch_size_setting = 1 if (self.args.get('use_tpu', False) or self.accelerator.device.type == 'xla') else None
self.train_loader = DataLoader(
self.train_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = num_workers,
pin_memory = True
pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
)
# val dataset and dataloader
@@ -209,10 +213,10 @@ class BrainToTextDecoder_Trainer:
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True
pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
)
self.logger.info("Successfully initialized datasets")

View File

@@ -1,213 +0,0 @@
2025-10-12 09:32:47,330: Using device: cpu
2025-10-12 09:33:20,119: torch.compile disabled for new TripleGRUDecoder compatibility
2025-10-12 09:33:20,120: Initialized RNN decoding model
2025-10-12 09:33:20,120: Model type: <class 'rnn_model.TripleGRUDecoder'>
2025-10-12 09:33:20,120: Model is callable: True
2025-10-12 09:33:20,120: Model has forward method: True
2025-10-12 09:33:20,120: TripleGRUDecoder(
(noise_model): NoiseModel(
(day_layer_activation): Softsign()
(day_weights): ParameterList(
(0): Parameter containing: [torch.float32 of size 512x512]
(1): Parameter containing: [torch.float32 of size 512x512]
(2): Parameter containing: [torch.float32 of size 512x512]
(3): Parameter containing: [torch.float32 of size 512x512]
(4): Parameter containing: [torch.float32 of size 512x512]
(5): Parameter containing: [torch.float32 of size 512x512]
(6): Parameter containing: [torch.float32 of size 512x512]
(7): Parameter containing: [torch.float32 of size 512x512]
(8): Parameter containing: [torch.float32 of size 512x512]
(9): Parameter containing: [torch.float32 of size 512x512]
(10): Parameter containing: [torch.float32 of size 512x512]
(11): Parameter containing: [torch.float32 of size 512x512]
(12): Parameter containing: [torch.float32 of size 512x512]
(13): Parameter containing: [torch.float32 of size 512x512]
(14): Parameter containing: [torch.float32 of size 512x512]
(15): Parameter containing: [torch.float32 of size 512x512]
(16): Parameter containing: [torch.float32 of size 512x512]
(17): Parameter containing: [torch.float32 of size 512x512]
(18): Parameter containing: [torch.float32 of size 512x512]
(19): Parameter containing: [torch.float32 of size 512x512]
(20): Parameter containing: [torch.float32 of size 512x512]
(21): Parameter containing: [torch.float32 of size 512x512]
(22): Parameter containing: [torch.float32 of size 512x512]
(23): Parameter containing: [torch.float32 of size 512x512]
(24): Parameter containing: [torch.float32 of size 512x512]
(25): Parameter containing: [torch.float32 of size 512x512]
(26): Parameter containing: [torch.float32 of size 512x512]
(27): Parameter containing: [torch.float32 of size 512x512]
(28): Parameter containing: [torch.float32 of size 512x512]
(29): Parameter containing: [torch.float32 of size 512x512]
(30): Parameter containing: [torch.float32 of size 512x512]
(31): Parameter containing: [torch.float32 of size 512x512]
(32): Parameter containing: [torch.float32 of size 512x512]
(33): Parameter containing: [torch.float32 of size 512x512]
(34): Parameter containing: [torch.float32 of size 512x512]
(35): Parameter containing: [torch.float32 of size 512x512]
(36): Parameter containing: [torch.float32 of size 512x512]
(37): Parameter containing: [torch.float32 of size 512x512]
(38): Parameter containing: [torch.float32 of size 512x512]
(39): Parameter containing: [torch.float32 of size 512x512]
(40): Parameter containing: [torch.float32 of size 512x512]
(41): Parameter containing: [torch.float32 of size 512x512]
(42): Parameter containing: [torch.float32 of size 512x512]
(43): Parameter containing: [torch.float32 of size 512x512]
(44): Parameter containing: [torch.float32 of size 512x512]
)
(day_biases): ParameterList(
(0): Parameter containing: [torch.float32 of size 1x512]
(1): Parameter containing: [torch.float32 of size 1x512]
(2): Parameter containing: [torch.float32 of size 1x512]
(3): Parameter containing: [torch.float32 of size 1x512]
(4): Parameter containing: [torch.float32 of size 1x512]
(5): Parameter containing: [torch.float32 of size 1x512]
(6): Parameter containing: [torch.float32 of size 1x512]
(7): Parameter containing: [torch.float32 of size 1x512]
(8): Parameter containing: [torch.float32 of size 1x512]
(9): Parameter containing: [torch.float32 of size 1x512]
(10): Parameter containing: [torch.float32 of size 1x512]
(11): Parameter containing: [torch.float32 of size 1x512]
(12): Parameter containing: [torch.float32 of size 1x512]
(13): Parameter containing: [torch.float32 of size 1x512]
(14): Parameter containing: [torch.float32 of size 1x512]
(15): Parameter containing: [torch.float32 of size 1x512]
(16): Parameter containing: [torch.float32 of size 1x512]
(17): Parameter containing: [torch.float32 of size 1x512]
(18): Parameter containing: [torch.float32 of size 1x512]
(19): Parameter containing: [torch.float32 of size 1x512]
(20): Parameter containing: [torch.float32 of size 1x512]
(21): Parameter containing: [torch.float32 of size 1x512]
(22): Parameter containing: [torch.float32 of size 1x512]
(23): Parameter containing: [torch.float32 of size 1x512]
(24): Parameter containing: [torch.float32 of size 1x512]
(25): Parameter containing: [torch.float32 of size 1x512]
(26): Parameter containing: [torch.float32 of size 1x512]
(27): Parameter containing: [torch.float32 of size 1x512]
(28): Parameter containing: [torch.float32 of size 1x512]
(29): Parameter containing: [torch.float32 of size 1x512]
(30): Parameter containing: [torch.float32 of size 1x512]
(31): Parameter containing: [torch.float32 of size 1x512]
(32): Parameter containing: [torch.float32 of size 1x512]
(33): Parameter containing: [torch.float32 of size 1x512]
(34): Parameter containing: [torch.float32 of size 1x512]
(35): Parameter containing: [torch.float32 of size 1x512]
(36): Parameter containing: [torch.float32 of size 1x512]
(37): Parameter containing: [torch.float32 of size 1x512]
(38): Parameter containing: [torch.float32 of size 1x512]
(39): Parameter containing: [torch.float32 of size 1x512]
(40): Parameter containing: [torch.float32 of size 1x512]
(41): Parameter containing: [torch.float32 of size 1x512]
(42): Parameter containing: [torch.float32 of size 1x512]
(43): Parameter containing: [torch.float32 of size 1x512]
(44): Parameter containing: [torch.float32 of size 1x512]
)
(day_layer_dropout): Dropout(p=0.2, inplace=False)
(gru): GRU(7168, 7168, num_layers=2, batch_first=True, dropout=0.4)
)
(clean_speech_model): CleanSpeechModel(
(day_layer_activation): Softsign()
(day_weights): ParameterList(
(0): Parameter containing: [torch.float32 of size 512x512]
(1): Parameter containing: [torch.float32 of size 512x512]
(2): Parameter containing: [torch.float32 of size 512x512]
(3): Parameter containing: [torch.float32 of size 512x512]
(4): Parameter containing: [torch.float32 of size 512x512]
(5): Parameter containing: [torch.float32 of size 512x512]
(6): Parameter containing: [torch.float32 of size 512x512]
(7): Parameter containing: [torch.float32 of size 512x512]
(8): Parameter containing: [torch.float32 of size 512x512]
(9): Parameter containing: [torch.float32 of size 512x512]
(10): Parameter containing: [torch.float32 of size 512x512]
(11): Parameter containing: [torch.float32 of size 512x512]
(12): Parameter containing: [torch.float32 of size 512x512]
(13): Parameter containing: [torch.float32 of size 512x512]
(14): Parameter containing: [torch.float32 of size 512x512]
(15): Parameter containing: [torch.float32 of size 512x512]
(16): Parameter containing: [torch.float32 of size 512x512]
(17): Parameter containing: [torch.float32 of size 512x512]
(18): Parameter containing: [torch.float32 of size 512x512]
(19): Parameter containing: [torch.float32 of size 512x512]
(20): Parameter containing: [torch.float32 of size 512x512]
(21): Parameter containing: [torch.float32 of size 512x512]
(22): Parameter containing: [torch.float32 of size 512x512]
(23): Parameter containing: [torch.float32 of size 512x512]
(24): Parameter containing: [torch.float32 of size 512x512]
(25): Parameter containing: [torch.float32 of size 512x512]
(26): Parameter containing: [torch.float32 of size 512x512]
(27): Parameter containing: [torch.float32 of size 512x512]
(28): Parameter containing: [torch.float32 of size 512x512]
(29): Parameter containing: [torch.float32 of size 512x512]
(30): Parameter containing: [torch.float32 of size 512x512]
(31): Parameter containing: [torch.float32 of size 512x512]
(32): Parameter containing: [torch.float32 of size 512x512]
(33): Parameter containing: [torch.float32 of size 512x512]
(34): Parameter containing: [torch.float32 of size 512x512]
(35): Parameter containing: [torch.float32 of size 512x512]
(36): Parameter containing: [torch.float32 of size 512x512]
(37): Parameter containing: [torch.float32 of size 512x512]
(38): Parameter containing: [torch.float32 of size 512x512]
(39): Parameter containing: [torch.float32 of size 512x512]
(40): Parameter containing: [torch.float32 of size 512x512]
(41): Parameter containing: [torch.float32 of size 512x512]
(42): Parameter containing: [torch.float32 of size 512x512]
(43): Parameter containing: [torch.float32 of size 512x512]
(44): Parameter containing: [torch.float32 of size 512x512]
)
(day_biases): ParameterList(
(0): Parameter containing: [torch.float32 of size 1x512]
(1): Parameter containing: [torch.float32 of size 1x512]
(2): Parameter containing: [torch.float32 of size 1x512]
(3): Parameter containing: [torch.float32 of size 1x512]
(4): Parameter containing: [torch.float32 of size 1x512]
(5): Parameter containing: [torch.float32 of size 1x512]
(6): Parameter containing: [torch.float32 of size 1x512]
(7): Parameter containing: [torch.float32 of size 1x512]
(8): Parameter containing: [torch.float32 of size 1x512]
(9): Parameter containing: [torch.float32 of size 1x512]
(10): Parameter containing: [torch.float32 of size 1x512]
(11): Parameter containing: [torch.float32 of size 1x512]
(12): Parameter containing: [torch.float32 of size 1x512]
(13): Parameter containing: [torch.float32 of size 1x512]
(14): Parameter containing: [torch.float32 of size 1x512]
(15): Parameter containing: [torch.float32 of size 1x512]
(16): Parameter containing: [torch.float32 of size 1x512]
(17): Parameter containing: [torch.float32 of size 1x512]
(18): Parameter containing: [torch.float32 of size 1x512]
(19): Parameter containing: [torch.float32 of size 1x512]
(20): Parameter containing: [torch.float32 of size 1x512]
(21): Parameter containing: [torch.float32 of size 1x512]
(22): Parameter containing: [torch.float32 of size 1x512]
(23): Parameter containing: [torch.float32 of size 1x512]
(24): Parameter containing: [torch.float32 of size 1x512]
(25): Parameter containing: [torch.float32 of size 1x512]
(26): Parameter containing: [torch.float32 of size 1x512]
(27): Parameter containing: [torch.float32 of size 1x512]
(28): Parameter containing: [torch.float32 of size 1x512]
(29): Parameter containing: [torch.float32 of size 1x512]
(30): Parameter containing: [torch.float32 of size 1x512]
(31): Parameter containing: [torch.float32 of size 1x512]
(32): Parameter containing: [torch.float32 of size 1x512]
(33): Parameter containing: [torch.float32 of size 1x512]
(34): Parameter containing: [torch.float32 of size 1x512]
(35): Parameter containing: [torch.float32 of size 1x512]
(36): Parameter containing: [torch.float32 of size 1x512]
(37): Parameter containing: [torch.float32 of size 1x512]
(38): Parameter containing: [torch.float32 of size 1x512]
(39): Parameter containing: [torch.float32 of size 1x512]
(40): Parameter containing: [torch.float32 of size 1x512]
(41): Parameter containing: [torch.float32 of size 1x512]
(42): Parameter containing: [torch.float32 of size 1x512]
(43): Parameter containing: [torch.float32 of size 1x512]
(44): Parameter containing: [torch.float32 of size 1x512]
)
(day_layer_dropout): Dropout(p=0.2, inplace=False)
(gru): GRU(7168, 768, num_layers=3, batch_first=True, dropout=0.4)
(out): Linear(in_features=768, out_features=41, bias=True)
)
(noisy_speech_model): NoisySpeechModel(
(gru): GRU(7168, 768, num_layers=2, batch_first=True, dropout=0.4)
(out): Linear(in_features=768, out_features=41, bias=True)
)
)
2025-10-12 09:33:20,124: Model has 687,568,466 parameters
2025-10-12 09:33:20,124: Model has 23,639,040 day-specific parameters | 3.44% of total parameters