#!/usr/bin/env python3 """ 高效版TTA-E参数搜索 先缓存所有基础预测,然后快速搜索参数组合 """ import os import sys import torch import numpy as np import pandas as pd from omegaconf import OmegaConf import time from tqdm import tqdm import editdistance import argparse import itertools import json import pickle # Add parent directories to path to import models sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) from model_training.rnn_model import GRUDecoder from model_training_lstm.rnn_model import LSTMDecoder from model_training.evaluate_model_helpers import * def parse_arguments(): parser = argparse.ArgumentParser(description='高效TTA-E参数搜索') # 模型和数据路径 parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline') parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn') parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final') parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv') parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test']) parser.add_argument('--gpu_number', type=int, default=0) # 搜索空间 parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') # TTA参数 parser.add_argument('--tta_noise_std', type=float, default=0.01) parser.add_argument('--tta_smooth_range', type=float, default=0.5) parser.add_argument('--tta_scale_range', type=float, default=0.05) parser.add_argument('--tta_cut_max', type=int, default=3) # 缓存控制 parser.add_argument('--cache_file', type=str, default='tta_cache.pkl') parser.add_argument('--force_recache', action='store_true') parser.add_argument('--output_file', type=str, default='search_results.json') return parser.parse_args() def generate_base_predictions(args): """生成所有基础增强的预测结果""" print("🔄 生成基础预测缓存...") # 设置设备 if torch.cuda.is_available() and args.gpu_number >= 0: device = torch.device(f'cuda:{args.gpu_number}') else: device = torch.device('cpu') # 加载模型 gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml')) lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml')) gru_model = GRUDecoder( neural_dim=gru_model_args['model']['n_input_features'], n_units=gru_model_args['model']['n_units'], n_days=len(gru_model_args['dataset']['sessions']), n_classes=gru_model_args['dataset']['n_classes'], rnn_dropout=gru_model_args['model']['rnn_dropout'], input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'], n_layers=gru_model_args['model']['n_layers'], patch_size=gru_model_args['model']['patch_size'], patch_stride=gru_model_args['model']['patch_stride'], ) lstm_model = LSTMDecoder( neural_dim=lstm_model_args['model']['n_input_features'], n_units=lstm_model_args['model']['n_units'], n_days=len(lstm_model_args['dataset']['sessions']), n_classes=lstm_model_args['dataset']['n_classes'], rnn_dropout=lstm_model_args['model']['rnn_dropout'], input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'], n_layers=lstm_model_args['model']['n_layers'], patch_size=lstm_model_args['model']['patch_size'], patch_stride=lstm_model_args['model']['patch_stride'], ) # 加载权重 gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) # 清理键名 for checkpoint in [gru_checkpoint, lstm_checkpoint]: for key in list(checkpoint['model_state_dict'].keys()): checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key) checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key) gru_model.load_state_dict(gru_checkpoint['model_state_dict']) lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) gru_model.to(device) lstm_model.to(device) gru_model.eval() lstm_model.eval() # 加载数据 b2txt_csv_df = pd.read_csv(args.csv_path) test_data = {} total_trials = 0 for session in gru_model_args['dataset']['sessions']: files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')] if f'data_{args.eval_type}.hdf5' in files: eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5') data = load_h5py_file(eval_file, b2txt_csv_df) test_data[session] = data total_trials += len(test_data[session]["neural_features"]) print(f"总试验数: {total_trials}") # 生成所有基础增强类型的预测 augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth'] cache = {} for aug_type in augmentation_types: print(f"处理增强类型: {aug_type}") cache[aug_type] = {} with tqdm(total=total_trials, desc=f'{aug_type}', unit='trial') as pbar: for session, data in test_data.items(): input_layer = gru_model_args['dataset']['sessions'].index(session) cache[aug_type][session] = [] for trial in range(len(data['neural_features'])): neural_input = data['neural_features'][trial] neural_input = np.expand_dims(neural_input, axis=0) neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16) # 应用增强 x_aug = neural_input.clone() if aug_type == 'noise': noise_scale = args.tta_noise_std * (0.5 + 0.5 * np.random.rand()) x_aug = x_aug + torch.randn_like(x_aug) * noise_scale elif aug_type == 'scale': scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * args.tta_scale_range x_aug = x_aug * scale_factor elif aug_type == 'shift' and args.tta_cut_max > 0: shift_amount = np.random.randint(1, min(args.tta_cut_max + 1, x_aug.shape[1] // 8)) x_aug = torch.cat([x_aug[:, shift_amount:, :], x_aug[:, :shift_amount, :]], dim=1) elif aug_type == 'smooth': smooth_variation = (torch.rand(1).item() - 0.5) * 2 * args.tta_smooth_range varied_smooth_std = max(0.3, gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation) # 应用高斯平滑 with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16): if aug_type == 'smooth': x_smoothed = gauss_smooth( inputs=x_aug, device=device, smooth_kernel_std=varied_smooth_std, smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'], padding='valid', ) else: x_smoothed = gauss_smooth( inputs=x_aug, device=device, smooth_kernel_std=gru_model_args['dataset']['data_transforms']['smooth_kernel_std'], smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'], padding='valid', ) with torch.no_grad(): gru_logits, _ = gru_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True) lstm_logits, _ = lstm_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True) gru_probs = torch.softmax(gru_logits, dim=-1).float().cpu().numpy() lstm_probs = torch.softmax(lstm_logits, dim=-1).float().cpu().numpy() # 保存预测和真实标签信息 trial_data = { 'gru_probs': gru_probs, 'lstm_probs': lstm_probs, 'trial_info': { 'session': session, 'block_num': data['block_num'][trial], 'trial_num': data['trial_num'][trial], } } if args.eval_type == 'val': trial_data['trial_info'].update({ 'seq_class_ids': data['seq_class_ids'][trial], 'seq_len': data['seq_len'][trial], 'sentence_label': data['sentence_label'][trial], }) cache[aug_type][session].append(trial_data) pbar.update(1) # 保存缓存 with open(args.cache_file, 'wb') as f: pickle.dump(cache, f) print(f"✅ 缓存已保存到 {args.cache_file}") return cache def evaluate_config(cache, gru_weight, tta_weights, eval_type='val'): """评估特定参数配置""" lstm_weight = 1.0 - gru_weight epsilon = 1e-8 total_edit_distance = 0 total_true_length = 0 # 归一化TTA权重 enabled_augmentations = [k for k, v in tta_weights.items() if v > 0] if len(enabled_augmentations) == 0: return float('inf') total_tta_weight = sum(tta_weights[k] for k in enabled_augmentations) norm_tta_weights = {k: tta_weights[k]/total_tta_weight for k in enabled_augmentations} # 遍历所有试验 for session in cache['original'].keys(): for trial_idx in range(len(cache['original'][session])): trial_info = cache['original'][session][trial_idx]['trial_info'] if eval_type == 'val' and 'seq_class_ids' not in trial_info: continue # 集成所有启用的增强 weighted_gru_probs = None weighted_lstm_probs = None for aug_type in enabled_augmentations: weight = norm_tta_weights[aug_type] gru_probs = torch.tensor(cache[aug_type][session][trial_idx]['gru_probs']) lstm_probs = torch.tensor(cache[aug_type][session][trial_idx]['lstm_probs']) if weighted_gru_probs is None: weighted_gru_probs = weight * gru_probs weighted_lstm_probs = weight * lstm_probs else: # 处理长度不同的情况 min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1]) weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :] weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :] # GRU+LSTM集成 weighted_gru_probs = weighted_gru_probs + epsilon weighted_lstm_probs = weighted_lstm_probs + epsilon log_ensemble_probs = (gru_weight * torch.log(weighted_gru_probs) + lstm_weight * torch.log(weighted_lstm_probs)) ensemble_probs = torch.exp(log_ensemble_probs) ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True) # 解码 pred_seq = torch.argmax(ensemble_probs[0], dim=-1).numpy() pred_seq = [int(p) for p in pred_seq if p != 0] pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq] if eval_type == 'val': true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']] true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq] ed = editdistance.eval(true_phonemes, pred_phonemes) total_edit_distance += ed total_true_length += len(true_phonemes) if eval_type == 'val' and total_true_length > 0: return 100 * total_edit_distance / total_true_length return 0.0 def search_parameters(cache, args): """搜索最优参数""" print("🔍 开始参数搜索...") # 解析搜索空间 gru_weights = [float(x) for x in args.gru_weights.split(',')] noise_weights = [float(x) for x in args.tta_noise_weights.split(',')] scale_weights = [float(x) for x in args.tta_scale_weights.split(',')] shift_weights = [float(x) for x in args.tta_shift_weights.split(',')] smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')] total_configs = len(gru_weights) * len(noise_weights) * len(scale_weights) * len(shift_weights) * len(smooth_weights) print(f"搜索空间: {total_configs} 个配置") best_per = float('inf') best_config = None all_results = [] config_count = 0 for gru_w in gru_weights: for noise_w in noise_weights: for scale_w in scale_weights: for shift_w in shift_weights: for smooth_w in smooth_weights: config_count += 1 tta_weights = { 'original': 1.0, # 总是包含原始数据 'noise': noise_w, 'scale': scale_w, 'shift': shift_w, 'smooth': smooth_w } per = evaluate_config(cache, gru_w, tta_weights, args.eval_type) result = { 'gru_weight': gru_w, 'lstm_weight': 1.0 - gru_w, 'tta_weights': tta_weights, 'per': per } all_results.append(result) if per < best_per: best_per = per best_config = result print(f"🎯 新最佳: PER={per:.3f}% | GRU={gru_w:.1f} | TTA=({noise_w},{scale_w},{shift_w},{smooth_w})") if config_count % 50 == 0: print(f"进度: {config_count}/{total_configs} ({100*config_count/total_configs:.1f}%)") return all_results, best_config def main(): args = parse_arguments() print("🚀 高效TTA-E参数搜索") print("=" * 50) # 第一阶段:生成或加载缓存 if args.force_recache or not os.path.exists(args.cache_file): cache = generate_base_predictions(args) else: print(f"📁 加载现有缓存: {args.cache_file}") with open(args.cache_file, 'rb') as f: cache = pickle.load(f) print("✅ 缓存加载完成") # 第二阶段:参数搜索 all_results, best_config = search_parameters(cache, args) # 保存结果 results = { 'best_config': best_config, 'all_results': all_results, 'args': vars(args), 'timestamp': time.strftime("%Y-%m-%d %H:%M:%S") } with open(args.output_file, 'w') as f: json.dump(results, f, indent=2) print("\n" + "=" * 50) print("🏆 搜索完成!") print(f"最佳配置: PER={best_config['per']:.3f}%") print(f"GRU权重: {best_config['gru_weight']:.1f}") print(f"TTA权重: {best_config['tta_weights']}") print(f"结果保存到: {args.output_file}") # 显示top-10 sorted_results = sorted(all_results, key=lambda x: x['per'])[:10] print(f"\n📊 Top-10配置:") for i, result in enumerate(sorted_results): tw = result['tta_weights'] print(f"{i+1:2d}. PER={result['per']:6.3f}% | GRU={result['gru_weight']:.1f} | " f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})") if __name__ == "__main__": main()