Files
b2txt25/TPU_ISSUES_RECORD.md
Zchen 4dad570eea tpu
2025-10-12 21:47:30 +08:00

7.9 KiB
Raw Blame History

TPU Training Issues Record

Core Problem

Primary Error: ValueError: You need to use 'even_batches=False' when the batch sampler has no batch size

This error occurs when using TPU with Hugging Face Accelerate framework and custom DataLoaders that have batch_size=None.

Root Cause Analysis

  1. Our custom dataset returns full batches (not individual samples)
  2. DataLoader is created with batch_size=None because batching is handled by the dataset
  3. TPU training with Accelerate requires even_batches=False for this configuration
  4. The even_batches parameter needs to be set in the DataLoader preparation, not Accelerator initialization

Failed Solution Attempts

Attempt 1: Adding even_batches to Accelerator.init()

self.accelerator = Accelerator(
    mixed_precision='bf16',
    gradient_accumulation_steps=1,
    even_batches=False  # ❌ WRONG - This parameter doesn't exist in Accelerator.__init__()
)

Error: TypeError: Accelerator.__init__() got an unexpected keyword argument 'even_batches'

Attempt 2: Complex TPU-specific DataLoader handling

  • Created conditional TPU/GPU logic
  • Manual data movement with to(device)
  • Custom collate_fn modifications
  • Result: Overengineered solution that didn't address root cause

Attempt 3: Memory optimization

  • Reduced TPU cores from 8 to 2
  • Reduced batch size
  • Misunderstood TPU memory allocation (fewer cores = less total memory, not more per core)

Attempt 4: Removing all TPU-specific logic

  • Let Accelerator handle everything automatically
  • Result: Same even_batches error returned

Correct Solution

The even_batches=False parameter should be passed using DataLoaderConfiguration when initializing the Accelerator:

from accelerate import Accelerator, DataLoaderConfiguration

# Configure DataLoader behavior for TPU
dataloader_config = DataLoaderConfiguration(
    even_batches=False  # Required for batch_size=None DataLoaders
)

self.accelerator = Accelerator(
    mixed_precision='bf16' if args.get('use_amp', True) else 'no',
    gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
    log_with=None,
    project_dir=args.get('output_dir', './output'),
    dataloader_config=dataloader_config  # ✅ CORRECT - Pass DataLoaderConfiguration
)

Technical Context

  • Model: Brain-to-text RNN with 687M parameters
  • Dataset: Custom dataset that returns full batches (batch_size=None in DataLoader)
  • TPU Config: 8 cores × 16GB = 128GB total memory
  • Batch Size: 64
  • Framework: PyTorch XLA with Hugging Face Accelerate

Key Files Modified

  • model_training_nnn/rnn_trainer.py - Main trainer class
  • model_training_nnn/rnn_args.yaml - Configuration file
  • model_training_nnn/dataset.py - Custom dataset class

Memory Allocation Facts

  • TPU v5e-8: 8 cores × 16GB = 128GB total
  • Fewer cores = LESS total memory (not more per core)

Latest Status (2025-10-12)

After DataLoaderConfiguration Fix

even_batches Error RESOLVED - No more ValueError: You need to use 'even_batches=False'

NEW ERROR: TypeError: 'NoneType' object is not iterable

File "/usr/local/lib/python3.12/site-packages/accelerate/data_loader.py", line 221, in _iter_with_no_split
    for idx, batch in enumerate(self.batch_sampler):
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'NoneType' object is not iterable

Root Cause: batch_sampler becomes None when our DataLoader has batch_size=None

Current Investigation

  • The issue is in Accelerate's data_loader.py line 221
  • Our custom dataset returns full batches, so we use batch_size=None in DataLoader
  • But Accelerate expects a proper batch_sampler when iterating
  • This is a fundamental incompatibility between our batching approach and Accelerate's expectations

COMPREHENSIVE SOLUTION (v2.0)

Problem Resolution Status

  1. even_batches Error RESOLVED with DataLoaderConfiguration
  2. batch_sampler None Error RESOLVED with custom collate_fn
  3. Data Type Mismatch Error RESOLVED - Fixed both input conversion and padding dtype preservation

Latest Error (2025-10-12 13:38)

INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[64,7168], argument shape: bf16[64,7168].

Root Cause: Mixed precision training with mixed_precision='bf16' expects all tensors to be bf16, but our data is being loaded as f32 (float32).

Analysis:

  • We enabled bf16 mixed precision in Accelerator configuration
  • Model parameters are automatically converted to bf16
  • But input data remains as f32, causing type mismatch during forward pass
  • TPU XLA compiler is strict about type matching

Solution: Comprehensive Data Type Conversion in Dataset

Fixed in dataset.py with two changes:

1. Convert input data to bf16 (line 130):

# Before (causes type mismatch):
input_features = torch.from_numpy(g['input_features'][:]) # defaults to f32

# After (TPU compatible):
input_features = torch.from_numpy(g['input_features'][:]).to(torch.bfloat16) # convert to bf16 for TPU compatibility

2. Preserve bf16 dtype after padding (line 149):

# Before (pad_sequence converts back to f32):
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)

# After (explicitly maintain bf16):
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0).to(torch.bfloat16)

Root Cause: pad_sequence function resets dtype to default (f32) even if input tensors are bf16.

Final Implementation

# In rnn_trainer.py prepare_dataloaders()

# Custom collate function that handles pre-batched data from our dataset
def collate_fn(batch):
    # Our dataset returns full batches, so batch will be a list of single batch dict
    # Extract the first (and only) element since our dataset.__getitem__() returns a full batch
    if len(batch) == 1 and isinstance(batch[0], dict):
        return batch[0]
    else:
        # Fallback for unexpected batch structure
        return batch

# DataLoader configuration compatible with Accelerate
self.train_loader = DataLoader(
    self.train_dataset,
    batch_size = 1,  # Use batch_size=1 since dataset returns full batches
    shuffle = shuffle_setting,
    num_workers = workers_setting,
    pin_memory = True,
    collate_fn = collate_fn
)

Key Insight: Our dataset's __getitem__() returns complete batches, but Accelerate expects individual samples. The solution is to use batch_size=1 and a custom collate_fn that unwraps the pre-batched data.

Complete Solution Summary

Three-Step Fix for TPU Training

  1. DataLoaderConfiguration: Added even_batches=False for batch_size=1 DataLoaders
  2. Custom collate_fn: Handles pre-batched data from our dataset
  3. Data Type Conversion: Convert input data to bf16 for mixed precision compatibility

Files Modified

Next Steps

  1. Implement even_batches=False DONE
  2. Fix batch_sampler None issue DONE
  3. Fix data type mismatch DONE
  4. Test TPU training with complete solution
  5. Integrate final solution into CLAUDE.md

Lessons Learned

  • Don't overcomplicate TPU conversion - it should be straightforward
  • Read Accelerate documentation carefully for parameter placement
  • Document issues immediately to avoid confusion
  • TPU memory allocation: fewer cores = less total memory