diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index fd69cfa..201cabc 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -193,7 +193,6 @@ class BrainToTextDecoder_Trainer: if use_tpu: # For TPU, create a custom DataLoader that properly handles our batch-returning Dataset # TPU requires specific DataLoader configuration to avoid batch_sampler issues - from torch.utils.data import DataLoader self.train_loader = DataLoader( self.train_dataset, batch_size = None, # None because our Dataset returns batches