# TPU Training Setup Guide for Brain-to-Text RNN This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code. ## Prerequisites ### 1. Install PyTorch XLA for TPU Support ```bash # Install PyTorch XLA (adjust version as needed) pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html # Or for specific PyTorch version: pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html ``` ### 2. Install Accelerate Library ```bash pip install accelerate ``` ### 3. Verify TPU Access ```bash # Check if TPU is available python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')" ``` ## Configuration Setup ### 1. Enable TPU in Configuration File Update your `rnn_args.yaml` file with TPU settings: ```yaml # TPU and distributed training settings use_tpu: true # Enable TPU training num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8) gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues use_amp: true # Enable mixed precision (bfloat16) # Adjust batch size for multi-core TPU dataset: batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64) ``` ### 2. TPU-Optimized Hyperparameters Recommended adjustments for TPU training: ```yaml # Learning rate scaling for distributed training lr_max: 0.005 # May need to scale with number of cores lr_max_day: 0.005 # Batch size considerations dataset: batch_size: 8 # Per-core batch size days_per_batch: 4 # Keep consistent across cores ``` ## Training Launch Options ### Method 1: Using the TPU Launch Script (Recommended) ```bash # Basic TPU training with 8 cores python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 # Check TPU environment only python launch_tpu_training.py --check_only # Custom configuration file python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8 ``` ### Method 2: Direct Accelerate Launch ```bash # Configure accelerate (one-time setup) accelerate config # Or use provided TPU config export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml # Launch training accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml ``` ### Method 3: Manual XLA Launch (Advanced) ```bash # Set TPU environment variables export TPU_CORES=8 export XLA_USE_BF16=1 # Launch with PyTorch XLA python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml ``` ## Key TPU Features Implemented ### 1. Distributed Training Support - Automatic model parallelization across 8 TPU cores - Synchronized gradient updates across all cores - Proper checkpoint saving/loading for distributed training ### 2. Mixed Precision Training - Automatic bfloat16 precision for TPU optimization - Faster training with maintained numerical stability - Reduced memory usage ### 3. TPU-Optimized Data Loading - Single-threaded data loading (num_workers=0) for TPU compatibility - Automatic data distribution across TPU cores - Efficient batch processing ### 4. Inference Support - TPU-compatible inference methods added to trainer class - `inference()` and `inference_batch()` methods for production use - Automatic mixed precision during inference ## Performance Optimization Tips ### 1. Batch Size Tuning - Start with total batch size = 64 (8 cores × 8 per core) - Increase gradually if memory allows - Monitor TPU utilization with `top` command ### 2. Gradient Accumulation - Use `gradient_accumulation_steps` to simulate larger batch sizes - Effective batch size = batch_size × num_cores × gradient_accumulation_steps ### 3. Learning Rate Scaling - Consider scaling learning rate with number of cores - Linear scaling: `lr_new = lr_base × num_cores` - May need warmup adjustment for large batch training ### 4. Memory Management - TPU v3-8: 128GB HBM memory total - TPU v4-8: 512GB HBM memory total - Monitor memory usage to avoid OOM errors ## Monitoring and Debugging ### 1. TPU Utilization ```bash # Monitor TPU usage watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"' ``` ### 2. Training Logs - Training logs include device information and core count - Monitor validation metrics across all cores - Check for synchronization issues in distributed training ### 3. Common Issues and Solutions **Issue**: "No TPU devices found" - **Solution**: Verify TPU runtime is started and accessible **Issue**: "DataLoader workers > 0 causes hangs" - **Solution**: Set `dataloader_num_workers: 0` in config **Issue**: "Mixed precision errors" - **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16 **Issue**: "Gradient synchronization timeouts" - **Solution**: Check network connectivity between TPU cores ## Example Training Command ```bash # Complete TPU training example cd model_training_nnn # 1. Update config for TPU vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8 # 2. Launch TPU training python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 # 3. Monitor training progress tail -f trained_models/baseline_rnn/training_log ``` ## Configuration Reference ### Required TPU Settings ```yaml use_tpu: true num_tpu_cores: 8 dataloader_num_workers: 0 use_amp: true ``` ### Optional TPU Optimizations ```yaml gradient_accumulation_steps: 1 dataset: batch_size: 8 # Per-core batch size mixed_precision: bf16 ``` This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.