Files
b2txt25/model_training_nnn/launch_tpu_training.py
2025-10-12 15:31:45 +08:00

126 lines
4.0 KiB
Python

#!/usr/bin/env python3
"""
TPU Training Launch Script for Brain-to-Text RNN Model
This script provides easy TPU training setup using Accelerate library.
Supports both single TPU core and multi-core (8 cores) training.
Usage:
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
Requirements:
- PyTorch XLA installed
- Accelerate library installed
- TPU runtime available
"""
import argparse
import yaml
import os
import sys
from pathlib import Path
def update_config_for_tpu(config_path, num_cores=8):
"""
Update configuration file to enable TPU training
"""
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# Enable TPU settings
config['use_tpu'] = True
config['num_tpu_cores'] = num_cores
config['dataloader_num_workers'] = 0 # Required for TPU
config['use_amp'] = True # Enable mixed precision with bfloat16
# Adjust batch size and gradient accumulation for multi-core TPU
if num_cores > 1:
# Distribute batch size across cores
original_batch_size = config['dataset']['batch_size']
config['dataset']['batch_size'] = max(1, original_batch_size // num_cores)
config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1))
print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core")
print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
# Save updated config
tpu_config_path = config_path.replace('.yaml', '_tpu.yaml')
with open(tpu_config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
print(f"TPU configuration saved to: {tpu_config_path}")
return tpu_config_path
def check_tpu_environment():
"""
Check if TPU environment is properly set up
"""
try:
import torch_xla
import torch_xla.core.xla_model as xm
# Check if TPUs are available
device = xm.xla_device()
print(f"TPU device available: {device}")
print(f"TPU ordinal: {xm.get_ordinal()}")
print(f"TPU world size: {xm.xrt_world_size()}")
return True
except ImportError:
print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.")
return False
except Exception as e:
print(f"ERROR: TPU not available - {e}")
return False
def run_tpu_training(config_path, num_cores=8):
"""
Launch TPU training using accelerate
"""
# Check TPU environment
if not check_tpu_environment():
sys.exit(1)
# Update config for TPU
tpu_config_path = update_config_for_tpu(config_path, num_cores)
# Set TPU environment variables
os.environ['TPU_CORES'] = str(num_cores)
os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16
# Launch training with accelerate
cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}"
print(f"Launching TPU training with command:")
print(f" {cmd}")
print(f"Using {num_cores} TPU cores")
print("-" * 60)
# Execute training
os.system(cmd)
def main():
parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN')
parser.add_argument('--config', default='rnn_args.yaml',
help='Path to configuration file (default: rnn_args.yaml)')
parser.add_argument('--num_cores', type=int, default=8,
help='Number of TPU cores to use (default: 8)')
parser.add_argument('--check_only', action='store_true',
help='Only check TPU environment, do not launch training')
args = parser.parse_args()
# Verify config file exists
if not os.path.exists(args.config):
print(f"ERROR: Configuration file {args.config} not found")
sys.exit(1)
if args.check_only:
check_tpu_environment()
return
# Run TPU training
run_tpu_training(args.config, args.num_cores)
if __name__ == "__main__":
main()