265 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			265 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | TensorFlow Training Script for Brain-to-Text RNN Model | ||
|  | Optimized for TPU v5e-8 | ||
|  | 
 | ||
|  | This script trains the TripleGRUDecoder model using TensorFlow and TPU hardware. | ||
|  | It provides the same functionality as the PyTorch version but with TensorFlow | ||
|  | operations optimized for TPU performance. | ||
|  | 
 | ||
|  | Usage: | ||
|  |     python train_model_tf.py --config_path rnn_args.yaml | ||
|  | 
 | ||
|  | Requirements: | ||
|  |     - TensorFlow >= 2.15.0 | ||
|  |     - TPU v5e-8 environment | ||
|  |     - Access to brain-to-text HDF5 dataset | ||
|  | """
 | ||
|  | 
 | ||
|  | import argparse | ||
|  | import os | ||
|  | import sys | ||
|  | import logging | ||
|  | from omegaconf import OmegaConf | ||
|  | 
 | ||
|  | # Add the current directory to Python path for imports | ||
|  | sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|  | 
 | ||
|  | from trainer_tf import BrainToTextDecoderTrainerTF | ||
|  | 
 | ||
|  | 
 | ||
|  | def setup_tpu_environment(): | ||
|  |     """Setup TPU environment variables for optimal performance""" | ||
|  |     # TPU v5e-8 optimizations | ||
|  |     os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA')  # Enable XLA optimizations | ||
|  |     os.environ.setdefault('XLA_USE_BF16', '1')  # Enable bfloat16 for memory efficiency | ||
|  |     os.environ.setdefault('TF_XLA_FLAGS', '--tf_xla_auto_jit=2')  # Enable XLA JIT compilation | ||
|  | 
 | ||
|  |     # TPU memory optimizations | ||
|  |     os.environ.setdefault('TPU_MEGACORE', '1')  # Enable megacore mode for larger models | ||
|  |     os.environ.setdefault('LIBTPU_INIT_ARGS', '--xla_tpu_spmd_threshold_for_allgather_cse=10000') | ||
|  | 
 | ||
|  |     # Disable warnings for cleaner output | ||
|  |     os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2') | ||
|  | 
 | ||
|  |     print("TPU environment configured for v5e-8 optimizations") | ||
|  | 
 | ||
|  | 
 | ||
|  | def validate_config(config): | ||
|  |     """Validate configuration for TensorFlow TPU training""" | ||
|  |     required_fields = [ | ||
|  |         'model.n_input_features', | ||
|  |         'model.n_units', | ||
|  |         'dataset.sessions', | ||
|  |         'dataset.n_classes', | ||
|  |         'num_training_batches', | ||
|  |         'output_dir', | ||
|  |         'checkpoint_dir' | ||
|  |     ] | ||
|  | 
 | ||
|  |     for field in required_fields: | ||
|  |         keys = field.split('.') | ||
|  |         value = config | ||
|  |         try: | ||
|  |             for key in keys: | ||
|  |                 value = value[key] | ||
|  |         except KeyError: | ||
|  |             raise ValueError(f"Missing required configuration field: {field}") | ||
|  | 
 | ||
|  |     # TPU-specific validations | ||
|  |     if config.get('use_tpu', True): | ||
|  |         if config['dataset']['batch_size'] < 8: | ||
|  |             logging.warning("Small batch size may not utilize TPU efficiently. Consider batch_size >= 32") | ||
|  | 
 | ||
|  |         if not config.get('use_amp', True): | ||
|  |             logging.warning("Mixed precision disabled. Consider enabling for better TPU performance") | ||
|  | 
 | ||
|  |     # Dataset validation | ||
|  |     dataset_dir = config['dataset']['dataset_dir'] | ||
|  |     if not os.path.exists(dataset_dir): | ||
|  |         raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}") | ||
|  | 
 | ||
|  |     # Check if at least one session file exists | ||
|  |     session_found = False | ||
|  |     for session in config['dataset']['sessions']: | ||
|  |         train_path = os.path.join(dataset_dir, session, 'data_train.hdf5') | ||
|  |         if os.path.exists(train_path): | ||
|  |             session_found = True | ||
|  |             break | ||
|  | 
 | ||
|  |     if not session_found: | ||
|  |         raise FileNotFoundError("No valid session data files found in dataset directory") | ||
|  | 
 | ||
|  |     print("Configuration validation passed") | ||
|  | 
 | ||
|  | 
 | ||
|  | def main(): | ||
|  |     """Main training function""" | ||
|  |     parser = argparse.ArgumentParser( | ||
|  |         description='Train Brain-to-Text RNN Model with TensorFlow on TPU v5e-8', | ||
|  |         formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--config_path', | ||
|  |         default='rnn_args.yaml', | ||
|  |         help='Path to configuration YAML file' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--output_dir', | ||
|  |         default=None, | ||
|  |         help='Override output directory from config' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--checkpoint_dir', | ||
|  |         default=None, | ||
|  |         help='Override checkpoint directory from config' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--resume_from', | ||
|  |         default=None, | ||
|  |         help='Path to checkpoint to resume training from' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--num_batches', | ||
|  |         type=int, | ||
|  |         default=None, | ||
|  |         help='Override number of training batches' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--batch_size', | ||
|  |         type=int, | ||
|  |         default=None, | ||
|  |         help='Override batch size' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--mixed_precision', | ||
|  |         action='store_true', | ||
|  |         default=None, | ||
|  |         help='Enable mixed precision training (bfloat16)' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--disable_mixed_precision', | ||
|  |         action='store_true', | ||
|  |         help='Disable mixed precision training' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--validate_only', | ||
|  |         action='store_true', | ||
|  |         help='Only run validation, do not train' | ||
|  |     ) | ||
|  | 
 | ||
|  |     parser.add_argument( | ||
|  |         '--debug', | ||
|  |         action='store_true', | ||
|  |         help='Enable debug logging' | ||
|  |     ) | ||
|  | 
 | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     # Setup logging | ||
|  |     log_level = logging.DEBUG if args.debug else logging.INFO | ||
|  |     logging.basicConfig( | ||
|  |         level=log_level, | ||
|  |         format='%(asctime)s - %(levelname)s - %(message)s' | ||
|  |     ) | ||
|  | 
 | ||
|  |     # Setup TPU environment | ||
|  |     setup_tpu_environment() | ||
|  | 
 | ||
|  |     # Load configuration | ||
|  |     if not os.path.exists(args.config_path): | ||
|  |         raise FileNotFoundError(f"Configuration file not found: {args.config_path}") | ||
|  | 
 | ||
|  |     config = OmegaConf.load(args.config_path) | ||
|  |     print(f"Loaded configuration from: {args.config_path}") | ||
|  | 
 | ||
|  |     # Apply command line overrides | ||
|  |     if args.output_dir: | ||
|  |         config.output_dir = args.output_dir | ||
|  |     if args.checkpoint_dir: | ||
|  |         config.checkpoint_dir = args.checkpoint_dir | ||
|  |     if args.num_batches: | ||
|  |         config.num_training_batches = args.num_batches | ||
|  |     if args.batch_size: | ||
|  |         config.dataset.batch_size = args.batch_size | ||
|  |     if args.mixed_precision: | ||
|  |         config.use_amp = True | ||
|  |     if args.disable_mixed_precision: | ||
|  |         config.use_amp = False | ||
|  | 
 | ||
|  |     # Validate configuration | ||
|  |     validate_config(config) | ||
|  | 
 | ||
|  |     try: | ||
|  |         # Initialize trainer | ||
|  |         print("Initializing TensorFlow Brain-to-Text trainer...") | ||
|  |         trainer = BrainToTextDecoderTrainerTF(config) | ||
|  | 
 | ||
|  |         # Load checkpoint if specified | ||
|  |         if args.resume_from: | ||
|  |             if os.path.exists(args.resume_from + '.weights.h5'): | ||
|  |                 trainer.load_checkpoint(args.resume_from) | ||
|  |                 print(f"Resumed training from checkpoint: {args.resume_from}") | ||
|  |             else: | ||
|  |                 raise FileNotFoundError(f"Checkpoint not found: {args.resume_from}") | ||
|  | 
 | ||
|  |         if args.validate_only: | ||
|  |             print("Running validation only...") | ||
|  |             # Create validation dataset | ||
|  |             from dataset_tf import create_input_fn | ||
|  |             val_dataset = create_input_fn( | ||
|  |                 trainer.val_dataset_tf, | ||
|  |                 trainer.args['dataset']['data_transforms'], | ||
|  |                 training=False | ||
|  |             ) | ||
|  |             val_dist_dataset = trainer.strategy.experimental_distribute_dataset(val_dataset) | ||
|  | 
 | ||
|  |             # Run validation | ||
|  |             val_metrics = trainer._validate(val_dist_dataset) | ||
|  | 
 | ||
|  |             print(f"Validation Results:") | ||
|  |             print(f"  Average Loss: {val_metrics['avg_loss']:.4f}") | ||
|  |             print(f"  Average PER: {val_metrics['avg_per']:.4f}") | ||
|  |             print(f"  Total Edit Distance: {val_metrics['total_edit_distance']}") | ||
|  |             print(f"  Total Sequence Length: {val_metrics['total_seq_length']}") | ||
|  | 
 | ||
|  |         else: | ||
|  |             # Start training | ||
|  |             print("Starting training...") | ||
|  |             train_stats = trainer.train() | ||
|  | 
 | ||
|  |             print("\nTraining completed successfully!") | ||
|  |             print(f"Best validation PER: {trainer.best_val_per:.5f}") | ||
|  |             print(f"Final training loss: {train_stats['train_losses'][-1]:.4f}") | ||
|  |             print(f"Final validation loss: {train_stats['val_losses'][-1]:.4f}") | ||
|  |             print(f"Total training batches: {len(train_stats['train_losses'])}") | ||
|  | 
 | ||
|  |             # Save final training statistics | ||
|  |             import pickle | ||
|  |             stats_path = os.path.join(config.output_dir, 'training_stats.pkl') | ||
|  |             with open(stats_path, 'wb') as f: | ||
|  |                 pickle.dump(train_stats, f) | ||
|  |             print(f"Training statistics saved to: {stats_path}") | ||
|  | 
 | ||
|  |     except KeyboardInterrupt: | ||
|  |         print("\nTraining interrupted by user") | ||
|  |         sys.exit(1) | ||
|  |     except Exception as e: | ||
|  |         print(f"\nTraining failed with error: {e}") | ||
|  |         if args.debug: | ||
|  |             import traceback | ||
|  |             traceback.print_exc() | ||
|  |         sys.exit(1) | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     main() |