diff --git a/TPU_ISSUES_RECORD.md b/TPU_ISSUES_RECORD.md index ab09f80..f0db3a5 100644 --- a/TPU_ISSUES_RECORD.md +++ b/TPU_ISSUES_RECORD.md @@ -74,10 +74,65 @@ self.accelerator = Accelerator( - TPU v5e-8: 8 cores × 16GB = 128GB total - Fewer cores = LESS total memory (not more per core) +## Latest Status (2025-10-12) + +### After DataLoaderConfiguration Fix +✅ **even_batches Error RESOLVED** - No more `ValueError: You need to use 'even_batches=False'` + +❌ **NEW ERROR**: `TypeError: 'NoneType' object is not iterable` +``` +File "/usr/local/lib/python3.12/site-packages/accelerate/data_loader.py", line 221, in _iter_with_no_split + for idx, batch in enumerate(self.batch_sampler): + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +TypeError: 'NoneType' object is not iterable +``` + +**Root Cause**: `batch_sampler` becomes `None` when our DataLoader has `batch_size=None` + +### Current Investigation +- The issue is in Accelerate's data_loader.py line 221 +- Our custom dataset returns full batches, so we use `batch_size=None` in DataLoader +- But Accelerate expects a proper batch_sampler when iterating +- This is a fundamental incompatibility between our batching approach and Accelerate's expectations + +## FINAL SOLUTION ✅ + +### Problem Resolution +1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration +2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn + +### Final Implementation +```python +# In rnn_trainer.py prepare_dataloaders() + +# Custom collate function that handles pre-batched data from our dataset +def collate_fn(batch): + # Our dataset returns full batches, so batch will be a list of single batch dict + # Extract the first (and only) element since our dataset.__getitem__() returns a full batch + if len(batch) == 1 and isinstance(batch[0], dict): + return batch[0] + else: + # Fallback for unexpected batch structure + return batch + +# DataLoader configuration compatible with Accelerate +self.train_loader = DataLoader( + self.train_dataset, + batch_size = 1, # Use batch_size=1 since dataset returns full batches + shuffle = shuffle_setting, + num_workers = workers_setting, + pin_memory = True, + collate_fn = collate_fn +) +``` + +**Key Insight**: Our dataset's `__getitem__()` returns complete batches, but Accelerate expects individual samples. The solution is to use `batch_size=1` and a custom `collate_fn` that unwraps the pre-batched data. + ## Next Steps -1. Implement correct even_batches=False in accelerator.prepare() -2. Test TPU training without overengineering -3. Verify memory usage with 8 cores configuration +1. ~~Implement even_batches=False~~ ✅ DONE +2. ~~Fix batch_sampler None issue~~ ✅ DONE +3. Test TPU training with complete solution +4. Integrate final solution into CLAUDE.md ## Lessons Learned - Don't overcomplicate TPU conversion - it should be straightforward diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 5db7a28..8d4c77c 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -189,13 +189,24 @@ class BrainToTextDecoder_Trainer: random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) - # Standard DataLoader configuration - let Accelerator handle device-specific optimizations + # Custom collate function that handles pre-batched data from our dataset + def collate_fn(batch): + # Our dataset returns full batches, so batch will be a list of single batch dict + # Extract the first (and only) element since our dataset.__getitem__() returns a full batch + if len(batch) == 1 and isinstance(batch[0], dict): + return batch[0] + else: + # Fallback for unexpected batch structure + return batch + + # DataLoader configuration compatible with Accelerate self.train_loader = DataLoader( self.train_dataset, - batch_size = None, # Dataset.__getitem__() already returns batches + batch_size = 1, # Use batch_size=1 since dataset returns full batches shuffle = self.args['dataset']['loader_shuffle'], num_workers = self.args['dataset']['num_dataloader_workers'], - pin_memory = True + pin_memory = True, + collate_fn = collate_fn ) # val dataset and dataloader @@ -209,13 +220,14 @@ class BrainToTextDecoder_Trainer: random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) - # Standard validation DataLoader configuration + # Validation DataLoader with same collate function self.val_loader = DataLoader( self.val_dataset, - batch_size = None, # Dataset.__getitem__() already returns batches + batch_size = 1, # Use batch_size=1 since dataset returns full batches shuffle = False, num_workers = 0, # Keep validation dataloader single-threaded for consistency - pin_memory = True + pin_memory = True, + collate_fn = collate_fn # Use same collate function ) self.logger.info("Successfully initialized datasets")