Files
b2txt25/TTA-E/log.txt
2025-10-06 15:17:44 +08:00

118 lines
6.9 KiB
Plaintext

# argument parser for command line arguments
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
help='Path to the pretrained GRU model directory.')
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
help='Path to the pretrained LSTM model directory.')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset.')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU number to use for model inference. Set to -1 to use CPU.')
parser.add_argument('--gru_weight', type=float, default=0.6,
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.')
# TTA parameters
parser.add_argument('--tta_samples', type=int, default=8,
help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.')
parser.add_argument('--tta_noise_std', type=float, default=0.01,
help='Standard deviation for TTA noise augmentation.')
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
help='Range for TTA smoothing kernel variation (±range from default).')
parser.add_argument('--tta_scale_range', type=float, default=0.05,
help='Range for TTA amplitude scaling (±range from 1.0).')
parser.add_argument('--tta_cut_max', type=int, default=3,
help='Maximum number of timesteps to cut from beginning in TTA.')
args = parser.parse_args()
Total true phoneme length: 41392
Total edit distance: 6473
Aggregate Phoneme Error Rate (PER): 15.64%
Results saved to: /root/autodl-tmp/nejm-brain-to-text/TTA-E/TTA-E_gru0.6_lstm0.4_samples8_val_20250917_210946.csv
TTA-E configuration: GRU weight = 0.60, LSTM weight = 0.40, TTA samples = 8
============================================================
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
help='Path to the pretrained GRU model directory.')
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
help='Path to the pretrained LSTM model directory.')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset.')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU number to use for model inference. Set to -1 to use CPU.')
parser.add_argument('--gru_weight', type=float, default=0.8,
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.')
# TTA parameters
parser.add_argument('--tta_samples', type=int, default=8,
help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.')
parser.add_argument('--tta_noise_std', type=float, default=0.01,
help='Standard deviation for TTA noise augmentation.')
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
help='Range for TTA smoothing kernel variation (±range from default).')
parser.add_argument('--tta_scale_range', type=float, default=0.05,
help='Range for TTA amplitude scaling (±range from 1.0).')
parser.add_argument('--tta_cut_max', type=int, default=3,
help='Maximum number of timesteps to cut from beginning in TTA.')
提高GRU权重
Total true phoneme length: 41392
Total edit distance: 4705
Aggregate Phoneme Error Rate (PER): 11.37%
============================================================
去TTA
Total true phoneme length: 41392
Total edit distance: 4326
Aggregate Phoneme Error Rate (PER): 10.45%
============================================================
纯GRU
Total true phoneme length: 41392
Total edit distance: 4215
Aggregate Phoneme Error Rate (PER): 10.18%
============================================================
纯LSTM
Total true phoneme length: 41392
Total edit distance: 4498
Aggregate Phoneme Error Rate (PER): 10.87%
============================================================
纯GRU + 3TTA
Total true phoneme length: 41392
Total edit distance: 4213
Aggregate Phoneme Error Rate (PER): 10.18%
============================================================
纯GRU + 5TTA
Total true phoneme length: 41392
Total edit distance: 4218
Aggregate Phoneme Error Rate (PER): 10.19%
============================================================
纯GRU + 4TTA
Aggregate Phoneme Error Rate (PER): 10.13%
============================================================
纯GRU + 4TTA 去高斯噪声
Aggregate Phoneme Error Rate (PER): 10.14%
============================================================
parser.add_argument('--tta_weights', type=str, default='0.6,0.6,0.6,1.0,0.0',
Aggregate Phoneme Error Rate (PER): 10.16%
============================================================
parser.add_argument('--tta_weights', type=str, default='1.0,0.6,0.6,1.0,0.0',
Aggregate Phoneme Error Rate (PER): 10.13%
============================================================
parser.add_argument('--tta_weights', type=str, default='1.0,0.6,0.0,1.0,0.0',
Aggregate Phoneme Error Rate (PER): 10.23%
============================================================
parser.add_argument('--tta_weights', type=str, default='1.0,1.0,0.0,1.0,0.0',
Aggregate Phoneme Error Rate (PER): 10.14%
============================================================
============================================================
============================================================
纯GRU + 5TTA 去高斯噪声
Total edit distance: 4308
Aggregate Phoneme Error Rate (PER): 10.41%