TPU
This commit is contained in:
265
model_training_nnn_tpu/train_model_tf.py
Normal file
265
model_training_nnn_tpu/train_model_tf.py
Normal file
@@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TensorFlow Training Script for Brain-to-Text RNN Model
|
||||
Optimized for TPU v5e-8
|
||||
|
||||
This script trains the TripleGRUDecoder model using TensorFlow and TPU hardware.
|
||||
It provides the same functionality as the PyTorch version but with TensorFlow
|
||||
operations optimized for TPU performance.
|
||||
|
||||
Usage:
|
||||
python train_model_tf.py --config_path rnn_args.yaml
|
||||
|
||||
Requirements:
|
||||
- TensorFlow >= 2.15.0
|
||||
- TPU v5e-8 environment
|
||||
- Access to brain-to-text HDF5 dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# Add the current directory to Python path for imports
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from trainer_tf import BrainToTextDecoderTrainerTF
|
||||
|
||||
|
||||
def setup_tpu_environment():
|
||||
"""Setup TPU environment variables for optimal performance"""
|
||||
# TPU v5e-8 optimizations
|
||||
os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA') # Enable XLA optimizations
|
||||
os.environ.setdefault('XLA_USE_BF16', '1') # Enable bfloat16 for memory efficiency
|
||||
os.environ.setdefault('TF_XLA_FLAGS', '--tf_xla_auto_jit=2') # Enable XLA JIT compilation
|
||||
|
||||
# TPU memory optimizations
|
||||
os.environ.setdefault('TPU_MEGACORE', '1') # Enable megacore mode for larger models
|
||||
os.environ.setdefault('LIBTPU_INIT_ARGS', '--xla_tpu_spmd_threshold_for_allgather_cse=10000')
|
||||
|
||||
# Disable warnings for cleaner output
|
||||
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
|
||||
|
||||
print("TPU environment configured for v5e-8 optimizations")
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
"""Validate configuration for TensorFlow TPU training"""
|
||||
required_fields = [
|
||||
'model.n_input_features',
|
||||
'model.n_units',
|
||||
'dataset.sessions',
|
||||
'dataset.n_classes',
|
||||
'num_training_batches',
|
||||
'output_dir',
|
||||
'checkpoint_dir'
|
||||
]
|
||||
|
||||
for field in required_fields:
|
||||
keys = field.split('.')
|
||||
value = config
|
||||
try:
|
||||
for key in keys:
|
||||
value = value[key]
|
||||
except KeyError:
|
||||
raise ValueError(f"Missing required configuration field: {field}")
|
||||
|
||||
# TPU-specific validations
|
||||
if config.get('use_tpu', True):
|
||||
if config['dataset']['batch_size'] < 8:
|
||||
logging.warning("Small batch size may not utilize TPU efficiently. Consider batch_size >= 32")
|
||||
|
||||
if not config.get('use_amp', True):
|
||||
logging.warning("Mixed precision disabled. Consider enabling for better TPU performance")
|
||||
|
||||
# Dataset validation
|
||||
dataset_dir = config['dataset']['dataset_dir']
|
||||
if not os.path.exists(dataset_dir):
|
||||
raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
|
||||
|
||||
# Check if at least one session file exists
|
||||
session_found = False
|
||||
for session in config['dataset']['sessions']:
|
||||
train_path = os.path.join(dataset_dir, session, 'data_train.hdf5')
|
||||
if os.path.exists(train_path):
|
||||
session_found = True
|
||||
break
|
||||
|
||||
if not session_found:
|
||||
raise FileNotFoundError("No valid session data files found in dataset directory")
|
||||
|
||||
print("Configuration validation passed")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
default='rnn_args.yaml',
|
||||
help='Path to configuration YAML file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
default=None,
|
||||
help='Override output directory from config'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--checkpoint_dir',
|
||||
default=None,
|
||||
help='Override checkpoint directory from config'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--resume_from',
|
||||
default=None,
|
||||
help='Path to checkpoint to resume training from'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--num_batches',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Override number of training batches'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Override batch size'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--mixed_precision',
|
||||
action='store_true',
|
||||
default=None,
|
||||
help='Enable mixed precision training (bfloat16)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--disable_mixed_precision',
|
||||
action='store_true',
|
||||
help='Disable mixed precision training'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--validate_only',
|
||||
action='store_true',
|
||||
help='Only run validation, do not train'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
action='store_true',
|
||||
help='Enable debug logging'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
log_level = logging.DEBUG if args.debug else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Setup TPU environment
|
||||
setup_tpu_environment()
|
||||
|
||||
# Load configuration
|
||||
if not os.path.exists(args.config_path):
|
||||
raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
|
||||
|
||||
config = OmegaConf.load(args.config_path)
|
||||
print(f"Loaded configuration from: {args.config_path}")
|
||||
|
||||
# Apply command line overrides
|
||||
if args.output_dir:
|
||||
config.output_dir = args.output_dir
|
||||
if args.checkpoint_dir:
|
||||
config.checkpoint_dir = args.checkpoint_dir
|
||||
if args.num_batches:
|
||||
config.num_training_batches = args.num_batches
|
||||
if args.batch_size:
|
||||
config.dataset.batch_size = args.batch_size
|
||||
if args.mixed_precision:
|
||||
config.use_amp = True
|
||||
if args.disable_mixed_precision:
|
||||
config.use_amp = False
|
||||
|
||||
# Validate configuration
|
||||
validate_config(config)
|
||||
|
||||
try:
|
||||
# Initialize trainer
|
||||
print("Initializing TensorFlow Brain-to-Text trainer...")
|
||||
trainer = BrainToTextDecoderTrainerTF(config)
|
||||
|
||||
# Load checkpoint if specified
|
||||
if args.resume_from:
|
||||
if os.path.exists(args.resume_from + '.weights.h5'):
|
||||
trainer.load_checkpoint(args.resume_from)
|
||||
print(f"Resumed training from checkpoint: {args.resume_from}")
|
||||
else:
|
||||
raise FileNotFoundError(f"Checkpoint not found: {args.resume_from}")
|
||||
|
||||
if args.validate_only:
|
||||
print("Running validation only...")
|
||||
# Create validation dataset
|
||||
from dataset_tf import create_input_fn
|
||||
val_dataset = create_input_fn(
|
||||
trainer.val_dataset_tf,
|
||||
trainer.args['dataset']['data_transforms'],
|
||||
training=False
|
||||
)
|
||||
val_dist_dataset = trainer.strategy.experimental_distribute_dataset(val_dataset)
|
||||
|
||||
# Run validation
|
||||
val_metrics = trainer._validate(val_dist_dataset)
|
||||
|
||||
print(f"Validation Results:")
|
||||
print(f" Average Loss: {val_metrics['avg_loss']:.4f}")
|
||||
print(f" Average PER: {val_metrics['avg_per']:.4f}")
|
||||
print(f" Total Edit Distance: {val_metrics['total_edit_distance']}")
|
||||
print(f" Total Sequence Length: {val_metrics['total_seq_length']}")
|
||||
|
||||
else:
|
||||
# Start training
|
||||
print("Starting training...")
|
||||
train_stats = trainer.train()
|
||||
|
||||
print("\nTraining completed successfully!")
|
||||
print(f"Best validation PER: {trainer.best_val_per:.5f}")
|
||||
print(f"Final training loss: {train_stats['train_losses'][-1]:.4f}")
|
||||
print(f"Final validation loss: {train_stats['val_losses'][-1]:.4f}")
|
||||
print(f"Total training batches: {len(train_stats['train_losses'])}")
|
||||
|
||||
# Save final training statistics
|
||||
import pickle
|
||||
stats_path = os.path.join(config.output_dir, 'training_stats.pkl')
|
||||
with open(stats_path, 'wb') as f:
|
||||
pickle.dump(train_stats, f)
|
||||
print(f"Training statistics saved to: {stats_path}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nTraining interrupted by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\nTraining failed with error: {e}")
|
||||
if args.debug:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user