This commit is contained in:
Zchen
2025-10-12 21:47:30 +08:00
parent dfb3f7312c
commit 4dad570eea
2 changed files with 20 additions and 6 deletions

View File

@@ -95,12 +95,12 @@ TypeError: 'NoneType' object is not iterable
- But Accelerate expects a proper batch_sampler when iterating
- This is a fundamental incompatibility between our batching approach and Accelerate's expectations
## COMPREHENSIVE SOLUTION ✅
## COMPREHENSIVE SOLUTION ✅ (v2.0)
### Problem Resolution Status
1. ~~even_batches Error~~ ✅ RESOLVED with DataLoaderConfiguration
2. ~~batch_sampler None Error~~ ✅ RESOLVED with custom collate_fn
3. ~~Data Type Mismatch Error~~ ✅ RESOLVED with bf16 conversion in dataset
3. ~~Data Type Mismatch Error~~ ✅ RESOLVED - Fixed both input conversion and padding dtype preservation
### Latest Error (2025-10-12 13:38)
```
@@ -115,8 +115,10 @@ INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32
- But input data remains as `f32`, causing type mismatch during forward pass
- TPU XLA compiler is strict about type matching
### Solution: Data Type Conversion in Dataset
Fixed in `dataset.py:130` by converting neural data to `bf16`:
### Solution: Comprehensive Data Type Conversion in Dataset
Fixed in `dataset.py` with two changes:
**1. Convert input data to bf16 (line 130):**
```python
# Before (causes type mismatch):
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
@@ -125,6 +127,17 @@ input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility
```
**2. Preserve bf16 dtype after padding (line 149):**
```python
# Before (pad_sequence converts back to f32):
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
# After (explicitly maintain bf16):
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
```
**Root Cause**: `pad_sequence` function resets dtype to default (f32) even if input tensors are bf16.
### Final Implementation
```python
# In rnn_trainer.py prepare_dataloaders()
@@ -163,6 +176,7 @@ self.train_loader = DataLoader(
- [rnn_trainer.py:44-46](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L44-L46): Added DataLoaderConfiguration
- [rnn_trainer.py:193-210](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\rnn_trainer.py#L193-L210): Custom collate_fn and batch_size=1
- [dataset.py:130](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L130): Convert neural data to bf16
- [dataset.py:149](f:\BRAIN-TO-TEXT\nejm-brain-to-text.worktrees\dev2\model_training_nnn\dataset.py#L149): Preserve bf16 dtype after padding
### Next Steps
1. ~~Implement even_batches=False~~ ✅ DONE

View File

@@ -145,8 +145,8 @@ class BrainToTextDataset(Dataset):
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
continue
# Pad data to form a cohesive batch
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
# Pad data to form a cohesive batch - ensure bf16 dtype is preserved
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])