#!/usr/bin/env python3 """ TensorFlow Training Script for Brain-to-Text RNN Model Optimized for TPU v5e-8 This script trains the TripleGRUDecoder model using TensorFlow and TPU hardware. It provides the same functionality as the PyTorch version but with TensorFlow operations optimized for TPU performance. Usage: python train_model_tf.py --config_path rnn_args.yaml Requirements: - TensorFlow >= 2.15.0 - TPU v5e-8 environment - Access to brain-to-text HDF5 dataset """ import argparse import os import sys import logging from omegaconf import OmegaConf # Add the current directory to Python path for imports sys.path.append(os.path.dirname(os.path.abspath(__file__))) from trainer_tf import BrainToTextDecoderTrainerTF def setup_tpu_environment(): """Setup TPU environment variables for optimal performance""" # TPU v5e-8 optimizations os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') # Enable XLA optimizations os.environ.setdefault('XLA_USE_BF16', '1') # Enable bfloat16 for memory efficiency os.environ.setdefault('TF_XLA_FLAGS', '--tf_xla_auto_jit=2') # Enable XLA JIT compilation # TPU memory optimizations os.environ.setdefault('TPU_MEGACORE', '1') # Enable megacore mode for larger models os.environ.setdefault('LIBTPU_INIT_ARGS', '--xla_tpu_spmd_threshold_for_allgather_cse=10000') # Disable warnings for cleaner output os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2') print("TPU environment configured for v5e-8 optimizations") def validate_config(config): """Validate configuration for TensorFlow TPU training""" required_fields = [ 'model.n_input_features', 'model.n_units', 'dataset.sessions', 'dataset.n_classes', 'num_training_batches', 'output_dir', 'checkpoint_dir' ] for field in required_fields: keys = field.split('.') value = config try: for key in keys: value = value[key] except KeyError: raise ValueError(f"Missing required configuration field: {field}") # TPU-specific validations if config.get('use_tpu', True): if config['dataset']['batch_size'] < 8: logging.warning("Small batch size may not utilize TPU efficiently. Consider batch_size >= 32") if not config.get('use_amp', True): logging.warning("Mixed precision disabled. Consider enabling for better TPU performance") # Dataset validation dataset_dir = config['dataset']['dataset_dir'] if not os.path.exists(dataset_dir): raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}") # Check if at least one session file exists session_found = False for session in config['dataset']['sessions']: train_path = os.path.join(dataset_dir, session, 'data_train.hdf5') if os.path.exists(train_path): session_found = True break if not session_found: raise FileNotFoundError("No valid session data files found in dataset directory") print("Configuration validation passed") def main(): """Main training function""" parser = argparse.ArgumentParser( description='Train Brain-to-Text RNN Model with TensorFlow on TPU v5e-8', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( '--config_path', default='rnn_args.yaml', help='Path to configuration YAML file' ) parser.add_argument( '--output_dir', default=None, help='Override output directory from config' ) parser.add_argument( '--checkpoint_dir', default=None, help='Override checkpoint directory from config' ) parser.add_argument( '--resume_from', default=None, help='Path to checkpoint to resume training from' ) parser.add_argument( '--num_batches', type=int, default=None, help='Override number of training batches' ) parser.add_argument( '--batch_size', type=int, default=None, help='Override batch size' ) parser.add_argument( '--mixed_precision', action='store_true', default=None, help='Enable mixed precision training (bfloat16)' ) parser.add_argument( '--disable_mixed_precision', action='store_true', help='Disable mixed precision training' ) parser.add_argument( '--validate_only', action='store_true', help='Only run validation, do not train' ) parser.add_argument( '--debug', action='store_true', help='Enable debug logging' ) args = parser.parse_args() # Setup logging log_level = logging.DEBUG if args.debug else logging.INFO logging.basicConfig( level=log_level, format='%(asctime)s - %(levelname)s - %(message)s' ) # Setup TPU environment setup_tpu_environment() # Load configuration if not os.path.exists(args.config_path): raise FileNotFoundError(f"Configuration file not found: {args.config_path}") config = OmegaConf.load(args.config_path) print(f"Loaded configuration from: {args.config_path}") # Apply command line overrides if args.output_dir: config.output_dir = args.output_dir if args.checkpoint_dir: config.checkpoint_dir = args.checkpoint_dir if args.num_batches: config.num_training_batches = args.num_batches if args.batch_size: config.dataset.batch_size = args.batch_size if args.mixed_precision: config.use_amp = True if args.disable_mixed_precision: config.use_amp = False # Validate configuration validate_config(config) try: # Initialize trainer print("Initializing TensorFlow Brain-to-Text trainer...") trainer = BrainToTextDecoderTrainerTF(config) # Load checkpoint if specified if args.resume_from: if os.path.exists(args.resume_from + '.weights.h5'): trainer.load_checkpoint(args.resume_from) print(f"Resumed training from checkpoint: {args.resume_from}") else: raise FileNotFoundError(f"Checkpoint not found: {args.resume_from}") if args.validate_only: print("Running validation only...") # Create validation dataset from dataset_tf import create_input_fn val_dataset = create_input_fn( trainer.val_dataset_tf, trainer.args['dataset']['data_transforms'], training=False ) val_dist_dataset = trainer.strategy.experimental_distribute_dataset(val_dataset) # Run validation val_metrics = trainer._validate(val_dist_dataset) print(f"Validation Results:") print(f" Average Loss: {val_metrics['avg_loss']:.4f}") print(f" Average PER: {val_metrics['avg_per']:.4f}") print(f" Total Edit Distance: {val_metrics['total_edit_distance']}") print(f" Total Sequence Length: {val_metrics['total_seq_length']}") else: # Start training print("Starting training...") train_stats = trainer.train() print("\nTraining completed successfully!") print(f"Best validation PER: {trainer.best_val_per:.5f}") print(f"Final training loss: {train_stats['train_losses'][-1]:.4f}") print(f"Final validation loss: {train_stats['val_losses'][-1]:.4f}") print(f"Total training batches: {len(train_stats['train_losses'])}") # Save final training statistics import pickle stats_path = os.path.join(config.output_dir, 'training_stats.pkl') with open(stats_path, 'wb') as f: pickle.dump(train_stats, f) print(f"Training statistics saved to: {stats_path}") except KeyboardInterrupt: print("\nTraining interrupted by user") sys.exit(1) except Exception as e: print(f"\nTraining failed with error: {e}") if args.debug: import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()