Files
b2txt25/model_training_nnn_tpu/train_model.py
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

25 lines
711 B
Python

import argparse
from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
def main():
parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
parser.add_argument('--config_path', default='rnn_args.yaml',
help='Path to configuration file (default: rnn_args.yaml)')
args = parser.parse_args()
# Load configuration
config = OmegaConf.load(args.config_path)
# Initialize trainer
trainer = BrainToTextDecoder_Trainer(config)
# Start training
trainer.train()
print("Training completed successfully!")
print(f"Best validation PER: {trainer.best_val_PER:.5f}")
if __name__ == "__main__":
main()