diff --git a/TPU_ISSUES_RECORD.md b/TPU_ISSUES_RECORD.md new file mode 100644 index 0000000..ab09f80 --- /dev/null +++ b/TPU_ISSUES_RECORD.md @@ -0,0 +1,86 @@ +# TPU Training Issues Record + +## Core Problem +**Primary Error**: `ValueError: You need to use 'even_batches=False' when the batch sampler has no batch size` + +This error occurs when using TPU with Hugging Face Accelerate framework and custom DataLoaders that have `batch_size=None`. + +## Root Cause Analysis +1. Our custom dataset returns full batches (not individual samples) +2. DataLoader is created with `batch_size=None` because batching is handled by the dataset +3. TPU training with Accelerate requires `even_batches=False` for this configuration +4. The `even_batches` parameter needs to be set in the DataLoader preparation, not Accelerator initialization + +## Failed Solution Attempts + +### Attempt 1: Adding even_batches to Accelerator.__init__() +```python +self.accelerator = Accelerator( + mixed_precision='bf16', + gradient_accumulation_steps=1, + even_batches=False # ❌ WRONG - This parameter doesn't exist in Accelerator.__init__() +) +``` +**Error**: `TypeError: Accelerator.__init__() got an unexpected keyword argument 'even_batches'` + +### Attempt 2: Complex TPU-specific DataLoader handling +- Created conditional TPU/GPU logic +- Manual data movement with `to(device)` +- Custom collate_fn modifications +- Result: Overengineered solution that didn't address root cause + +### Attempt 3: Memory optimization +- Reduced TPU cores from 8 to 2 +- Reduced batch size +- Misunderstood TPU memory allocation (fewer cores = less total memory, not more per core) + +### Attempt 4: Removing all TPU-specific logic +- Let Accelerator handle everything automatically +- Result: Same even_batches error returned + +## Correct Solution +The `even_batches=False` parameter should be passed using `DataLoaderConfiguration` when initializing the Accelerator: + +```python +from accelerate import Accelerator, DataLoaderConfiguration + +# Configure DataLoader behavior for TPU +dataloader_config = DataLoaderConfiguration( + even_batches=False # Required for batch_size=None DataLoaders +) + +self.accelerator = Accelerator( + mixed_precision='bf16' if args.get('use_amp', True) else 'no', + gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), + log_with=None, + project_dir=args.get('output_dir', './output'), + dataloader_config=dataloader_config # ✅ CORRECT - Pass DataLoaderConfiguration +) +``` + +## Technical Context +- **Model**: Brain-to-text RNN with 687M parameters +- **Dataset**: Custom dataset that returns full batches (batch_size=None in DataLoader) +- **TPU Config**: 8 cores × 16GB = 128GB total memory +- **Batch Size**: 64 +- **Framework**: PyTorch XLA with Hugging Face Accelerate + +## Key Files Modified +- `model_training_nnn/rnn_trainer.py` - Main trainer class +- `model_training_nnn/rnn_args.yaml` - Configuration file +- `model_training_nnn/dataset.py` - Custom dataset class + +## Memory Allocation Facts +- TPU v5e-8: 8 cores × 16GB = 128GB total +- Fewer cores = LESS total memory (not more per core) + +## Next Steps +1. Implement correct even_batches=False in accelerator.prepare() +2. Test TPU training without overengineering +3. Verify memory usage with 8 cores configuration + +## Lessons Learned +- Don't overcomplicate TPU conversion - it should be straightforward +- Read Accelerate documentation carefully for parameter placement +- Document issues immediately to avoid confusion +- TPU memory allocation: fewer cores = less total memory \ No newline at end of file diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 9a99998..5db7a28 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -19,7 +19,7 @@ import torchaudio.functional as F # for edit distance from omegaconf import OmegaConf # Import Accelerate for TPU support -from accelerate import Accelerator +from accelerate import Accelerator, DataLoaderConfiguration from accelerate.utils import set_seed torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs @@ -40,12 +40,18 @@ class BrainToTextDecoder_Trainer: args : dictionary of training arguments ''' + # Configure DataLoader behavior for TPU compatibility + dataloader_config = DataLoaderConfiguration( + even_batches=False # Required for batch_size=None DataLoaders on TPU + ) + # Initialize Accelerator for TPU/multi-device support self.accelerator = Accelerator( mixed_precision='bf16' if args.get('use_amp', True) else 'no', gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), log_with=None, # We'll use our own logging project_dir=args.get('output_dir', './output'), + dataloader_config=dataloader_config, )