126 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			126 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | TPU Training Launch Script for Brain-to-Text RNN Model | ||
|  | 
 | ||
|  | This script provides easy TPU training setup using Accelerate library. | ||
|  | Supports both single TPU core and multi-core (8 cores) training. | ||
|  | 
 | ||
|  | Usage: | ||
|  |     python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 | ||
|  | 
 | ||
|  | Requirements: | ||
|  |     - PyTorch XLA installed | ||
|  |     - Accelerate library installed | ||
|  |     - TPU runtime available | ||
|  | """
 | ||
|  | 
 | ||
|  | import argparse | ||
|  | import yaml | ||
|  | import os | ||
|  | import sys | ||
|  | from pathlib import Path | ||
|  | 
 | ||
|  | def update_config_for_tpu(config_path, num_cores=8): | ||
|  |     """
 | ||
|  |     Update configuration file to enable TPU training | ||
|  |     """
 | ||
|  |     with open(config_path, 'r') as f: | ||
|  |         config = yaml.safe_load(f) | ||
|  | 
 | ||
|  |     # Enable TPU settings | ||
|  |     config['use_tpu'] = True | ||
|  |     config['num_tpu_cores'] = num_cores | ||
|  |     config['dataloader_num_workers'] = 0  # Required for TPU | ||
|  |     config['use_amp'] = True  # Enable mixed precision with bfloat16 | ||
|  | 
 | ||
|  |     # Adjust batch size and gradient accumulation for multi-core TPU | ||
|  |     if num_cores > 1: | ||
|  |         # Distribute batch size across cores | ||
|  |         original_batch_size = config['dataset']['batch_size'] | ||
|  |         config['dataset']['batch_size'] = max(1, original_batch_size // num_cores) | ||
|  |         config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1)) | ||
|  | 
 | ||
|  |         print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core") | ||
|  |         print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}") | ||
|  | 
 | ||
|  |     # Save updated config | ||
|  |     tpu_config_path = config_path.replace('.yaml', '_tpu.yaml') | ||
|  |     with open(tpu_config_path, 'w') as f: | ||
|  |         yaml.dump(config, f, default_flow_style=False) | ||
|  | 
 | ||
|  |     print(f"TPU configuration saved to: {tpu_config_path}") | ||
|  |     return tpu_config_path | ||
|  | 
 | ||
|  | def check_tpu_environment(): | ||
|  |     """
 | ||
|  |     Check if TPU environment is properly set up | ||
|  |     """
 | ||
|  |     try: | ||
|  |         import torch_xla | ||
|  |         import torch_xla.core.xla_model as xm | ||
|  | 
 | ||
|  |         # Check if TPUs are available | ||
|  |         device = xm.xla_device() | ||
|  |         print(f"TPU device available: {device}") | ||
|  |         print(f"TPU ordinal: {xm.get_ordinal()}") | ||
|  |         print(f"TPU world size: {xm.xrt_world_size()}") | ||
|  | 
 | ||
|  |         return True | ||
|  |     except ImportError: | ||
|  |         print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.") | ||
|  |         return False | ||
|  |     except Exception as e: | ||
|  |         print(f"ERROR: TPU not available - {e}") | ||
|  |         return False | ||
|  | 
 | ||
|  | def run_tpu_training(config_path, num_cores=8): | ||
|  |     """
 | ||
|  |     Launch TPU training using accelerate | ||
|  |     """
 | ||
|  |     # Check TPU environment | ||
|  |     if not check_tpu_environment(): | ||
|  |         sys.exit(1) | ||
|  | 
 | ||
|  |     # Update config for TPU | ||
|  |     tpu_config_path = update_config_for_tpu(config_path, num_cores) | ||
|  | 
 | ||
|  |     # Set TPU environment variables | ||
|  |     os.environ['TPU_CORES'] = str(num_cores) | ||
|  |     os.environ['XLA_USE_BF16'] = '1'  # Enable bfloat16 | ||
|  | 
 | ||
|  |     # Launch training with accelerate | ||
|  |     cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}" | ||
|  | 
 | ||
|  |     print(f"Launching TPU training with command:") | ||
|  |     print(f"  {cmd}") | ||
|  |     print(f"Using {num_cores} TPU cores") | ||
|  |     print("-" * 60) | ||
|  | 
 | ||
|  |     # Execute training | ||
|  |     os.system(cmd) | ||
|  | 
 | ||
|  | def main(): | ||
|  |     parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN') | ||
|  |     parser.add_argument('--config', default='rnn_args.yaml', | ||
|  |                        help='Path to configuration file (default: rnn_args.yaml)') | ||
|  |     parser.add_argument('--num_cores', type=int, default=8, | ||
|  |                        help='Number of TPU cores to use (default: 8)') | ||
|  |     parser.add_argument('--check_only', action='store_true', | ||
|  |                        help='Only check TPU environment, do not launch training') | ||
|  | 
 | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     # Verify config file exists | ||
|  |     if not os.path.exists(args.config): | ||
|  |         print(f"ERROR: Configuration file {args.config} not found") | ||
|  |         sys.exit(1) | ||
|  | 
 | ||
|  |     if args.check_only: | ||
|  |         check_tpu_environment() | ||
|  |         return | ||
|  | 
 | ||
|  |     # Run TPU training | ||
|  |     run_tpu_training(args.config, args.num_cores) | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     main() |