This commit is contained in:
Zchen
2025-10-12 21:36:33 +08:00
parent 580648c058
commit 6e1d8e18f7
2 changed files with 76 additions and 9 deletions

View File

@@ -189,13 +189,24 @@ class BrainToTextDecoder_Trainer:
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Standard DataLoader configuration - let Accelerator handle device-specific optimizations
# Custom collate function that handles pre-batched data from our dataset
def collate_fn(batch):
# Our dataset returns full batches, so batch will be a list of single batch dict
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
if len(batch) == 1 and isinstance(batch[0], dict):
return batch[0]
else:
# Fallback for unexpected batch structure
return batch
# DataLoader configuration compatible with Accelerate
self.train_loader = DataLoader(
self.train_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = self.args['dataset']['num_dataloader_workers'],
pin_memory = True
pin_memory = True,
collate_fn = collate_fn
)
# val dataset and dataloader
@@ -209,13 +220,14 @@ class BrainToTextDecoder_Trainer:
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Standard validation DataLoader configuration
# Validation DataLoader with same collate function
self.val_loader = DataLoader(
self.val_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True
pin_memory = True,
collate_fn = collate_fn # Use same collate function
)
self.logger.info("Successfully initialized datasets")