2025-10-12 15:31:45 +08:00
|
|
|
import argparse
|
2025-10-12 09:11:32 +08:00
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from rnn_trainer import BrainToTextDecoder_Trainer
|
|
|
|
|
2025-10-12 15:31:45 +08:00
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
|
|
|
|
parser.add_argument('--config_path', default='rnn_args.yaml',
|
|
|
|
help='Path to configuration file (default: rnn_args.yaml)')
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
# Load configuration
|
|
|
|
config = OmegaConf.load(args.config_path)
|
|
|
|
|
|
|
|
# Initialize trainer
|
|
|
|
trainer = BrainToTextDecoder_Trainer(config)
|
|
|
|
|
|
|
|
# Start training
|
|
|
|
trainer.train()
|
|
|
|
|
|
|
|
print("Training completed successfully!")
|
|
|
|
print(f"Best validation PER: {trainer.best_val_PER:.5f}")
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|