204 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
		
		
			
		
	
	
			204 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
|   | # TPU Training Setup Guide for Brain-to-Text RNN
 | |||
|  | 
 | |||
|  | This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code. | |||
|  | 
 | |||
|  | ## Prerequisites
 | |||
|  | 
 | |||
|  | ### 1. Install PyTorch XLA for TPU Support
 | |||
|  | ```bash | |||
|  | # Install PyTorch XLA (adjust version as needed)
 | |||
|  | pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html | |||
|  | 
 | |||
|  | # Or for specific PyTorch version:
 | |||
|  | pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html | |||
|  | ``` | |||
|  | 
 | |||
|  | ### 2. Install Accelerate Library
 | |||
|  | ```bash | |||
|  | pip install accelerate | |||
|  | ``` | |||
|  | 
 | |||
|  | ### 3. Verify TPU Access
 | |||
|  | ```bash | |||
|  | # Check if TPU is available
 | |||
|  | python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')" | |||
|  | ``` | |||
|  | 
 | |||
|  | ## Configuration Setup
 | |||
|  | 
 | |||
|  | ### 1. Enable TPU in Configuration File
 | |||
|  | 
 | |||
|  | Update your `rnn_args.yaml` file with TPU settings: | |||
|  | 
 | |||
|  | ```yaml | |||
|  | # TPU and distributed training settings
 | |||
|  | use_tpu: true                        # Enable TPU training | |||
|  | num_tpu_cores: 8                     # Number of TPU cores (8 for v3-8 or v4-8) | |||
|  | gradient_accumulation_steps: 1       # Gradient accumulation for large effective batch size | |||
|  | dataloader_num_workers: 0           # Must be 0 for TPU to avoid multiprocessing issues | |||
|  | use_amp: true                       # Enable mixed precision (bfloat16) | |||
|  | 
 | |||
|  | # Adjust batch size for multi-core TPU
 | |||
|  | dataset: | |||
|  |   batch_size: 8                     # Per-core batch size (total = 8 cores × 8 = 64) | |||
|  | ``` | |||
|  | 
 | |||
|  | ### 2. TPU-Optimized Hyperparameters
 | |||
|  | 
 | |||
|  | Recommended adjustments for TPU training: | |||
|  | 
 | |||
|  | ```yaml | |||
|  | # Learning rate scaling for distributed training
 | |||
|  | lr_max: 0.005                       # May need to scale with number of cores | |||
|  | lr_max_day: 0.005 | |||
|  | 
 | |||
|  | # Batch size considerations
 | |||
|  | dataset: | |||
|  |   batch_size: 8                     # Per-core batch size | |||
|  |   days_per_batch: 4                 # Keep consistent across cores | |||
|  | ``` | |||
|  | 
 | |||
|  | ## Training Launch Options
 | |||
|  | 
 | |||
|  | ### Method 1: Using the TPU Launch Script (Recommended)
 | |||
|  | 
 | |||
|  | ```bash | |||
|  | # Basic TPU training with 8 cores
 | |||
|  | python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 | |||
|  | 
 | |||
|  | # Check TPU environment only
 | |||
|  | python launch_tpu_training.py --check_only | |||
|  | 
 | |||
|  | # Custom configuration file
 | |||
|  | python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8 | |||
|  | ``` | |||
|  | 
 | |||
|  | ### Method 2: Direct Accelerate Launch
 | |||
|  | 
 | |||
|  | ```bash | |||
|  | # Configure accelerate (one-time setup)
 | |||
|  | accelerate config | |||
|  | 
 | |||
|  | # Or use provided TPU config
 | |||
|  | export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml | |||
|  | 
 | |||
|  | # Launch training
 | |||
|  | accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml | |||
|  | ``` | |||
|  | 
 | |||
|  | ### Method 3: Manual XLA Launch (Advanced)
 | |||
|  | 
 | |||
|  | ```bash | |||
|  | # Set TPU environment variables
 | |||
|  | export TPU_CORES=8 | |||
|  | export XLA_USE_BF16=1 | |||
|  | 
 | |||
|  | # Launch with PyTorch XLA
 | |||
|  | python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml | |||
|  | ``` | |||
|  | 
 | |||
|  | ## Key TPU Features Implemented
 | |||
|  | 
 | |||
|  | ### 1. Distributed Training Support
 | |||
|  | - Automatic model parallelization across 8 TPU cores | |||
|  | - Synchronized gradient updates across all cores | |||
|  | - Proper checkpoint saving/loading for distributed training | |||
|  | 
 | |||
|  | ### 2. Mixed Precision Training
 | |||
|  | - Automatic bfloat16 precision for TPU optimization | |||
|  | - Faster training with maintained numerical stability | |||
|  | - Reduced memory usage | |||
|  | 
 | |||
|  | ### 3. TPU-Optimized Data Loading
 | |||
|  | - Single-threaded data loading (num_workers=0) for TPU compatibility | |||
|  | - Automatic data distribution across TPU cores | |||
|  | - Efficient batch processing | |||
|  | 
 | |||
|  | ### 4. Inference Support
 | |||
|  | - TPU-compatible inference methods added to trainer class | |||
|  | - `inference()` and `inference_batch()` methods for production use | |||
|  | - Automatic mixed precision during inference | |||
|  | 
 | |||
|  | ## Performance Optimization Tips
 | |||
|  | 
 | |||
|  | ### 1. Batch Size Tuning
 | |||
|  | - Start with total batch size = 64 (8 cores × 8 per core) | |||
|  | - Increase gradually if memory allows | |||
|  | - Monitor TPU utilization with `top` command | |||
|  | 
 | |||
|  | ### 2. Gradient Accumulation
 | |||
|  | - Use `gradient_accumulation_steps` to simulate larger batch sizes | |||
|  | - Effective batch size = batch_size × num_cores × gradient_accumulation_steps | |||
|  | 
 | |||
|  | ### 3. Learning Rate Scaling
 | |||
|  | - Consider scaling learning rate with number of cores | |||
|  | - Linear scaling: `lr_new = lr_base × num_cores` | |||
|  | - May need warmup adjustment for large batch training | |||
|  | 
 | |||
|  | ### 4. Memory Management
 | |||
|  | - TPU v3-8: 128GB HBM memory total | |||
|  | - TPU v4-8: 512GB HBM memory total | |||
|  | - Monitor memory usage to avoid OOM errors | |||
|  | 
 | |||
|  | ## Monitoring and Debugging
 | |||
|  | 
 | |||
|  | ### 1. TPU Utilization
 | |||
|  | ```bash | |||
|  | # Monitor TPU usage
 | |||
|  | watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"' | |||
|  | ``` | |||
|  | 
 | |||
|  | ### 2. Training Logs
 | |||
|  | - Training logs include device information and core count | |||
|  | - Monitor validation metrics across all cores | |||
|  | - Check for synchronization issues in distributed training | |||
|  | 
 | |||
|  | ### 3. Common Issues and Solutions
 | |||
|  | 
 | |||
|  | **Issue**: "No TPU devices found" | |||
|  | - **Solution**: Verify TPU runtime is started and accessible | |||
|  | 
 | |||
|  | **Issue**: "DataLoader workers > 0 causes hangs" | |||
|  | - **Solution**: Set `dataloader_num_workers: 0` in config | |||
|  | 
 | |||
|  | **Issue**: "Mixed precision errors" | |||
|  | - **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16 | |||
|  | 
 | |||
|  | **Issue**: "Gradient synchronization timeouts" | |||
|  | - **Solution**: Check network connectivity between TPU cores | |||
|  | 
 | |||
|  | ## Example Training Command
 | |||
|  | 
 | |||
|  | ```bash | |||
|  | # Complete TPU training example
 | |||
|  | cd model_training_nnn | |||
|  | 
 | |||
|  | # 1. Update config for TPU
 | |||
|  | vim rnn_args.yaml  # Set use_tpu: true, num_tpu_cores: 8 | |||
|  | 
 | |||
|  | # 2. Launch TPU training
 | |||
|  | python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 | |||
|  | 
 | |||
|  | # 3. Monitor training progress
 | |||
|  | tail -f trained_models/baseline_rnn/training_log | |||
|  | ``` | |||
|  | 
 | |||
|  | ## Configuration Reference
 | |||
|  | 
 | |||
|  | ### Required TPU Settings
 | |||
|  | ```yaml | |||
|  | use_tpu: true | |||
|  | num_tpu_cores: 8 | |||
|  | dataloader_num_workers: 0 | |||
|  | use_amp: true | |||
|  | ``` | |||
|  | 
 | |||
|  | ### Optional TPU Optimizations
 | |||
|  | ```yaml | |||
|  | gradient_accumulation_steps: 1 | |||
|  | dataset: | |||
|  |   batch_size: 8  # Per-core batch size | |||
|  | mixed_precision: bf16 | |||
|  | ``` | |||
|  | 
 | |||
|  | This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library. |