#!/usr/bin/env python3 """ TTA-E参数搜索脚本 先运行所有基础配置获取预测结果,然后搜索最优的参数组合 避免重复模型推理,提高搜索效率 """ import os import sys import torch import numpy as np import pandas as pd from omegaconf import OmegaConf import time from tqdm import tqdm import editdistance import argparse import itertools import json from concurrent.futures import ProcessPoolExecutor import pickle # Add parent directories to path to import models sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) from model_training.rnn_model import GRUDecoder from model_training_lstm.rnn_model import LSTMDecoder from model_training.evaluate_model_helpers import * def parse_arguments(): parser = argparse.ArgumentParser(description='TTA-E Parameter Search for optimal configuration.') parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', help='Path to the pretrained GRU model directory.') parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', help='Path to the pretrained LSTM model directory.') parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', help='Path to the dataset directory.') parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', help='Path to the CSV file with metadata.') parser.add_argument('--gpu_number', type=int, default=0, help='GPU number to use for model inference.') parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'], help='Evaluation type.') # 搜索空间参数 parser.add_argument('--gru_weights', type=str, default='0.4,0.5,0.6,0.7,0.8,1.0', help='Comma-separated GRU weights to search.') parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.5,1.0', help='Comma-separated noise weights to search.') parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.5,1.0', help='Comma-separated scale weights to search.') parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.5,1.0', help='Comma-separated shift weights to search.') parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.5,1.0', help='Comma-separated smooth weights to search.') # TTA固定参数 parser.add_argument('--tta_noise_std', type=float, default=0.01, help='Standard deviation for TTA noise augmentation.') parser.add_argument('--tta_smooth_range', type=float, default=0.5, help='Range for TTA smoothing kernel variation.') parser.add_argument('--tta_scale_range', type=float, default=0.05, help='Range for TTA amplitude scaling.') parser.add_argument('--tta_cut_max', type=int, default=3, help='Maximum timesteps for TTA shift.') # 输出控制 parser.add_argument('--cache_file', type=str, default='tta_predictions_cache.pkl', help='File to cache model predictions.') parser.add_argument('--results_file', type=str, default='parameter_search_results.json', help='File to save search results.') parser.add_argument('--force_recache', action='store_true', help='Force re-computation of predictions cache.') return parser.parse_args() def generate_all_base_configs(): """生成所有需要运行的基础配置(每种增强单独运行)""" base_configs = [ {'name': 'original', 'tta_weights': {'original': 1.0, 'noise': 0.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 0.0}}, {'name': 'noise', 'tta_weights': {'original': 0.0, 'noise': 1.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 0.0}}, {'name': 'scale', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 1.0, 'shift': 0.0, 'smooth': 0.0}}, {'name': 'shift', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 0.0, 'shift': 1.0, 'smooth': 0.0}}, {'name': 'smooth', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 1.0}}, ] return base_configs def run_single_tta_prediction(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, device, aug_type, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max): """运行单个TTA增强的预测""" x_augmented = x.clone() # Get default smoothing parameters default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size'] if aug_type == 'original': pass elif aug_type == 'noise': noise_scale = tta_noise_std * (0.5 + 0.5 * np.random.rand()) noise = torch.randn_like(x_augmented) * noise_scale x_augmented = x_augmented + noise elif aug_type == 'scale': scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range x_augmented = x_augmented * scale_factor elif aug_type == 'shift' and tta_cut_max > 0: shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8)) x_augmented = torch.cat([x_augmented[:, shift_amount:, :], x_augmented[:, :shift_amount, :]], dim=1) elif aug_type == 'smooth': smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range varied_smooth_std = max(0.3, default_smooth_std + smooth_variation) # Use autocast for efficiency with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16): # Apply Gaussian smoothing if aug_type == 'smooth': x_smoothed = gauss_smooth( inputs=x_augmented, device=device, smooth_kernel_std=varied_smooth_std, smooth_kernel_size=default_smooth_size, padding='valid', ) else: x_smoothed = gauss_smooth( inputs=x_augmented, device=device, smooth_kernel_std=default_smooth_std, smooth_kernel_size=default_smooth_size, padding='valid', ) with torch.no_grad(): # Get GRU logits and convert to probabilities gru_logits, _ = gru_model( x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True, ) gru_probs = torch.softmax(gru_logits, dim=-1) # Get LSTM logits and convert to probabilities lstm_logits, _ = lstm_model( x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True, ) lstm_probs = torch.softmax(lstm_logits, dim=-1) return gru_probs.float().cpu().numpy(), lstm_probs.float().cpu().numpy() def cache_model_predictions(args): """缓存所有基础模型预测结果""" print("=== 第一阶段:缓存模型预测结果 ===") # 设置设备 if torch.cuda.is_available() and args.gpu_number >= 0: device = torch.device(f'cuda:{args.gpu_number}') else: device = torch.device('cpu') print(f'Using device: {device}') # 加载模型 print("Loading models...") gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml')) lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml')) # Define GRU model gru_model = GRUDecoder( neural_dim=gru_model_args['model']['n_input_features'], n_units=gru_model_args['model']['n_units'], n_days=len(gru_model_args['dataset']['sessions']), n_classes=gru_model_args['dataset']['n_classes'], rnn_dropout=gru_model_args['model']['rnn_dropout'], input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'], n_layers=gru_model_args['model']['n_layers'], patch_size=gru_model_args['model']['patch_size'], patch_stride=gru_model_args['model']['patch_stride'], ) # Load GRU model weights gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) for key in list(gru_checkpoint['model_state_dict'].keys()): gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key) gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key) gru_model.load_state_dict(gru_checkpoint['model_state_dict']) # Define LSTM model lstm_model = LSTMDecoder( neural_dim=lstm_model_args['model']['n_input_features'], n_units=lstm_model_args['model']['n_units'], n_days=len(lstm_model_args['dataset']['sessions']), n_classes=lstm_model_args['dataset']['n_classes'], rnn_dropout=lstm_model_args['model']['rnn_dropout'], input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'], n_layers=lstm_model_args['model']['n_layers'], patch_size=lstm_model_args['model']['patch_size'], patch_stride=lstm_model_args['model']['patch_stride'], ) # Load LSTM model weights lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) for key in list(lstm_checkpoint['model_state_dict'].keys()): lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key) lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key) lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) gru_model.to(device) lstm_model.to(device) gru_model.eval() lstm_model.eval() # 加载数据 print("Loading data...") b2txt_csv_df = pd.read_csv(args.csv_path) test_data = {} total_trials = 0 for session in gru_model_args['dataset']['sessions']: files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')] if f'data_{args.eval_type}.hdf5' in files: eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5') data = load_h5py_file(eval_file, b2txt_csv_df) test_data[session] = data total_trials += len(test_data[session]["neural_features"]) print(f'Loaded {len(test_data[session]["neural_features"])} {args.eval_type} trials for session {session}.') print(f'Total trials: {total_trials}') # 生成所有基础预测 base_configs = generate_all_base_configs() predictions_cache = {} for config in base_configs: config_name = config['name'] print(f"\nGenerating predictions for: {config_name}") predictions_cache[config_name] = {} with tqdm(total=total_trials, desc=f'Processing {config_name}', unit='trial') as pbar: for session, data in test_data.items(): input_layer = gru_model_args['dataset']['sessions'].index(session) session_predictions = [] for trial in range(len(data['neural_features'])): neural_input = data['neural_features'][trial] neural_input = np.expand_dims(neural_input, axis=0) neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16) gru_probs, lstm_probs = run_single_tta_prediction( neural_input, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, device, config_name, args.tta_noise_std, args.tta_smooth_range, args.tta_scale_range, args.tta_cut_max ) session_predictions.append({ 'gru_probs': gru_probs, 'lstm_probs': lstm_probs, 'trial_info': { 'session': session, 'block_num': data['block_num'][trial], 'trial_num': data['trial_num'][trial], 'seq_class_ids': data['seq_class_ids'][trial] if args.eval_type == 'val' else None, 'seq_len': data['seq_len'][trial] if args.eval_type == 'val' else None, 'sentence_label': data['sentence_label'][trial] if args.eval_type == 'val' else None, } }) pbar.update(1) predictions_cache[config_name][session] = session_predictions # 保存缓存 print(f"\nSaving predictions cache to {args.cache_file}...") with open(args.cache_file, 'wb') as f: pickle.dump(predictions_cache, f) print("✓ Predictions cache saved successfully!") return predictions_cache def ensemble_and_evaluate(predictions_cache, gru_weight, tta_weights, eval_type='val'): """基于缓存的预测结果进行集成和评估""" lstm_weight = 1.0 - gru_weight epsilon = 1e-8 total_trials = 0 total_edit_distance = 0 total_true_length = 0 # 检查哪些增强被启用 enabled_augmentations = [aug_type for aug_type, weight in tta_weights.items() if weight > 0] if len(enabled_augmentations) == 0: return float('inf') # 无效配置 # 归一化TTA权重 total_tta_weight = sum(weight for weight in tta_weights.values() if weight > 0) normalized_tta_weights = {k: v/total_tta_weight for k, v in tta_weights.items() if v > 0} for session in predictions_cache['original'].keys(): session_predictions = predictions_cache['original'][session] for trial_idx in range(len(session_predictions)): trial_info = session_predictions[trial_idx]['trial_info'] if eval_type == 'val' and trial_info['seq_class_ids'] is None: continue # 收集所有启用增强的概率 weighted_gru_probs = None weighted_lstm_probs = None for aug_type in enabled_augmentations: weight = normalized_tta_weights[aug_type] gru_probs = predictions_cache[aug_type][session][trial_idx]['gru_probs'] lstm_probs = predictions_cache[aug_type][session][trial_idx]['lstm_probs'] gru_probs = torch.tensor(gru_probs) lstm_probs = torch.tensor(lstm_probs) if weighted_gru_probs is None: weighted_gru_probs = weight * gru_probs weighted_lstm_probs = weight * lstm_probs else: # 处理序列长度不同的情况 min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1]) weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :] weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :] # 集成GRU和LSTM(几何平均) weighted_gru_probs = weighted_gru_probs + epsilon weighted_lstm_probs = weighted_lstm_probs + epsilon log_ensemble_probs = (gru_weight * torch.log(weighted_gru_probs) + lstm_weight * torch.log(weighted_lstm_probs)) ensemble_probs = torch.exp(log_ensemble_probs) ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True) # 解码预测序列 pred_seq = torch.argmax(ensemble_probs[0], dim=-1).numpy() pred_seq = [int(p) for p in pred_seq if p != 0] pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq] if eval_type == 'val': # 计算PER true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']] true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq] ed = editdistance.eval(true_phonemes, pred_phonemes) total_edit_distance += ed total_true_length += len(true_phonemes) total_trials += 1 if eval_type == 'val' and total_true_length > 0: per = 100 * total_edit_distance / total_true_length return per else: return 0.0 # test模式返回0 def search_optimal_parameters(predictions_cache, args): """搜索最优参数组合""" print("\n=== 第二阶段:参数搜索 ===") # 解析搜索空间 gru_weights = [float(x) for x in args.gru_weights.split(',')] noise_weights = [float(x) for x in args.tta_noise_weights.split(',')] scale_weights = [float(x) for x in args.tta_scale_weights.split(',')] shift_weights = [float(x) for x in args.tta_shift_weights.split(',')] smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')] print(f"Search space:") print(f" GRU weights: {gru_weights}") print(f" Noise weights: {noise_weights}") print(f" Scale weights: {scale_weights}") print(f" Shift weights: {shift_weights}") print(f" Smooth weights: {smooth_weights}") # 生成所有参数组合 all_combinations = list(itertools.product(gru_weights, noise_weights, scale_weights, shift_weights, smooth_weights)) total_combinations = len(all_combinations) print(f"Total combinations to evaluate: {total_combinations}") best_per = float('inf') best_config = None results = [] with tqdm(total=total_combinations, desc='Parameter search', unit='config') as pbar: for gru_w, noise_w, scale_w, shift_w, smooth_w in all_combinations: tta_weights = { 'original': 1.0, # 总是包含原始数据 'noise': noise_w, 'scale': scale_w, 'shift': shift_w, 'smooth': smooth_w } per = ensemble_and_evaluate(predictions_cache, gru_w, tta_weights, args.eval_type) config = { 'gru_weight': gru_w, 'lstm_weight': 1.0 - gru_w, 'tta_weights': tta_weights, 'per': per } results.append(config) if per < best_per: best_per = per best_config = config print(f"\n🎯 New best PER: {per:.3f}%") print(f" GRU weight: {gru_w:.1f}") print(f" TTA weights: {tta_weights}") pbar.update(1) return results, best_config def main(): args = parse_arguments() print("TTA-E Parameter Search") print("=" * 50) # 第一阶段:缓存预测结果 if args.force_recache or not os.path.exists(args.cache_file): predictions_cache = cache_model_predictions(args) else: print(f"Loading existing predictions cache from {args.cache_file}...") with open(args.cache_file, 'rb') as f: predictions_cache = pickle.load(f) print("✓ Cache loaded successfully!") # 第二阶段:参数搜索 results, best_config = search_optimal_parameters(predictions_cache, args) # 保存结果 print(f"\n=== 搜索完成 ===") print(f"Best configuration:") print(f" PER: {best_config['per']:.3f}%") print(f" GRU weight: {best_config['gru_weight']:.1f}") print(f" LSTM weight: {best_config['lstm_weight']:.1f}") print(f" TTA weights: {best_config['tta_weights']}") # 保存所有结果 search_results = { 'best_config': best_config, 'all_results': results, 'search_args': vars(args), 'timestamp': time.strftime("%Y-%m-%d %H:%M:%S") } with open(args.results_file, 'w') as f: json.dump(search_results, f, indent=2) print(f"\n✓ Results saved to {args.results_file}") # 显示前10个最佳配置 sorted_results = sorted(results, key=lambda x: x['per']) print(f"\nTop 10 configurations:") for i, config in enumerate(sorted_results[:10]): print(f"{i+1:2d}. PER={config['per']:6.3f}% | GRU={config['gru_weight']:.1f} | TTA={config['tta_weights']}") if __name__ == "__main__": main()