This commit is contained in:
Zchen
2025-10-12 21:36:33 +08:00
parent 580648c058
commit 6e1d8e18f7
2 changed files with 76 additions and 9 deletions

View File

@@ -74,10 +74,65 @@ self.accelerator = Accelerator(
- TPU v5e-8: 8 cores × 16GB = 128GB total
- Fewer cores = LESS total memory (not more per core)
## Latest Status (2025-10-12)
### After DataLoaderConfiguration Fix
**even_batches Error RESOLVED** - No more `ValueError: You need to use 'even_batches=False'`
**NEW ERROR**: `TypeError: 'NoneType' object is not iterable`
```
File "/usr/local/lib/python3.12/site-packages/accelerate/data_loader.py", line 221, in _iter_with_no_split
for idx, batch in enumerate(self.batch_sampler):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'NoneType' object is not iterable
```
**Root Cause**: `batch_sampler` becomes `None` when our DataLoader has `batch_size=None`
### Current Investigation
- The issue is in Accelerate's data_loader.py line 221
- Our custom dataset returns full batches, so we use `batch_size=None` in DataLoader
- But Accelerate expects a proper batch_sampler when iterating
- This is a fundamental incompatibility between our batching approach and Accelerate's expectations
## FINAL SOLUTION ✅
### Problem Resolution
1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration
2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn
### Final Implementation
```python
# In rnn_trainer.py prepare_dataloaders()
# Custom collate function that handles pre-batched data from our dataset
def collate_fn(batch):
# Our dataset returns full batches, so batch will be a list of single batch dict
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
if len(batch) == 1 and isinstance(batch[0], dict):
return batch[0]
else:
# Fallback for unexpected batch structure
return batch
# DataLoader configuration compatible with Accelerate
self.train_loader = DataLoader(
self.train_dataset,
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = shuffle_setting,
num_workers = workers_setting,
pin_memory = True,
collate_fn = collate_fn
)
```
**Key Insight**: Our dataset's `__getitem__()` returns complete batches, but Accelerate expects individual samples. The solution is to use `batch_size=1` and a custom `collate_fn` that unwraps the pre-batched data.
## Next Steps
1. Implement correct even_batches=False in accelerator.prepare()
2. Test TPU training without overengineering
3. Verify memory usage with 8 cores configuration
1. ~~Implement even_batches=False~~ ✅ DONE
2. ~~Fix batch_sampler None issue~~ ✅ DONE
3. Test TPU training with complete solution
4. Integrate final solution into CLAUDE.md
## Lessons Learned
- Don't overcomplicate TPU conversion - it should be straightforward

View File

@@ -189,13 +189,24 @@ class BrainToTextDecoder_Trainer:
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Standard DataLoader configuration - let Accelerator handle device-specific optimizations
# Custom collate function that handles pre-batched data from our dataset
def collate_fn(batch):
# Our dataset returns full batches, so batch will be a list of single batch dict
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
if len(batch) == 1 and isinstance(batch[0], dict):
return batch[0]
else:
# Fallback for unexpected batch structure
return batch
# DataLoader configuration compatible with Accelerate
self.train_loader = DataLoader(
self.train_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = self.args['dataset']['num_dataloader_workers'],
pin_memory = True
pin_memory = True,
collate_fn = collate_fn
)
# val dataset and dataloader
@@ -209,13 +220,14 @@ class BrainToTextDecoder_Trainer:
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Standard validation DataLoader configuration
# Validation DataLoader with same collate function
self.val_loader = DataLoader(
self.val_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True
pin_memory = True,
collate_fn = collate_fn # Use same collate function
)
self.logger.info("Successfully initialized datasets")