tpu
This commit is contained in:
@@ -189,13 +189,24 @@ class BrainToTextDecoder_Trainer:
|
||||
random_seed = self.args['dataset']['seed'],
|
||||
feature_subset = feature_subset
|
||||
)
|
||||
# Standard DataLoader configuration - let Accelerator handle device-specific optimizations
|
||||
# Custom collate function that handles pre-batched data from our dataset
|
||||
def collate_fn(batch):
|
||||
# Our dataset returns full batches, so batch will be a list of single batch dict
|
||||
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
|
||||
if len(batch) == 1 and isinstance(batch[0], dict):
|
||||
return batch[0]
|
||||
else:
|
||||
# Fallback for unexpected batch structure
|
||||
return batch
|
||||
|
||||
# DataLoader configuration compatible with Accelerate
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||||
shuffle = self.args['dataset']['loader_shuffle'],
|
||||
num_workers = self.args['dataset']['num_dataloader_workers'],
|
||||
pin_memory = True
|
||||
pin_memory = True,
|
||||
collate_fn = collate_fn
|
||||
)
|
||||
|
||||
# val dataset and dataloader
|
||||
@@ -209,13 +220,14 @@ class BrainToTextDecoder_Trainer:
|
||||
random_seed = self.args['dataset']['seed'],
|
||||
feature_subset = feature_subset
|
||||
)
|
||||
# Standard validation DataLoader configuration
|
||||
# Validation DataLoader with same collate function
|
||||
self.val_loader = DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||||
shuffle = False,
|
||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||
pin_memory = True
|
||||
pin_memory = True,
|
||||
collate_fn = collate_fn # Use same collate function
|
||||
)
|
||||
|
||||
self.logger.info("Successfully initialized datasets")
|
||||
|
Reference in New Issue
Block a user