204 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
			
		
		
	
	
			204 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
| # TPU Training Setup Guide for Brain-to-Text RNN
 | ||
| 
 | ||
| This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.
 | ||
| 
 | ||
| ## Prerequisites
 | ||
| 
 | ||
| ### 1. Install PyTorch XLA for TPU Support
 | ||
| ```bash
 | ||
| # Install PyTorch XLA (adjust version as needed)
 | ||
| pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
 | ||
| 
 | ||
| # Or for specific PyTorch version:
 | ||
| pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
 | ||
| ```
 | ||
| 
 | ||
| ### 2. Install Accelerate Library
 | ||
| ```bash
 | ||
| pip install accelerate
 | ||
| ```
 | ||
| 
 | ||
| ### 3. Verify TPU Access
 | ||
| ```bash
 | ||
| # Check if TPU is available
 | ||
| python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"
 | ||
| ```
 | ||
| 
 | ||
| ## Configuration Setup
 | ||
| 
 | ||
| ### 1. Enable TPU in Configuration File
 | ||
| 
 | ||
| Update your `rnn_args.yaml` file with TPU settings:
 | ||
| 
 | ||
| ```yaml
 | ||
| # TPU and distributed training settings
 | ||
| use_tpu: true                        # Enable TPU training
 | ||
| num_tpu_cores: 8                     # Number of TPU cores (8 for v3-8 or v4-8)
 | ||
| gradient_accumulation_steps: 1       # Gradient accumulation for large effective batch size
 | ||
| dataloader_num_workers: 0           # Must be 0 for TPU to avoid multiprocessing issues
 | ||
| use_amp: true                       # Enable mixed precision (bfloat16)
 | ||
| 
 | ||
| # Adjust batch size for multi-core TPU
 | ||
| dataset:
 | ||
|   batch_size: 8                     # Per-core batch size (total = 8 cores × 8 = 64)
 | ||
| ```
 | ||
| 
 | ||
| ### 2. TPU-Optimized Hyperparameters
 | ||
| 
 | ||
| Recommended adjustments for TPU training:
 | ||
| 
 | ||
| ```yaml
 | ||
| # Learning rate scaling for distributed training
 | ||
| lr_max: 0.005                       # May need to scale with number of cores
 | ||
| lr_max_day: 0.005
 | ||
| 
 | ||
| # Batch size considerations
 | ||
| dataset:
 | ||
|   batch_size: 8                     # Per-core batch size
 | ||
|   days_per_batch: 4                 # Keep consistent across cores
 | ||
| ```
 | ||
| 
 | ||
| ## Training Launch Options
 | ||
| 
 | ||
| ### Method 1: Using the TPU Launch Script (Recommended)
 | ||
| 
 | ||
| ```bash
 | ||
| # Basic TPU training with 8 cores
 | ||
| python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
 | ||
| 
 | ||
| # Check TPU environment only
 | ||
| python launch_tpu_training.py --check_only
 | ||
| 
 | ||
| # Custom configuration file
 | ||
| python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8
 | ||
| ```
 | ||
| 
 | ||
| ### Method 2: Direct Accelerate Launch
 | ||
| 
 | ||
| ```bash
 | ||
| # Configure accelerate (one-time setup)
 | ||
| accelerate config
 | ||
| 
 | ||
| # Or use provided TPU config
 | ||
| export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml
 | ||
| 
 | ||
| # Launch training
 | ||
| accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml
 | ||
| ```
 | ||
| 
 | ||
| ### Method 3: Manual XLA Launch (Advanced)
 | ||
| 
 | ||
| ```bash
 | ||
| # Set TPU environment variables
 | ||
| export TPU_CORES=8
 | ||
| export XLA_USE_BF16=1
 | ||
| 
 | ||
| # Launch with PyTorch XLA
 | ||
| python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml
 | ||
| ```
 | ||
| 
 | ||
| ## Key TPU Features Implemented
 | ||
| 
 | ||
| ### 1. Distributed Training Support
 | ||
| - Automatic model parallelization across 8 TPU cores
 | ||
| - Synchronized gradient updates across all cores
 | ||
| - Proper checkpoint saving/loading for distributed training
 | ||
| 
 | ||
| ### 2. Mixed Precision Training
 | ||
| - Automatic bfloat16 precision for TPU optimization
 | ||
| - Faster training with maintained numerical stability
 | ||
| - Reduced memory usage
 | ||
| 
 | ||
| ### 3. TPU-Optimized Data Loading
 | ||
| - Single-threaded data loading (num_workers=0) for TPU compatibility
 | ||
| - Automatic data distribution across TPU cores
 | ||
| - Efficient batch processing
 | ||
| 
 | ||
| ### 4. Inference Support
 | ||
| - TPU-compatible inference methods added to trainer class
 | ||
| - `inference()` and `inference_batch()` methods for production use
 | ||
| - Automatic mixed precision during inference
 | ||
| 
 | ||
| ## Performance Optimization Tips
 | ||
| 
 | ||
| ### 1. Batch Size Tuning
 | ||
| - Start with total batch size = 64 (8 cores × 8 per core)
 | ||
| - Increase gradually if memory allows
 | ||
| - Monitor TPU utilization with `top` command
 | ||
| 
 | ||
| ### 2. Gradient Accumulation
 | ||
| - Use `gradient_accumulation_steps` to simulate larger batch sizes
 | ||
| - Effective batch size = batch_size × num_cores × gradient_accumulation_steps
 | ||
| 
 | ||
| ### 3. Learning Rate Scaling
 | ||
| - Consider scaling learning rate with number of cores
 | ||
| - Linear scaling: `lr_new = lr_base × num_cores`
 | ||
| - May need warmup adjustment for large batch training
 | ||
| 
 | ||
| ### 4. Memory Management
 | ||
| - TPU v3-8: 128GB HBM memory total
 | ||
| - TPU v4-8: 512GB HBM memory total
 | ||
| - Monitor memory usage to avoid OOM errors
 | ||
| 
 | ||
| ## Monitoring and Debugging
 | ||
| 
 | ||
| ### 1. TPU Utilization
 | ||
| ```bash
 | ||
| # Monitor TPU usage
 | ||
| watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'
 | ||
| ```
 | ||
| 
 | ||
| ### 2. Training Logs
 | ||
| - Training logs include device information and core count
 | ||
| - Monitor validation metrics across all cores
 | ||
| - Check for synchronization issues in distributed training
 | ||
| 
 | ||
| ### 3. Common Issues and Solutions
 | ||
| 
 | ||
| **Issue**: "No TPU devices found"
 | ||
| - **Solution**: Verify TPU runtime is started and accessible
 | ||
| 
 | ||
| **Issue**: "DataLoader workers > 0 causes hangs"
 | ||
| - **Solution**: Set `dataloader_num_workers: 0` in config
 | ||
| 
 | ||
| **Issue**: "Mixed precision errors"
 | ||
| - **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16
 | ||
| 
 | ||
| **Issue**: "Gradient synchronization timeouts"
 | ||
| - **Solution**: Check network connectivity between TPU cores
 | ||
| 
 | ||
| ## Example Training Command
 | ||
| 
 | ||
| ```bash
 | ||
| # Complete TPU training example
 | ||
| cd model_training_nnn
 | ||
| 
 | ||
| # 1. Update config for TPU
 | ||
| vim rnn_args.yaml  # Set use_tpu: true, num_tpu_cores: 8
 | ||
| 
 | ||
| # 2. Launch TPU training
 | ||
| python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
 | ||
| 
 | ||
| # 3. Monitor training progress
 | ||
| tail -f trained_models/baseline_rnn/training_log
 | ||
| ```
 | ||
| 
 | ||
| ## Configuration Reference
 | ||
| 
 | ||
| ### Required TPU Settings
 | ||
| ```yaml
 | ||
| use_tpu: true
 | ||
| num_tpu_cores: 8
 | ||
| dataloader_num_workers: 0
 | ||
| use_amp: true
 | ||
| ```
 | ||
| 
 | ||
| ### Optional TPU Optimizations
 | ||
| ```yaml
 | ||
| gradient_accumulation_steps: 1
 | ||
| dataset:
 | ||
|   batch_size: 8  # Per-core batch size
 | ||
| mixed_precision: bf16
 | ||
| ```
 | ||
| 
 | ||
| This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library. | 
