25 lines
711 B
Python
25 lines
711 B
Python
![]() |
import argparse
|
||
|
from omegaconf import OmegaConf
|
||
|
from rnn_trainer import BrainToTextDecoder_Trainer
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
|
||
|
parser.add_argument('--config_path', default='rnn_args.yaml',
|
||
|
help='Path to configuration file (default: rnn_args.yaml)')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# Load configuration
|
||
|
config = OmegaConf.load(args.config_path)
|
||
|
|
||
|
# Initialize trainer
|
||
|
trainer = BrainToTextDecoder_Trainer(config)
|
||
|
|
||
|
# Start training
|
||
|
trainer.train()
|
||
|
|
||
|
print("Training completed successfully!")
|
||
|
print(f"Best validation PER: {trainer.best_val_PER:.5f}")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|