192 lines
7.9 KiB
Markdown
192 lines
7.9 KiB
Markdown
# TPU Training Issues Record
|
||
|
||
## Core Problem
|
||
**Primary Error**: `ValueError: You need to use 'even_batches=False' when the batch sampler has no batch size`
|
||
|
||
This error occurs when using TPU with Hugging Face Accelerate framework and custom DataLoaders that have `batch_size=None`.
|
||
|
||
## Root Cause Analysis
|
||
1. Our custom dataset returns full batches (not individual samples)
|
||
2. DataLoader is created with `batch_size=None` because batching is handled by the dataset
|
||
3. TPU training with Accelerate requires `even_batches=False` for this configuration
|
||
4. The `even_batches` parameter needs to be set in the DataLoader preparation, not Accelerator initialization
|
||
|
||
## Failed Solution Attempts
|
||
|
||
### Attempt 1: Adding even_batches to Accelerator.__init__()
|
||
```python
|
||
self.accelerator = Accelerator(
|
||
mixed_precision='bf16',
|
||
gradient_accumulation_steps=1,
|
||
even_batches=False # ❌ WRONG - This parameter doesn't exist in Accelerator.__init__()
|
||
)
|
||
```
|
||
**Error**: `TypeError: Accelerator.__init__() got an unexpected keyword argument 'even_batches'`
|
||
|
||
### Attempt 2: Complex TPU-specific DataLoader handling
|
||
- Created conditional TPU/GPU logic
|
||
- Manual data movement with `to(device)`
|
||
- Custom collate_fn modifications
|
||
- Result: Overengineered solution that didn't address root cause
|
||
|
||
### Attempt 3: Memory optimization
|
||
- Reduced TPU cores from 8 to 2
|
||
- Reduced batch size
|
||
- Misunderstood TPU memory allocation (fewer cores = less total memory, not more per core)
|
||
|
||
### Attempt 4: Removing all TPU-specific logic
|
||
- Let Accelerator handle everything automatically
|
||
- Result: Same even_batches error returned
|
||
|
||
## Correct Solution
|
||
The `even_batches=False` parameter should be passed using `DataLoaderConfiguration` when initializing the Accelerator:
|
||
|
||
```python
|
||
from accelerate import Accelerator, DataLoaderConfiguration
|
||
|
||
# Configure DataLoader behavior for TPU
|
||
dataloader_config = DataLoaderConfiguration(
|
||
even_batches=False # Required for batch_size=None DataLoaders
|
||
)
|
||
|
||
self.accelerator = Accelerator(
|
||
mixed_precision='bf16' if args.get('use_amp', True) else 'no',
|
||
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
|
||
log_with=None,
|
||
project_dir=args.get('output_dir', './output'),
|
||
dataloader_config=dataloader_config # ✅ CORRECT - Pass DataLoaderConfiguration
|
||
)
|
||
```
|
||
|
||
## Technical Context
|
||
- **Model**: Brain-to-text RNN with 687M parameters
|
||
- **Dataset**: Custom dataset that returns full batches (batch_size=None in DataLoader)
|
||
- **TPU Config**: 8 cores × 16GB = 128GB total memory
|
||
- **Batch Size**: 64
|
||
- **Framework**: PyTorch XLA with Hugging Face Accelerate
|
||
|
||
## Key Files Modified
|
||
- `model_training_nnn/rnn_trainer.py` - Main trainer class
|
||
- `model_training_nnn/rnn_args.yaml` - Configuration file
|
||
- `model_training_nnn/dataset.py` - Custom dataset class
|
||
|
||
## Memory Allocation Facts
|
||
- TPU v5e-8: 8 cores × 16GB = 128GB total
|
||
- Fewer cores = LESS total memory (not more per core)
|
||
|
||
## Latest Status (2025-10-12)
|
||
|
||
### After DataLoaderConfiguration Fix
|
||
✅ **even_batches Error RESOLVED** - No more `ValueError: You need to use 'even_batches=False'`
|
||
|
||
❌ **NEW ERROR**: `TypeError: 'NoneType' object is not iterable`
|
||
```
|
||
File "/usr/local/lib/python3.12/site-packages/accelerate/data_loader.py", line 221, in _iter_with_no_split
|
||
for idx, batch in enumerate(self.batch_sampler):
|
||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||
TypeError: 'NoneType' object is not iterable
|
||
```
|
||
|
||
**Root Cause**: `batch_sampler` becomes `None` when our DataLoader has `batch_size=None`
|
||
|
||
### Current Investigation
|
||
- The issue is in Accelerate's data_loader.py line 221
|
||
- Our custom dataset returns full batches, so we use `batch_size=None` in DataLoader
|
||
- But Accelerate expects a proper batch_sampler when iterating
|
||
- This is a fundamental incompatibility between our batching approach and Accelerate's expectations
|
||
|
||
## COMPREHENSIVE SOLUTION ✅ (v2.0)
|
||
|
||
### Problem Resolution Status
|
||
1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration
|
||
2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn
|
||
3. ~~Data Type Mismatch Error~~ ✅ RESOLVED - Fixed both input conversion and padding dtype preservation
|
||
|
||
### Latest Error (2025-10-12 13:38)
|
||
```
|
||
INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[64,7168], argument shape: bf16[64,7168].
|
||
```
|
||
|
||
**Root Cause**: Mixed precision training with `mixed_precision='bf16'` expects all tensors to be `bf16`, but our data is being loaded as `f32` (float32).
|
||
|
||
**Analysis**:
|
||
- We enabled `bf16` mixed precision in Accelerator configuration
|
||
- Model parameters are automatically converted to `bf16`
|
||
- But input data remains as `f32`, causing type mismatch during forward pass
|
||
- TPU XLA compiler is strict about type matching
|
||
|
||
### Solution: Comprehensive Data Type Conversion in Dataset
|
||
Fixed in `dataset.py` with two changes:
|
||
|
||
**1. Convert input data to bf16 (line 130):**
|
||
```python
|
||
# Before (causes type mismatch):
|
||
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
|
||
|
||
# After (TPU compatible):
|
||
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility
|
||
```
|
||
|
||
**2. Preserve bf16 dtype after padding (line 149):**
|
||
```python
|
||
# Before (pad_sequence converts back to f32):
|
||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
|
||
|
||
# After (explicitly maintain bf16):
|
||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
|
||
```
|
||
|
||
**Root Cause**: `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16.
|
||
|
||
### Final Implementation
|
||
```python
|
||
# In rnn_trainer.py prepare_dataloaders()
|
||
|
||
# Custom collate function that handles pre-batched data from our dataset
|
||
def collate_fn(batch):
|
||
# Our dataset returns full batches, so batch will be a list of single batch dict
|
||
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
|
||
if len(batch) == 1 and isinstance(batch[0], dict):
|
||
return batch[0]
|
||
else:
|
||
# Fallback for unexpected batch structure
|
||
return batch
|
||
|
||
# DataLoader configuration compatible with Accelerate
|
||
self.train_loader = DataLoader(
|
||
self.train_dataset,
|
||
batch_size = 1, # Use batch_size=1 since dataset returns full batches
|
||
shuffle = shuffle_setting,
|
||
num_workers = workers_setting,
|
||
pin_memory = True,
|
||
collate_fn = collate_fn
|
||
)
|
||
```
|
||
|
||
**Key Insight**: Our dataset's `__getitem__()` returns complete batches, but Accelerate expects individual samples. The solution is to use `batch_size=1` and a custom `collate_fn` that unwraps the pre-batched data.
|
||
|
||
## Complete Solution Summary
|
||
|
||
### Three-Step Fix for TPU Training
|
||
1. **DataLoaderConfiguration**: Added `even_batches=False` for batch_size=1 DataLoaders
|
||
2. **Custom collate_fn**: Handles pre-batched data from our dataset
|
||
3. **Data Type Conversion**: Convert input data to `bf16` for mixed precision compatibility
|
||
|
||
### Files Modified
|
||
- [rnn_trainer.py:44-46](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L44-L46): Added DataLoaderConfiguration
|
||
- [rnn_trainer.py:193-210](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L193-L210): Custom collate_fn and batch_size=1
|
||
- [dataset.py:130](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L130): Convert neural data to bf16
|
||
- [dataset.py:149](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L149): Preserve bf16 dtype after padding
|
||
|
||
### Next Steps
|
||
1. ~~Implement even_batches=False~~ ✅ DONE
|
||
2. ~~Fix batch_sampler None issue~~ ✅ DONE
|
||
3. ~~Fix data type mismatch~~ ✅ DONE
|
||
4. Test TPU training with complete solution
|
||
5. Integrate final solution into CLAUDE.md
|
||
|
||
## Lessons Learned
|
||
- Don't overcomplicate TPU conversion - it should be straightforward
|
||
- Read Accelerate documentation carefully for parameter placement
|
||
- Document issues immediately to avoid confusion
|
||
- TPU memory allocation: fewer cores = less total memory |