tpu
This commit is contained in:
25
model_training_nnn_tpu/train_model.py
Normal file
25
model_training_nnn_tpu/train_model.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import argparse
|
||||
from omegaconf import OmegaConf
|
||||
from rnn_trainer import BrainToTextDecoder_Trainer
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model')
|
||||
parser.add_argument('--config_path', default='rnn_args.yaml',
|
||||
help='Path to configuration file (default: rnn_args.yaml)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load configuration
|
||||
config = OmegaConf.load(args.config_path)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = BrainToTextDecoder_Trainer(config)
|
||||
|
||||
# Start training
|
||||
trainer.train()
|
||||
|
||||
print("Training completed successfully!")
|
||||
print(f"Best validation PER: {trainer.best_val_PER:.5f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user