473 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			473 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | |||
|  | """
 | |||
|  | TTA-E参数搜索脚本 | |||
|  | 先运行所有基础配置获取预测结果,然后搜索最优的参数组合 | |||
|  | 避免重复模型推理,提高搜索效率 | |||
|  | """
 | |||
|  | 
 | |||
|  | import os | |||
|  | import sys | |||
|  | import torch | |||
|  | import numpy as np | |||
|  | import pandas as pd | |||
|  | from omegaconf import OmegaConf | |||
|  | import time | |||
|  | from tqdm import tqdm | |||
|  | import editdistance | |||
|  | import argparse | |||
|  | import itertools | |||
|  | import json | |||
|  | from concurrent.futures import ProcessPoolExecutor | |||
|  | import pickle | |||
|  | 
 | |||
|  | # Add parent directories to path to import models | |||
|  | sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) | |||
|  | sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) | |||
|  | 
 | |||
|  | from model_training.rnn_model import GRUDecoder | |||
|  | from model_training_lstm.rnn_model import LSTMDecoder | |||
|  | from model_training.evaluate_model_helpers import * | |||
|  | 
 | |||
|  | def parse_arguments(): | |||
|  |     parser = argparse.ArgumentParser(description='TTA-E Parameter Search for optimal configuration.') | |||
|  |     parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', | |||
|  |                         help='Path to the pretrained GRU model directory.') | |||
|  |     parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', | |||
|  |                         help='Path to the pretrained LSTM model directory.') | |||
|  |     parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', | |||
|  |                         help='Path to the dataset directory.') | |||
|  |     parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', | |||
|  |                         help='Path to the CSV file with metadata.') | |||
|  |     parser.add_argument('--gpu_number', type=int, default=0, | |||
|  |                         help='GPU number to use for model inference.') | |||
|  |     parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'], | |||
|  |                         help='Evaluation type.') | |||
|  |      | |||
|  |     # 搜索空间参数 | |||
|  |     parser.add_argument('--gru_weights', type=str, default='0.4,0.5,0.6,0.7,0.8,1.0', | |||
|  |                         help='Comma-separated GRU weights to search.') | |||
|  |     parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.5,1.0', | |||
|  |                         help='Comma-separated noise weights to search.') | |||
|  |     parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.5,1.0', | |||
|  |                         help='Comma-separated scale weights to search.') | |||
|  |     parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.5,1.0', | |||
|  |                         help='Comma-separated shift weights to search.') | |||
|  |     parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.5,1.0', | |||
|  |                         help='Comma-separated smooth weights to search.') | |||
|  |      | |||
|  |     # TTA固定参数 | |||
|  |     parser.add_argument('--tta_noise_std', type=float, default=0.01, | |||
|  |                         help='Standard deviation for TTA noise augmentation.') | |||
|  |     parser.add_argument('--tta_smooth_range', type=float, default=0.5, | |||
|  |                         help='Range for TTA smoothing kernel variation.') | |||
|  |     parser.add_argument('--tta_scale_range', type=float, default=0.05, | |||
|  |                         help='Range for TTA amplitude scaling.') | |||
|  |     parser.add_argument('--tta_cut_max', type=int, default=3, | |||
|  |                         help='Maximum timesteps for TTA shift.') | |||
|  |      | |||
|  |     # 输出控制 | |||
|  |     parser.add_argument('--cache_file', type=str, default='tta_predictions_cache.pkl', | |||
|  |                         help='File to cache model predictions.') | |||
|  |     parser.add_argument('--results_file', type=str, default='parameter_search_results.json', | |||
|  |                         help='File to save search results.') | |||
|  |     parser.add_argument('--force_recache', action='store_true', | |||
|  |                         help='Force re-computation of predictions cache.') | |||
|  |      | |||
|  |     return parser.parse_args() | |||
|  | 
 | |||
|  | def generate_all_base_configs(): | |||
|  |     """生成所有需要运行的基础配置(每种增强单独运行)""" | |||
|  |     base_configs = [ | |||
|  |         {'name': 'original', 'tta_weights': {'original': 1.0, 'noise': 0.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 0.0}}, | |||
|  |         {'name': 'noise', 'tta_weights': {'original': 0.0, 'noise': 1.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 0.0}}, | |||
|  |         {'name': 'scale', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 1.0, 'shift': 0.0, 'smooth': 0.0}}, | |||
|  |         {'name': 'shift', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 0.0, 'shift': 1.0, 'smooth': 0.0}}, | |||
|  |         {'name': 'smooth', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 1.0}}, | |||
|  |     ] | |||
|  |     return base_configs | |||
|  | 
 | |||
|  | def run_single_tta_prediction(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,  | |||
|  |                              device, aug_type, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max): | |||
|  |     """运行单个TTA增强的预测""" | |||
|  |     x_augmented = x.clone() | |||
|  |      | |||
|  |     # Get default smoothing parameters | |||
|  |     default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] | |||
|  |     default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size'] | |||
|  |      | |||
|  |     if aug_type == 'original': | |||
|  |         pass | |||
|  |     elif aug_type == 'noise': | |||
|  |         noise_scale = tta_noise_std * (0.5 + 0.5 * np.random.rand()) | |||
|  |         noise = torch.randn_like(x_augmented) * noise_scale | |||
|  |         x_augmented = x_augmented + noise | |||
|  |     elif aug_type == 'scale': | |||
|  |         scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range | |||
|  |         x_augmented = x_augmented * scale_factor | |||
|  |     elif aug_type == 'shift' and tta_cut_max > 0: | |||
|  |         shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8)) | |||
|  |         x_augmented = torch.cat([x_augmented[:, shift_amount:, :],  | |||
|  |                                x_augmented[:, :shift_amount, :]], dim=1) | |||
|  |     elif aug_type == 'smooth': | |||
|  |         smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range | |||
|  |         varied_smooth_std = max(0.3, default_smooth_std + smooth_variation) | |||
|  | 
 | |||
|  |     # Use autocast for efficiency | |||
|  |     with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16): | |||
|  |          | |||
|  |         # Apply Gaussian smoothing | |||
|  |         if aug_type == 'smooth': | |||
|  |             x_smoothed = gauss_smooth( | |||
|  |                 inputs=x_augmented,  | |||
|  |                 device=device, | |||
|  |                 smooth_kernel_std=varied_smooth_std, | |||
|  |                 smooth_kernel_size=default_smooth_size, | |||
|  |                 padding='valid', | |||
|  |             ) | |||
|  |         else: | |||
|  |             x_smoothed = gauss_smooth( | |||
|  |                 inputs=x_augmented,  | |||
|  |                 device=device, | |||
|  |                 smooth_kernel_std=default_smooth_std, | |||
|  |                 smooth_kernel_size=default_smooth_size, | |||
|  |                 padding='valid', | |||
|  |             ) | |||
|  | 
 | |||
|  |         with torch.no_grad(): | |||
|  |             # Get GRU logits and convert to probabilities | |||
|  |             gru_logits, _ = gru_model( | |||
|  |                 x=x_smoothed, | |||
|  |                 day_idx=torch.tensor([input_layer], device=device), | |||
|  |                 states=None, | |||
|  |                 return_state=True, | |||
|  |             ) | |||
|  |             gru_probs = torch.softmax(gru_logits, dim=-1) | |||
|  |              | |||
|  |             # Get LSTM logits and convert to probabilities | |||
|  |             lstm_logits, _ = lstm_model( | |||
|  |                 x=x_smoothed, | |||
|  |                 day_idx=torch.tensor([input_layer], device=device), | |||
|  |                 states=None, | |||
|  |                 return_state=True, | |||
|  |             ) | |||
|  |             lstm_probs = torch.softmax(lstm_logits, dim=-1) | |||
|  |              | |||
|  |             return gru_probs.float().cpu().numpy(), lstm_probs.float().cpu().numpy() | |||
|  | 
 | |||
|  | def cache_model_predictions(args): | |||
|  |     """缓存所有基础模型预测结果""" | |||
|  |     print("=== 第一阶段:缓存模型预测结果 ===") | |||
|  |      | |||
|  |     # 设置设备 | |||
|  |     if torch.cuda.is_available() and args.gpu_number >= 0: | |||
|  |         device = torch.device(f'cuda:{args.gpu_number}') | |||
|  |     else: | |||
|  |         device = torch.device('cpu') | |||
|  |     print(f'Using device: {device}') | |||
|  |      | |||
|  |     # 加载模型 | |||
|  |     print("Loading models...") | |||
|  |     gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml')) | |||
|  |     lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml')) | |||
|  |      | |||
|  |     # Define GRU model | |||
|  |     gru_model = GRUDecoder( | |||
|  |         neural_dim=gru_model_args['model']['n_input_features'], | |||
|  |         n_units=gru_model_args['model']['n_units'],  | |||
|  |         n_days=len(gru_model_args['dataset']['sessions']), | |||
|  |         n_classes=gru_model_args['dataset']['n_classes'], | |||
|  |         rnn_dropout=gru_model_args['model']['rnn_dropout'], | |||
|  |         input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'], | |||
|  |         n_layers=gru_model_args['model']['n_layers'], | |||
|  |         patch_size=gru_model_args['model']['patch_size'], | |||
|  |         patch_stride=gru_model_args['model']['patch_stride'], | |||
|  |     ) | |||
|  |      | |||
|  |     # Load GRU model weights | |||
|  |     gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'),  | |||
|  |                                weights_only=False, map_location=device) | |||
|  |     for key in list(gru_checkpoint['model_state_dict'].keys()): | |||
|  |         gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key) | |||
|  |         gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key) | |||
|  |     gru_model.load_state_dict(gru_checkpoint['model_state_dict']) | |||
|  |      | |||
|  |     # Define LSTM model | |||
|  |     lstm_model = LSTMDecoder( | |||
|  |         neural_dim=lstm_model_args['model']['n_input_features'], | |||
|  |         n_units=lstm_model_args['model']['n_units'],  | |||
|  |         n_days=len(lstm_model_args['dataset']['sessions']), | |||
|  |         n_classes=lstm_model_args['dataset']['n_classes'], | |||
|  |         rnn_dropout=lstm_model_args['model']['rnn_dropout'], | |||
|  |         input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'], | |||
|  |         n_layers=lstm_model_args['model']['n_layers'], | |||
|  |         patch_size=lstm_model_args['model']['patch_size'], | |||
|  |         patch_stride=lstm_model_args['model']['patch_stride'], | |||
|  |     ) | |||
|  |      | |||
|  |     # Load LSTM model weights | |||
|  |     lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'),  | |||
|  |                                 weights_only=False, map_location=device) | |||
|  |     for key in list(lstm_checkpoint['model_state_dict'].keys()): | |||
|  |         lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key) | |||
|  |         lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key) | |||
|  |     lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) | |||
|  |      | |||
|  |     gru_model.to(device) | |||
|  |     lstm_model.to(device) | |||
|  |     gru_model.eval() | |||
|  |     lstm_model.eval() | |||
|  |      | |||
|  |     # 加载数据 | |||
|  |     print("Loading data...") | |||
|  |     b2txt_csv_df = pd.read_csv(args.csv_path) | |||
|  |     test_data = {} | |||
|  |     total_trials = 0 | |||
|  |      | |||
|  |     for session in gru_model_args['dataset']['sessions']: | |||
|  |         files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')] | |||
|  |         if f'data_{args.eval_type}.hdf5' in files: | |||
|  |             eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5') | |||
|  |             data = load_h5py_file(eval_file, b2txt_csv_df) | |||
|  |             test_data[session] = data | |||
|  |             total_trials += len(test_data[session]["neural_features"]) | |||
|  |             print(f'Loaded {len(test_data[session]["neural_features"])} {args.eval_type} trials for session {session}.') | |||
|  |      | |||
|  |     print(f'Total trials: {total_trials}') | |||
|  |      | |||
|  |     # 生成所有基础预测 | |||
|  |     base_configs = generate_all_base_configs() | |||
|  |     predictions_cache = {} | |||
|  |      | |||
|  |     for config in base_configs: | |||
|  |         config_name = config['name'] | |||
|  |         print(f"\nGenerating predictions for: {config_name}") | |||
|  |          | |||
|  |         predictions_cache[config_name] = {} | |||
|  |          | |||
|  |         with tqdm(total=total_trials, desc=f'Processing {config_name}', unit='trial') as pbar: | |||
|  |             for session, data in test_data.items(): | |||
|  |                 input_layer = gru_model_args['dataset']['sessions'].index(session) | |||
|  |                 session_predictions = [] | |||
|  |                  | |||
|  |                 for trial in range(len(data['neural_features'])): | |||
|  |                     neural_input = data['neural_features'][trial] | |||
|  |                     neural_input = np.expand_dims(neural_input, axis=0) | |||
|  |                     neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16) | |||
|  |                      | |||
|  |                     gru_probs, lstm_probs = run_single_tta_prediction( | |||
|  |                         neural_input, input_layer, gru_model, lstm_model,  | |||
|  |                         gru_model_args, lstm_model_args, device, config_name, | |||
|  |                         args.tta_noise_std, args.tta_smooth_range,  | |||
|  |                         args.tta_scale_range, args.tta_cut_max | |||
|  |                     ) | |||
|  |                      | |||
|  |                     session_predictions.append({ | |||
|  |                         'gru_probs': gru_probs, | |||
|  |                         'lstm_probs': lstm_probs, | |||
|  |                         'trial_info': { | |||
|  |                             'session': session, | |||
|  |                             'block_num': data['block_num'][trial], | |||
|  |                             'trial_num': data['trial_num'][trial], | |||
|  |                             'seq_class_ids': data['seq_class_ids'][trial] if args.eval_type == 'val' else None, | |||
|  |                             'seq_len': data['seq_len'][trial] if args.eval_type == 'val' else None, | |||
|  |                             'sentence_label': data['sentence_label'][trial] if args.eval_type == 'val' else None, | |||
|  |                         } | |||
|  |                     }) | |||
|  |                      | |||
|  |                     pbar.update(1) | |||
|  |                  | |||
|  |                 predictions_cache[config_name][session] = session_predictions | |||
|  |      | |||
|  |     # 保存缓存 | |||
|  |     print(f"\nSaving predictions cache to {args.cache_file}...") | |||
|  |     with open(args.cache_file, 'wb') as f: | |||
|  |         pickle.dump(predictions_cache, f) | |||
|  |      | |||
|  |     print("✓ Predictions cache saved successfully!") | |||
|  |     return predictions_cache | |||
|  | 
 | |||
|  | def ensemble_and_evaluate(predictions_cache, gru_weight, tta_weights, eval_type='val'): | |||
|  |     """基于缓存的预测结果进行集成和评估""" | |||
|  |     lstm_weight = 1.0 - gru_weight | |||
|  |     epsilon = 1e-8 | |||
|  |      | |||
|  |     total_trials = 0 | |||
|  |     total_edit_distance = 0 | |||
|  |     total_true_length = 0 | |||
|  |      | |||
|  |     # 检查哪些增强被启用 | |||
|  |     enabled_augmentations = [aug_type for aug_type, weight in tta_weights.items() if weight > 0] | |||
|  |     if len(enabled_augmentations) == 0: | |||
|  |         return float('inf')  # 无效配置 | |||
|  |      | |||
|  |     # 归一化TTA权重 | |||
|  |     total_tta_weight = sum(weight for weight in tta_weights.values() if weight > 0) | |||
|  |     normalized_tta_weights = {k: v/total_tta_weight for k, v in tta_weights.items() if v > 0} | |||
|  |      | |||
|  |     for session in predictions_cache['original'].keys(): | |||
|  |         session_predictions = predictions_cache['original'][session] | |||
|  |          | |||
|  |         for trial_idx in range(len(session_predictions)): | |||
|  |             trial_info = session_predictions[trial_idx]['trial_info'] | |||
|  |              | |||
|  |             if eval_type == 'val' and trial_info['seq_class_ids'] is None: | |||
|  |                 continue | |||
|  |              | |||
|  |             # 收集所有启用增强的概率 | |||
|  |             weighted_gru_probs = None | |||
|  |             weighted_lstm_probs = None | |||
|  |              | |||
|  |             for aug_type in enabled_augmentations: | |||
|  |                 weight = normalized_tta_weights[aug_type] | |||
|  |                 gru_probs = predictions_cache[aug_type][session][trial_idx]['gru_probs'] | |||
|  |                 lstm_probs = predictions_cache[aug_type][session][trial_idx]['lstm_probs'] | |||
|  |                  | |||
|  |                 gru_probs = torch.tensor(gru_probs) | |||
|  |                 lstm_probs = torch.tensor(lstm_probs) | |||
|  |                  | |||
|  |                 if weighted_gru_probs is None: | |||
|  |                     weighted_gru_probs = weight * gru_probs | |||
|  |                     weighted_lstm_probs = weight * lstm_probs | |||
|  |                 else: | |||
|  |                     # 处理序列长度不同的情况 | |||
|  |                     min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1]) | |||
|  |                     weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :] | |||
|  |                     weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :] | |||
|  |              | |||
|  |             # 集成GRU和LSTM(几何平均) | |||
|  |             weighted_gru_probs = weighted_gru_probs + epsilon | |||
|  |             weighted_lstm_probs = weighted_lstm_probs + epsilon | |||
|  |              | |||
|  |             log_ensemble_probs = (gru_weight * torch.log(weighted_gru_probs) +  | |||
|  |                                  lstm_weight * torch.log(weighted_lstm_probs)) | |||
|  |             ensemble_probs = torch.exp(log_ensemble_probs) | |||
|  |             ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True) | |||
|  |              | |||
|  |             # 解码预测序列 | |||
|  |             pred_seq = torch.argmax(ensemble_probs[0], dim=-1).numpy() | |||
|  |             pred_seq = [int(p) for p in pred_seq if p != 0] | |||
|  |             pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] | |||
|  |             pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq] | |||
|  |              | |||
|  |             if eval_type == 'val': | |||
|  |                 # 计算PER | |||
|  |                 true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']] | |||
|  |                 true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq] | |||
|  |                  | |||
|  |                 ed = editdistance.eval(true_phonemes, pred_phonemes) | |||
|  |                 total_edit_distance += ed | |||
|  |                 total_true_length += len(true_phonemes) | |||
|  |              | |||
|  |             total_trials += 1 | |||
|  |      | |||
|  |     if eval_type == 'val' and total_true_length > 0: | |||
|  |         per = 100 * total_edit_distance / total_true_length | |||
|  |         return per | |||
|  |     else: | |||
|  |         return 0.0  # test模式返回0 | |||
|  | 
 | |||
|  | def search_optimal_parameters(predictions_cache, args): | |||
|  |     """搜索最优参数组合""" | |||
|  |     print("\n=== 第二阶段:参数搜索 ===") | |||
|  |      | |||
|  |     # 解析搜索空间 | |||
|  |     gru_weights = [float(x) for x in args.gru_weights.split(',')] | |||
|  |     noise_weights = [float(x) for x in args.tta_noise_weights.split(',')] | |||
|  |     scale_weights = [float(x) for x in args.tta_scale_weights.split(',')] | |||
|  |     shift_weights = [float(x) for x in args.tta_shift_weights.split(',')] | |||
|  |     smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')] | |||
|  |      | |||
|  |     print(f"Search space:") | |||
|  |     print(f"  GRU weights: {gru_weights}") | |||
|  |     print(f"  Noise weights: {noise_weights}") | |||
|  |     print(f"  Scale weights: {scale_weights}") | |||
|  |     print(f"  Shift weights: {shift_weights}") | |||
|  |     print(f"  Smooth weights: {smooth_weights}") | |||
|  |      | |||
|  |     # 生成所有参数组合 | |||
|  |     all_combinations = list(itertools.product(gru_weights, noise_weights, scale_weights, shift_weights, smooth_weights)) | |||
|  |     total_combinations = len(all_combinations) | |||
|  |     print(f"Total combinations to evaluate: {total_combinations}") | |||
|  |      | |||
|  |     best_per = float('inf') | |||
|  |     best_config = None | |||
|  |     results = [] | |||
|  |      | |||
|  |     with tqdm(total=total_combinations, desc='Parameter search', unit='config') as pbar: | |||
|  |         for gru_w, noise_w, scale_w, shift_w, smooth_w in all_combinations: | |||
|  |             tta_weights = { | |||
|  |                 'original': 1.0,  # 总是包含原始数据 | |||
|  |                 'noise': noise_w, | |||
|  |                 'scale': scale_w, | |||
|  |                 'shift': shift_w, | |||
|  |                 'smooth': smooth_w | |||
|  |             } | |||
|  |              | |||
|  |             per = ensemble_and_evaluate(predictions_cache, gru_w, tta_weights, args.eval_type) | |||
|  |              | |||
|  |             config = { | |||
|  |                 'gru_weight': gru_w, | |||
|  |                 'lstm_weight': 1.0 - gru_w, | |||
|  |                 'tta_weights': tta_weights, | |||
|  |                 'per': per | |||
|  |             } | |||
|  |             results.append(config) | |||
|  |              | |||
|  |             if per < best_per: | |||
|  |                 best_per = per | |||
|  |                 best_config = config | |||
|  |                 print(f"\n🎯 New best PER: {per:.3f}%") | |||
|  |                 print(f"   GRU weight: {gru_w:.1f}") | |||
|  |                 print(f"   TTA weights: {tta_weights}") | |||
|  |              | |||
|  |             pbar.update(1) | |||
|  |      | |||
|  |     return results, best_config | |||
|  | 
 | |||
|  | def main(): | |||
|  |     args = parse_arguments() | |||
|  |      | |||
|  |     print("TTA-E Parameter Search") | |||
|  |     print("=" * 50) | |||
|  |      | |||
|  |     # 第一阶段:缓存预测结果 | |||
|  |     if args.force_recache or not os.path.exists(args.cache_file): | |||
|  |         predictions_cache = cache_model_predictions(args) | |||
|  |     else: | |||
|  |         print(f"Loading existing predictions cache from {args.cache_file}...") | |||
|  |         with open(args.cache_file, 'rb') as f: | |||
|  |             predictions_cache = pickle.load(f) | |||
|  |         print("✓ Cache loaded successfully!") | |||
|  |      | |||
|  |     # 第二阶段:参数搜索 | |||
|  |     results, best_config = search_optimal_parameters(predictions_cache, args) | |||
|  |      | |||
|  |     # 保存结果 | |||
|  |     print(f"\n=== 搜索完成 ===") | |||
|  |     print(f"Best configuration:") | |||
|  |     print(f"  PER: {best_config['per']:.3f}%") | |||
|  |     print(f"  GRU weight: {best_config['gru_weight']:.1f}") | |||
|  |     print(f"  LSTM weight: {best_config['lstm_weight']:.1f}") | |||
|  |     print(f"  TTA weights: {best_config['tta_weights']}") | |||
|  |      | |||
|  |     # 保存所有结果 | |||
|  |     search_results = { | |||
|  |         'best_config': best_config, | |||
|  |         'all_results': results, | |||
|  |         'search_args': vars(args), | |||
|  |         'timestamp': time.strftime("%Y-%m-%d %H:%M:%S") | |||
|  |     } | |||
|  |      | |||
|  |     with open(args.results_file, 'w') as f: | |||
|  |         json.dump(search_results, f, indent=2) | |||
|  |      | |||
|  |     print(f"\n✓ Results saved to {args.results_file}") | |||
|  |      | |||
|  |     # 显示前10个最佳配置 | |||
|  |     sorted_results = sorted(results, key=lambda x: x['per']) | |||
|  |     print(f"\nTop 10 configurations:") | |||
|  |     for i, config in enumerate(sorted_results[:10]): | |||
|  |         print(f"{i+1:2d}. PER={config['per']:6.3f}% | GRU={config['gru_weight']:.1f} | TTA={config['tta_weights']}") | |||
|  | 
 | |||
|  | if __name__ == "__main__": | |||
|  |     main() |