25 lines
		
	
	
		
			711 B
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			25 lines
		
	
	
		
			711 B
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import argparse | ||
|  | from omegaconf import OmegaConf | ||
|  | from rnn_trainer import BrainToTextDecoder_Trainer | ||
|  | 
 | ||
|  | def main(): | ||
|  |     parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model') | ||
|  |     parser.add_argument('--config_path', default='rnn_args.yaml', | ||
|  |                        help='Path to configuration file (default: rnn_args.yaml)') | ||
|  | 
 | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     # Load configuration | ||
|  |     config = OmegaConf.load(args.config_path) | ||
|  | 
 | ||
|  |     # Initialize trainer | ||
|  |     trainer = BrainToTextDecoder_Trainer(config) | ||
|  | 
 | ||
|  |     # Start training | ||
|  |     trainer.train() | ||
|  | 
 | ||
|  |     print("Training completed successfully!") | ||
|  |     print(f"Best validation PER: {trainer.best_val_PER:.5f}") | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     main() |