# argument parser for command line arguments parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.') parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', help='Path to the pretrained GRU model directory.') parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', help='Path to the pretrained LSTM model directory.') parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', help='Path to the dataset directory (relative to the current working directory).') parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'], help='Evaluation type: "val" for validation set, "test" for test set.') parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', help='Path to the CSV file with metadata about the dataset.') parser.add_argument('--gpu_number', type=int, default=0, help='GPU number to use for model inference. Set to -1 to use CPU.') parser.add_argument('--gru_weight', type=float, default=0.6, help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.') # TTA parameters parser.add_argument('--tta_samples', type=int, default=8, help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.') parser.add_argument('--tta_noise_std', type=float, default=0.01, help='Standard deviation for TTA noise augmentation.') parser.add_argument('--tta_smooth_range', type=float, default=0.5, help='Range for TTA smoothing kernel variation (±range from default).') parser.add_argument('--tta_scale_range', type=float, default=0.05, help='Range for TTA amplitude scaling (±range from 1.0).') parser.add_argument('--tta_cut_max', type=int, default=3, help='Maximum number of timesteps to cut from beginning in TTA.') args = parser.parse_args() Total true phoneme length: 41392 Total edit distance: 6473 Aggregate Phoneme Error Rate (PER): 15.64% Results saved to: /root/autodl-tmp/nejm-brain-to-text/TTA-E/TTA-E_gru0.6_lstm0.4_samples8_val_20250917_210946.csv TTA-E configuration: GRU weight = 0.60, LSTM weight = 0.40, TTA samples = 8 ============================================================ parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.') parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', help='Path to the pretrained GRU model directory.') parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', help='Path to the pretrained LSTM model directory.') parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', help='Path to the dataset directory (relative to the current working directory).') parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'], help='Evaluation type: "val" for validation set, "test" for test set.') parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', help='Path to the CSV file with metadata about the dataset.') parser.add_argument('--gpu_number', type=int, default=0, help='GPU number to use for model inference. Set to -1 to use CPU.') parser.add_argument('--gru_weight', type=float, default=0.8, help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.') # TTA parameters parser.add_argument('--tta_samples', type=int, default=8, help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.') parser.add_argument('--tta_noise_std', type=float, default=0.01, help='Standard deviation for TTA noise augmentation.') parser.add_argument('--tta_smooth_range', type=float, default=0.5, help='Range for TTA smoothing kernel variation (±range from default).') parser.add_argument('--tta_scale_range', type=float, default=0.05, help='Range for TTA amplitude scaling (±range from 1.0).') parser.add_argument('--tta_cut_max', type=int, default=3, help='Maximum number of timesteps to cut from beginning in TTA.') 提高GRU权重 Total true phoneme length: 41392 Total edit distance: 4705 Aggregate Phoneme Error Rate (PER): 11.37% ============================================================ 去TTA Total true phoneme length: 41392 Total edit distance: 4326 Aggregate Phoneme Error Rate (PER): 10.45% ============================================================ 纯GRU Total true phoneme length: 41392 Total edit distance: 4215 Aggregate Phoneme Error Rate (PER): 10.18% ============================================================ 纯LSTM Total true phoneme length: 41392 Total edit distance: 4498 Aggregate Phoneme Error Rate (PER): 10.87% ============================================================ 纯GRU + 3TTA Total true phoneme length: 41392 Total edit distance: 4213 Aggregate Phoneme Error Rate (PER): 10.18% ============================================================ 纯GRU + 5TTA Total true phoneme length: 41392 Total edit distance: 4218 Aggregate Phoneme Error Rate (PER): 10.19% ============================================================ 纯GRU + 4TTA Aggregate Phoneme Error Rate (PER): 10.13% ============================================================ 纯GRU + 4TTA 去高斯噪声 Aggregate Phoneme Error Rate (PER): 10.14% ============================================================ parser.add_argument('--tta_weights', type=str, default='0.6,0.6,0.6,1.0,0.0', Aggregate Phoneme Error Rate (PER): 10.16% ============================================================ parser.add_argument('--tta_weights', type=str, default='1.0,0.6,0.6,1.0,0.0', Aggregate Phoneme Error Rate (PER): 10.13% ============================================================ parser.add_argument('--tta_weights', type=str, default='1.0,0.6,0.0,1.0,0.0', Aggregate Phoneme Error Rate (PER): 10.23% ============================================================ parser.add_argument('--tta_weights', type=str, default='1.0,1.0,0.0,1.0,0.0', Aggregate Phoneme Error Rate (PER): 10.14% ============================================================ ============================================================ ============================================================ 纯GRU + 5TTA 去高斯噪声 Total edit distance: 4308 Aggregate Phoneme Error Rate (PER): 10.41%