490 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			490 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | |||
|  | """
 | |||
|  | 高效版TTA-E参数搜索 | |||
|  | 先缓存所有基础预测,然后快速搜索参数组合 | |||
|  | """
 | |||
|  | 
 | |||
|  | import os | |||
|  | import sys | |||
|  | import torch | |||
|  | import numpy as np | |||
|  | import pandas as pd | |||
|  | from omegaconf import OmegaConf | |||
|  | import time | |||
|  | from tqdm import tqdm | |||
|  | import editdistance | |||
|  | import argparse | |||
|  | import itertools | |||
|  | import json | |||
|  | import pickle | |||
|  | from concurrent.futures import ThreadPoolExecutor, as_completed | |||
|  | 
 | |||
|  | # GPU加速库 | |||
|  | try: | |||
|  |     import cupy as cp | |||
|  |     GPU_AVAILABLE = True | |||
|  |     print("✅ CuPy available - GPU acceleration enabled") | |||
|  |     # 定义转换函数 | |||
|  |     def to_cpu(arr): | |||
|  |         return cp.asnumpy(arr) if hasattr(arr, 'device') else arr | |||
|  | except ImportError: | |||
|  |     print("⚠️ CuPy not available - falling back to CPU numpy") | |||
|  |     import numpy as cp | |||
|  |     GPU_AVAILABLE = False | |||
|  |     # 定义转换函数 | |||
|  |     def to_cpu(arr): | |||
|  |         return arr | |||
|  | 
 | |||
|  | # Add parent directories to path to import models | |||
|  | sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) | |||
|  | sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) | |||
|  | 
 | |||
|  | from model_training.rnn_model import GRUDecoder | |||
|  | from model_training_lstm.rnn_model import LSTMDecoder | |||
|  | from model_training.evaluate_model_helpers import * | |||
|  | from model_training.evaluate_model_helpers import LOGIT_TO_PHONEME | |||
|  | 
 | |||
|  | def parse_arguments(): | |||
|  |     parser = argparse.ArgumentParser(description='高效TTA-E参数搜索') | |||
|  |      | |||
|  |     # 模型和数据路径 | |||
|  |     parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline') | |||
|  |     parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn') | |||
|  |     parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final') | |||
|  |     parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv') | |||
|  |     parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test']) | |||
|  |     parser.add_argument('--gpu_number', type=int, default=0) | |||
|  |      | |||
|  |     # 搜索空间 | |||
|  |     parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') | |||
|  |     parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') | |||
|  |     parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') | |||
|  |     parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') | |||
|  |     parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') | |||
|  |      | |||
|  |     # TTA参数 | |||
|  |     parser.add_argument('--tta_noise_std', type=float, default=0.01) | |||
|  |     parser.add_argument('--tta_smooth_range', type=float, default=0.5) | |||
|  |     parser.add_argument('--tta_scale_range', type=float, default=0.05) | |||
|  |     parser.add_argument('--tta_cut_max', type=int, default=3) | |||
|  |      | |||
|  |     # 缓存控制 | |||
|  |     parser.add_argument('--cache_file', type=str, default='tta_cache.pkl') | |||
|  |     parser.add_argument('--force_recache', action='store_true') | |||
|  |     parser.add_argument('--output_file', type=str, default='search_results.json') | |||
|  |      | |||
|  |     # 并行处理 | |||
|  |     parser.add_argument('--max_workers', type=int, default=25, help='最大并行工作线程数') | |||
|  |     parser.add_argument('--batch_size', type=int, default=100, help='每批处理的配置数量') | |||
|  |      | |||
|  |     return parser.parse_args() | |||
|  | 
 | |||
|  | def generate_base_predictions(args): | |||
|  |     """生成所有基础增强的预测结果""" | |||
|  |     print("🔄 生成基础预测缓存...") | |||
|  |      | |||
|  |     # 设置设备 | |||
|  |     if torch.cuda.is_available() and args.gpu_number >= 0: | |||
|  |         device = torch.device(f'cuda:{args.gpu_number}') | |||
|  |     else: | |||
|  |         device = torch.device('cpu') | |||
|  |      | |||
|  |     # 加载模型 | |||
|  |     gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml')) | |||
|  |     lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml')) | |||
|  |      | |||
|  |     gru_model = GRUDecoder( | |||
|  |         neural_dim=gru_model_args['model']['n_input_features'], | |||
|  |         n_units=gru_model_args['model']['n_units'],  | |||
|  |         n_days=len(gru_model_args['dataset']['sessions']), | |||
|  |         n_classes=gru_model_args['dataset']['n_classes'], | |||
|  |         rnn_dropout=gru_model_args['model']['rnn_dropout'], | |||
|  |         input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'], | |||
|  |         n_layers=gru_model_args['model']['n_layers'], | |||
|  |         patch_size=gru_model_args['model']['patch_size'], | |||
|  |         patch_stride=gru_model_args['model']['patch_stride'], | |||
|  |     ) | |||
|  |      | |||
|  |     lstm_model = LSTMDecoder( | |||
|  |         neural_dim=lstm_model_args['model']['n_input_features'], | |||
|  |         n_units=lstm_model_args['model']['n_units'],  | |||
|  |         n_days=len(lstm_model_args['dataset']['sessions']), | |||
|  |         n_classes=lstm_model_args['dataset']['n_classes'], | |||
|  |         rnn_dropout=lstm_model_args['model']['rnn_dropout'], | |||
|  |         input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'], | |||
|  |         n_layers=lstm_model_args['model']['n_layers'], | |||
|  |         patch_size=lstm_model_args['model']['patch_size'], | |||
|  |         patch_stride=lstm_model_args['model']['patch_stride'], | |||
|  |     ) | |||
|  |      | |||
|  |     # 加载权重 | |||
|  |     gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'),  | |||
|  |                                weights_only=False, map_location=device) | |||
|  |     lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'),  | |||
|  |                                 weights_only=False, map_location=device) | |||
|  |      | |||
|  |     # 清理键名 | |||
|  |     for checkpoint in [gru_checkpoint, lstm_checkpoint]: | |||
|  |         for key in list(checkpoint['model_state_dict'].keys()): | |||
|  |             checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key) | |||
|  |             checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key) | |||
|  |      | |||
|  |     gru_model.load_state_dict(gru_checkpoint['model_state_dict']) | |||
|  |     lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) | |||
|  |      | |||
|  |     gru_model.to(device) | |||
|  |     lstm_model.to(device) | |||
|  |     gru_model.eval() | |||
|  |     lstm_model.eval() | |||
|  |      | |||
|  |     # 加载数据 | |||
|  |     b2txt_csv_df = pd.read_csv(args.csv_path) | |||
|  |     test_data = {} | |||
|  |     total_trials = 0 | |||
|  |      | |||
|  |     for session in gru_model_args['dataset']['sessions']: | |||
|  |         files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')] | |||
|  |         if f'data_{args.eval_type}.hdf5' in files: | |||
|  |             eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5') | |||
|  |             data = load_h5py_file(eval_file, b2txt_csv_df) | |||
|  |             test_data[session] = data | |||
|  |             total_trials += len(test_data[session]["neural_features"]) | |||
|  |      | |||
|  |     print(f"总试验数: {total_trials}") | |||
|  |      | |||
|  |     # 生成所有基础增强类型的预测 | |||
|  |     augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth'] | |||
|  |     cache = {} | |||
|  |      | |||
|  |     for aug_type in augmentation_types: | |||
|  |         print(f"处理增强类型: {aug_type}") | |||
|  |         cache[aug_type] = {} | |||
|  |          | |||
|  |         with tqdm(total=total_trials, desc=f'{aug_type}', unit='trial') as pbar: | |||
|  |             for session, data in test_data.items(): | |||
|  |                 input_layer = gru_model_args['dataset']['sessions'].index(session) | |||
|  |                 cache[aug_type][session] = [] | |||
|  |                  | |||
|  |                 for trial in range(len(data['neural_features'])): | |||
|  |                     neural_input = data['neural_features'][trial] | |||
|  |                     neural_input = np.expand_dims(neural_input, axis=0) | |||
|  |                     neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16) | |||
|  |                      | |||
|  |                     # 应用增强 | |||
|  |                     x_aug = neural_input.clone() | |||
|  |                     if aug_type == 'noise': | |||
|  |                         noise_scale = args.tta_noise_std * (0.5 + 0.5 * np.random.rand()) | |||
|  |                         x_aug = x_aug + torch.randn_like(x_aug) * noise_scale | |||
|  |                     elif aug_type == 'scale': | |||
|  |                         scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * args.tta_scale_range | |||
|  |                         x_aug = x_aug * scale_factor | |||
|  |                     elif aug_type == 'shift' and args.tta_cut_max > 0: | |||
|  |                         shift_amount = np.random.randint(1, min(args.tta_cut_max + 1, x_aug.shape[1] // 8)) | |||
|  |                         x_aug = torch.cat([x_aug[:, shift_amount:, :], x_aug[:, :shift_amount, :]], dim=1) | |||
|  |                     elif aug_type == 'smooth': | |||
|  |                         smooth_variation = (torch.rand(1).item() - 0.5) * 2 * args.tta_smooth_range | |||
|  |                         varied_smooth_std = max(0.3, gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation) | |||
|  |                      | |||
|  |                     # 应用高斯平滑 | |||
|  |                     with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16): | |||
|  |                         if aug_type == 'smooth': | |||
|  |                             x_smoothed = gauss_smooth( | |||
|  |                                 inputs=x_aug, device=device, | |||
|  |                                 smooth_kernel_std=varied_smooth_std, | |||
|  |                                 smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'], | |||
|  |                                 padding='valid', | |||
|  |                             ) | |||
|  |                         else: | |||
|  |                             x_smoothed = gauss_smooth( | |||
|  |                                 inputs=x_aug, device=device, | |||
|  |                                 smooth_kernel_std=gru_model_args['dataset']['data_transforms']['smooth_kernel_std'], | |||
|  |                                 smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'], | |||
|  |                                 padding='valid', | |||
|  |                             ) | |||
|  |                          | |||
|  |                         with torch.no_grad(): | |||
|  |                             gru_logits, _ = gru_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True) | |||
|  |                             lstm_logits, _ = lstm_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True) | |||
|  |                              | |||
|  |                             gru_probs = torch.softmax(gru_logits, dim=-1).float().cpu().numpy() | |||
|  |                             lstm_probs = torch.softmax(lstm_logits, dim=-1).float().cpu().numpy() | |||
|  |                      | |||
|  |                     # 保存预测和真实标签信息 | |||
|  |                     trial_data = { | |||
|  |                         'gru_probs': gru_probs, | |||
|  |                         'lstm_probs': lstm_probs, | |||
|  |                         'trial_info': { | |||
|  |                             'session': session, | |||
|  |                             'block_num': data['block_num'][trial], | |||
|  |                             'trial_num': data['trial_num'][trial], | |||
|  |                         } | |||
|  |                     } | |||
|  |                      | |||
|  |                     if args.eval_type == 'val': | |||
|  |                         trial_data['trial_info'].update({ | |||
|  |                             'seq_class_ids': data['seq_class_ids'][trial], | |||
|  |                             'seq_len': data['seq_len'][trial], | |||
|  |                             'sentence_label': data['sentence_label'][trial], | |||
|  |                         }) | |||
|  |                      | |||
|  |                     cache[aug_type][session].append(trial_data) | |||
|  |                     pbar.update(1) | |||
|  |      | |||
|  |     # 保存缓存 | |||
|  |     with open(args.cache_file, 'wb') as f: | |||
|  |         pickle.dump(cache, f) | |||
|  |      | |||
|  |     print(f"✅ 缓存已保存到 {args.cache_file}") | |||
|  |     return cache | |||
|  | 
 | |||
|  | def evaluate_config(cache, gru_weight, tta_weights, eval_type='val'): | |||
|  |     """评估特定参数配置 - GPU加速版本""" | |||
|  |     lstm_weight = 1.0 - gru_weight | |||
|  |     epsilon = 1e-8 | |||
|  |      | |||
|  |     total_edit_distance = 0 | |||
|  |     total_true_length = 0 | |||
|  |      | |||
|  |     # 归一化TTA权重 | |||
|  |     enabled_augmentations = [k for k, v in tta_weights.items() if v > 0] | |||
|  |     if len(enabled_augmentations) == 0: | |||
|  |         return float('inf') | |||
|  |      | |||
|  |     total_tta_weight = sum(tta_weights[k] for k in enabled_augmentations) | |||
|  |     norm_tta_weights = {k: tta_weights[k]/total_tta_weight for k in enabled_augmentations} | |||
|  |      | |||
|  |     # 遍历所有试验 | |||
|  |     for session in cache['original'].keys(): | |||
|  |         for trial_idx in range(len(cache['original'][session])): | |||
|  |             trial_info = cache['original'][session][trial_idx]['trial_info'] | |||
|  |              | |||
|  |             if eval_type == 'val' and 'seq_class_ids' not in trial_info: | |||
|  |                 continue | |||
|  |              | |||
|  |             # 集成所有启用的增强 - 使用GPU加速 | |||
|  |             weighted_gru_probs = None | |||
|  |             weighted_lstm_probs = None | |||
|  |              | |||
|  |             for aug_type in enabled_augmentations: | |||
|  |                 weight = norm_tta_weights[aug_type] | |||
|  |                 # 使用GPU数组进行计算 | |||
|  |                 if GPU_AVAILABLE: | |||
|  |                     gru_probs = cp.asarray(cache[aug_type][session][trial_idx]['gru_probs']) | |||
|  |                     lstm_probs = cp.asarray(cache[aug_type][session][trial_idx]['lstm_probs']) | |||
|  |                 else: | |||
|  |                     gru_probs = cache[aug_type][session][trial_idx]['gru_probs'] | |||
|  |                     lstm_probs = cache[aug_type][session][trial_idx]['lstm_probs'] | |||
|  |                  | |||
|  |                 if weighted_gru_probs is None: | |||
|  |                     weighted_gru_probs = weight * gru_probs | |||
|  |                     weighted_lstm_probs = weight * lstm_probs | |||
|  |                 else: | |||
|  |                     # 处理长度不同的情况 | |||
|  |                     min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1]) | |||
|  |                     weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :] | |||
|  |                     weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :] | |||
|  |              | |||
|  |             # GRU+LSTM集成 - GPU加速的log空间计算 | |||
|  |             weighted_gru_probs = weighted_gru_probs + epsilon | |||
|  |             weighted_lstm_probs = weighted_lstm_probs + epsilon | |||
|  |              | |||
|  |             if GPU_AVAILABLE: | |||
|  |                 log_ensemble_probs = (gru_weight * cp.log(weighted_gru_probs) +  | |||
|  |                                      lstm_weight * cp.log(weighted_lstm_probs)) | |||
|  |                 ensemble_probs = cp.exp(log_ensemble_probs) | |||
|  |                 ensemble_probs = ensemble_probs / ensemble_probs.sum(axis=-1, keepdims=True) | |||
|  |                  | |||
|  |                 # 解码 - GPU加速的argmax | |||
|  |                 pred_seq = cp.argmax(ensemble_probs[0], axis=-1) | |||
|  |                 pred_seq = to_cpu(pred_seq)  # 转回CPU进行后续处理 | |||
|  |             else: | |||
|  |                 log_ensemble_probs = (gru_weight * np.log(weighted_gru_probs) +  | |||
|  |                                      lstm_weight * np.log(weighted_lstm_probs)) | |||
|  |                 ensemble_probs = np.exp(log_ensemble_probs) | |||
|  |                 ensemble_probs = ensemble_probs / ensemble_probs.sum(axis=-1, keepdims=True) | |||
|  |                  | |||
|  |                 # 解码 | |||
|  |                 pred_seq = np.argmax(ensemble_probs[0], axis=-1) | |||
|  |              | |||
|  |             # 后处理(在CPU上进行,因为涉及Python列表操作) | |||
|  |             pred_seq = [int(p) for p in pred_seq if p != 0] | |||
|  |             pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] | |||
|  |             pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq] | |||
|  |              | |||
|  |             if eval_type == 'val': | |||
|  |                 true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']] | |||
|  |                 true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq] | |||
|  |                  | |||
|  |                 ed = editdistance.eval(true_phonemes, pred_phonemes) | |||
|  |                 total_edit_distance += ed | |||
|  |                 total_true_length += len(true_phonemes) | |||
|  |      | |||
|  |     if eval_type == 'val' and total_true_length > 0: | |||
|  |         return 100 * total_edit_distance / total_true_length | |||
|  |     return 0.0 | |||
|  | 
 | |||
|  | def evaluate_config_wrapper(args_tuple): | |||
|  |     """包装函数用于多线程处理""" | |||
|  |     cache, gru_weight, tta_weights, eval_type = args_tuple | |||
|  |     per = evaluate_config(cache, gru_weight, tta_weights, eval_type) | |||
|  |      | |||
|  |     return { | |||
|  |         'gru_weight': gru_weight, | |||
|  |         'lstm_weight': 1.0 - gru_weight, | |||
|  |         'tta_weights': tta_weights.copy(), | |||
|  |         'per': per | |||
|  |     } | |||
|  | 
 | |||
|  | def search_parameters(cache, args): | |||
|  |     """搜索最优参数""" | |||
|  |     print("🔍 开始参数搜索...") | |||
|  |      | |||
|  |     # 解析搜索空间 | |||
|  |     gru_weights = [float(x) for x in args.gru_weights.split(',')] | |||
|  |     noise_weights = [float(x) for x in args.tta_noise_weights.split(',')] | |||
|  |     scale_weights = [float(x) for x in args.tta_scale_weights.split(',')] | |||
|  |     shift_weights = [float(x) for x in args.tta_shift_weights.split(',')] | |||
|  |     smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')] | |||
|  |      | |||
|  |     # 生成所有配置组合 | |||
|  |     configs = [] | |||
|  |     for gru_w in gru_weights: | |||
|  |         for noise_w in noise_weights: | |||
|  |             for scale_w in scale_weights: | |||
|  |                 for shift_w in shift_weights: | |||
|  |                     for smooth_w in smooth_weights: | |||
|  |                         tta_weights = { | |||
|  |                             'original': 1.0,  # 总是包含原始数据 | |||
|  |                             'noise': noise_w, | |||
|  |                             'scale': scale_w, | |||
|  |                             'shift': shift_w, | |||
|  |                             'smooth': smooth_w | |||
|  |                         } | |||
|  |                         configs.append((cache, gru_w, tta_weights, args.eval_type)) | |||
|  |      | |||
|  |     total_configs = len(configs) | |||
|  |     print(f"搜索空间: {total_configs} 个配置") | |||
|  |     print(f"使用 {args.max_workers} 个线程并行处理...") | |||
|  |      | |||
|  |     best_per = float('inf') | |||
|  |     best_config = None | |||
|  |     all_results = [] | |||
|  |     completed_count = 0 | |||
|  |      | |||
|  |     # 使用线程池并行处理 | |||
|  |     with ThreadPoolExecutor(max_workers=args.max_workers) as executor: | |||
|  |         # 提交所有任务 | |||
|  |         future_to_config = { | |||
|  |             executor.submit(evaluate_config_wrapper, config): config  | |||
|  |             for config in configs | |||
|  |         } | |||
|  |          | |||
|  |         # 处理完成的任务 | |||
|  |         for future in as_completed(future_to_config): | |||
|  |             try: | |||
|  |                 result = future.result() | |||
|  |                 all_results.append(result) | |||
|  |                 completed_count += 1 | |||
|  |                  | |||
|  |                 # 更新最佳结果 | |||
|  |                 if result['per'] < best_per: | |||
|  |                     best_per = result['per'] | |||
|  |                     best_config = result | |||
|  |                     tw = result['tta_weights'] | |||
|  |                     print(f"🎯 新最佳[{completed_count}/{total_configs}]: PER={best_per:.3f}% | " | |||
|  |                           f"GRU={result['gru_weight']:.1f} | " | |||
|  |                           f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})") | |||
|  |                  | |||
|  |                 # 定期进度报告 | |||
|  |                 if completed_count % args.batch_size == 0: | |||
|  |                     progress = 100 * completed_count / total_configs | |||
|  |                     print(f"📊 进度: {completed_count}/{total_configs} ({progress:.1f}%) | 当前最优PER: {best_per:.3f}%") | |||
|  |                 elif completed_count % 50 == 0:  # 更频繁的简单进度 | |||
|  |                     print(f"... {completed_count}/{total_configs} ...", end='', flush=True) | |||
|  |                  | |||
|  |             except Exception as e: | |||
|  |                 completed_count += 1 | |||
|  |                 config = future_to_config[future] | |||
|  |                 print(f"❌ 配置失败: {config[1]:.1f}, 错误: {e}") | |||
|  |                 # 添加失败的结果记录 | |||
|  |                 all_results.append({ | |||
|  |                     'gru_weight': config[1], | |||
|  |                     'lstm_weight': 1.0 - config[1], | |||
|  |                     'tta_weights': config[2], | |||
|  |                     'per': float('inf'), | |||
|  |                     'error': str(e) | |||
|  |                 }) | |||
|  |      | |||
|  |     print(f"\n✅ 所有任务完成!") | |||
|  |      | |||
|  |     # 找到真正的最佳结果(防止异常情况) | |||
|  |     valid_results = [r for r in all_results if 'error' not in r and r['per'] != float('inf')] | |||
|  |     if valid_results: | |||
|  |         best_config = min(valid_results, key=lambda x: x['per']) | |||
|  |      | |||
|  |     return all_results, best_config | |||
|  | 
 | |||
|  | def main(): | |||
|  |     args = parse_arguments() | |||
|  |      | |||
|  |     print("🚀 高效TTA-E参数搜索") | |||
|  |     print("=" * 50) | |||
|  |      | |||
|  |     # 第一阶段:生成或加载缓存 | |||
|  |     if args.force_recache or not os.path.exists(args.cache_file): | |||
|  |         cache = generate_base_predictions(args) | |||
|  |     else: | |||
|  |         print(f"📁 加载现有缓存: {args.cache_file}") | |||
|  |         with open(args.cache_file, 'rb') as f: | |||
|  |             cache = pickle.load(f) | |||
|  |         print("✅ 缓存加载完成") | |||
|  |      | |||
|  |     # 第二阶段:参数搜索 | |||
|  |     all_results, best_config = search_parameters(cache, args) | |||
|  |      | |||
|  |     # 保存结果 | |||
|  |     results = { | |||
|  |         'best_config': best_config, | |||
|  |         'all_results': all_results, | |||
|  |         'args': vars(args), | |||
|  |         'timestamp': time.strftime("%Y-%m-%d %H:%M:%S") | |||
|  |     } | |||
|  |      | |||
|  |     with open(args.output_file, 'w') as f: | |||
|  |         json.dump(results, f, indent=2) | |||
|  |      | |||
|  |     print("\n" + "=" * 50) | |||
|  |     print("🏆 搜索完成!") | |||
|  |      | |||
|  |     if best_config is not None: | |||
|  |         print(f"最佳配置: PER={best_config['per']:.3f}%") | |||
|  |         print(f"GRU权重: {best_config['gru_weight']:.1f}") | |||
|  |         print(f"TTA权重: {best_config['tta_weights']}") | |||
|  |         print(f"结果保存到: {args.output_file}") | |||
|  |          | |||
|  |         # 显示top-10 | |||
|  |         sorted_results = sorted([r for r in all_results if 'error' not in r and r['per'] != float('inf')],  | |||
|  |                                key=lambda x: x['per'])[:10] | |||
|  |         print(f"\n📊 Top-10配置:") | |||
|  |         for i, result in enumerate(sorted_results): | |||
|  |             tw = result['tta_weights'] | |||
|  |             print(f"{i+1:2d}. PER={result['per']:6.3f}% | GRU={result['gru_weight']:.1f} | " | |||
|  |                   f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})") | |||
|  |     else: | |||
|  |         print("❌ 未找到有效的配置结果!所有配置都失败了。") | |||
|  |         print("请检查缓存数据和评估逻辑。") | |||
|  |      | |||
|  |     # 搜索统计 | |||
|  |     valid_results = [r for r in all_results if 'error' not in r and r['per'] != float('inf')] | |||
|  |     print(f"\n📈 搜索统计:") | |||
|  |     print(f"  总配置数: {len(all_results)}") | |||
|  |     print(f"  成功配置数: {len(valid_results)}") | |||
|  |     print(f"  失败配置数: {len(all_results) - len(valid_results)}") | |||
|  |      | |||
|  |     if valid_results: | |||
|  |         valid_pers = [r['per'] for r in valid_results] | |||
|  |         print(f"  PER范围: {min(valid_pers):.3f}% - {max(valid_pers):.3f}%") | |||
|  |         print(f"  平均PER: {sum(valid_pers)/len(valid_pers):.3f}%") | |||
|  | 
 | |||
|  | if __name__ == "__main__": | |||
|  |     main() |