#!/usr/bin/env python3 """ 高效版TTA-E参数搜索 先缓存所有基础预测,然后快速搜索参数组合 """ import os import sys import torch import numpy as np import pandas as pd from omegaconf import OmegaConf import time from tqdm import tqdm import editdistance import argparse import itertools import json import pickle from concurrent.futures import ThreadPoolExecutor, as_completed # GPU加速库 try: import cupy as cp GPU_AVAILABLE = True print("✅ CuPy available - GPU acceleration enabled") # 定义转换函数 def to_cpu(arr): return cp.asnumpy(arr) if hasattr(arr, 'device') else arr except ImportError: print("⚠️ CuPy not available - falling back to CPU numpy") import numpy as cp GPU_AVAILABLE = False # 定义转换函数 def to_cpu(arr): return arr # Add parent directories to path to import models sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) from model_training.rnn_model import GRUDecoder from model_training_lstm.rnn_model import LSTMDecoder from model_training.evaluate_model_helpers import * from model_training.evaluate_model_helpers import LOGIT_TO_PHONEME def parse_arguments(): parser = argparse.ArgumentParser(description='高效TTA-E参数搜索') # 模型和数据路径 parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline') parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn') parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final') parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv') parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test']) parser.add_argument('--gpu_number', type=int, default=0) # 搜索空间 parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0') # TTA参数 parser.add_argument('--tta_noise_std', type=float, default=0.01) parser.add_argument('--tta_smooth_range', type=float, default=0.5) parser.add_argument('--tta_scale_range', type=float, default=0.05) parser.add_argument('--tta_cut_max', type=int, default=3) # 缓存控制 parser.add_argument('--cache_file', type=str, default='tta_cache.pkl') parser.add_argument('--force_recache', action='store_true') parser.add_argument('--output_file', type=str, default='search_results.json') # 并行处理 parser.add_argument('--max_workers', type=int, default=25, help='最大并行工作线程数') parser.add_argument('--batch_size', type=int, default=100, help='每批处理的配置数量') return parser.parse_args() def generate_base_predictions(args): """生成所有基础增强的预测结果""" print("🔄 生成基础预测缓存...") # 设置设备 if torch.cuda.is_available() and args.gpu_number >= 0: device = torch.device(f'cuda:{args.gpu_number}') else: device = torch.device('cpu') # 加载模型 gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml')) lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml')) gru_model = GRUDecoder( neural_dim=gru_model_args['model']['n_input_features'], n_units=gru_model_args['model']['n_units'], n_days=len(gru_model_args['dataset']['sessions']), n_classes=gru_model_args['dataset']['n_classes'], rnn_dropout=gru_model_args['model']['rnn_dropout'], input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'], n_layers=gru_model_args['model']['n_layers'], patch_size=gru_model_args['model']['patch_size'], patch_stride=gru_model_args['model']['patch_stride'], ) lstm_model = LSTMDecoder( neural_dim=lstm_model_args['model']['n_input_features'], n_units=lstm_model_args['model']['n_units'], n_days=len(lstm_model_args['dataset']['sessions']), n_classes=lstm_model_args['dataset']['n_classes'], rnn_dropout=lstm_model_args['model']['rnn_dropout'], input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'], n_layers=lstm_model_args['model']['n_layers'], patch_size=lstm_model_args['model']['patch_size'], patch_stride=lstm_model_args['model']['patch_stride'], ) # 加载权重 gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) # 清理键名 for checkpoint in [gru_checkpoint, lstm_checkpoint]: for key in list(checkpoint['model_state_dict'].keys()): checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key) checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key) gru_model.load_state_dict(gru_checkpoint['model_state_dict']) lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) gru_model.to(device) lstm_model.to(device) gru_model.eval() lstm_model.eval() # 加载数据 b2txt_csv_df = pd.read_csv(args.csv_path) test_data = {} total_trials = 0 for session in gru_model_args['dataset']['sessions']: files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')] if f'data_{args.eval_type}.hdf5' in files: eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5') data = load_h5py_file(eval_file, b2txt_csv_df) test_data[session] = data total_trials += len(test_data[session]["neural_features"]) print(f"总试验数: {total_trials}") # 生成所有基础增强类型的预测 augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth'] cache = {} for aug_type in augmentation_types: print(f"处理增强类型: {aug_type}") cache[aug_type] = {} with tqdm(total=total_trials, desc=f'{aug_type}', unit='trial') as pbar: for session, data in test_data.items(): input_layer = gru_model_args['dataset']['sessions'].index(session) cache[aug_type][session] = [] for trial in range(len(data['neural_features'])): neural_input = data['neural_features'][trial] neural_input = np.expand_dims(neural_input, axis=0) neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16) # 应用增强 x_aug = neural_input.clone() if aug_type == 'noise': noise_scale = args.tta_noise_std * (0.5 + 0.5 * np.random.rand()) x_aug = x_aug + torch.randn_like(x_aug) * noise_scale elif aug_type == 'scale': scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * args.tta_scale_range x_aug = x_aug * scale_factor elif aug_type == 'shift' and args.tta_cut_max > 0: shift_amount = np.random.randint(1, min(args.tta_cut_max + 1, x_aug.shape[1] // 8)) x_aug = torch.cat([x_aug[:, shift_amount:, :], x_aug[:, :shift_amount, :]], dim=1) elif aug_type == 'smooth': smooth_variation = (torch.rand(1).item() - 0.5) * 2 * args.tta_smooth_range varied_smooth_std = max(0.3, gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation) # 应用高斯平滑 with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16): if aug_type == 'smooth': x_smoothed = gauss_smooth( inputs=x_aug, device=device, smooth_kernel_std=varied_smooth_std, smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'], padding='valid', ) else: x_smoothed = gauss_smooth( inputs=x_aug, device=device, smooth_kernel_std=gru_model_args['dataset']['data_transforms']['smooth_kernel_std'], smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'], padding='valid', ) with torch.no_grad(): gru_logits, _ = gru_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True) lstm_logits, _ = lstm_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True) gru_probs = torch.softmax(gru_logits, dim=-1).float().cpu().numpy() lstm_probs = torch.softmax(lstm_logits, dim=-1).float().cpu().numpy() # 保存预测和真实标签信息 trial_data = { 'gru_probs': gru_probs, 'lstm_probs': lstm_probs, 'trial_info': { 'session': session, 'block_num': data['block_num'][trial], 'trial_num': data['trial_num'][trial], } } if args.eval_type == 'val': trial_data['trial_info'].update({ 'seq_class_ids': data['seq_class_ids'][trial], 'seq_len': data['seq_len'][trial], 'sentence_label': data['sentence_label'][trial], }) cache[aug_type][session].append(trial_data) pbar.update(1) # 保存缓存 with open(args.cache_file, 'wb') as f: pickle.dump(cache, f) print(f"✅ 缓存已保存到 {args.cache_file}") return cache def evaluate_config(cache, gru_weight, tta_weights, eval_type='val'): """评估特定参数配置 - GPU加速版本""" lstm_weight = 1.0 - gru_weight epsilon = 1e-8 total_edit_distance = 0 total_true_length = 0 # 归一化TTA权重 enabled_augmentations = [k for k, v in tta_weights.items() if v > 0] if len(enabled_augmentations) == 0: return float('inf') total_tta_weight = sum(tta_weights[k] for k in enabled_augmentations) norm_tta_weights = {k: tta_weights[k]/total_tta_weight for k in enabled_augmentations} # 遍历所有试验 for session in cache['original'].keys(): for trial_idx in range(len(cache['original'][session])): trial_info = cache['original'][session][trial_idx]['trial_info'] if eval_type == 'val' and 'seq_class_ids' not in trial_info: continue # 集成所有启用的增强 - 使用GPU加速 weighted_gru_probs = None weighted_lstm_probs = None for aug_type in enabled_augmentations: weight = norm_tta_weights[aug_type] # 使用GPU数组进行计算 if GPU_AVAILABLE: gru_probs = cp.asarray(cache[aug_type][session][trial_idx]['gru_probs']) lstm_probs = cp.asarray(cache[aug_type][session][trial_idx]['lstm_probs']) else: gru_probs = cache[aug_type][session][trial_idx]['gru_probs'] lstm_probs = cache[aug_type][session][trial_idx]['lstm_probs'] if weighted_gru_probs is None: weighted_gru_probs = weight * gru_probs weighted_lstm_probs = weight * lstm_probs else: # 处理长度不同的情况 min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1]) weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :] weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :] # GRU+LSTM集成 - GPU加速的log空间计算 weighted_gru_probs = weighted_gru_probs + epsilon weighted_lstm_probs = weighted_lstm_probs + epsilon if GPU_AVAILABLE: log_ensemble_probs = (gru_weight * cp.log(weighted_gru_probs) + lstm_weight * cp.log(weighted_lstm_probs)) ensemble_probs = cp.exp(log_ensemble_probs) ensemble_probs = ensemble_probs / ensemble_probs.sum(axis=-1, keepdims=True) # 解码 - GPU加速的argmax pred_seq = cp.argmax(ensemble_probs[0], axis=-1) pred_seq = to_cpu(pred_seq) # 转回CPU进行后续处理 else: log_ensemble_probs = (gru_weight * np.log(weighted_gru_probs) + lstm_weight * np.log(weighted_lstm_probs)) ensemble_probs = np.exp(log_ensemble_probs) ensemble_probs = ensemble_probs / ensemble_probs.sum(axis=-1, keepdims=True) # 解码 pred_seq = np.argmax(ensemble_probs[0], axis=-1) # 后处理(在CPU上进行,因为涉及Python列表操作) pred_seq = [int(p) for p in pred_seq if p != 0] pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq] if eval_type == 'val': true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']] true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq] ed = editdistance.eval(true_phonemes, pred_phonemes) total_edit_distance += ed total_true_length += len(true_phonemes) if eval_type == 'val' and total_true_length > 0: return 100 * total_edit_distance / total_true_length return 0.0 def evaluate_config_wrapper(args_tuple): """包装函数用于多线程处理""" cache, gru_weight, tta_weights, eval_type = args_tuple per = evaluate_config(cache, gru_weight, tta_weights, eval_type) return { 'gru_weight': gru_weight, 'lstm_weight': 1.0 - gru_weight, 'tta_weights': tta_weights.copy(), 'per': per } def search_parameters(cache, args): """搜索最优参数""" print("🔍 开始参数搜索...") # 解析搜索空间 gru_weights = [float(x) for x in args.gru_weights.split(',')] noise_weights = [float(x) for x in args.tta_noise_weights.split(',')] scale_weights = [float(x) for x in args.tta_scale_weights.split(',')] shift_weights = [float(x) for x in args.tta_shift_weights.split(',')] smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')] # 生成所有配置组合 configs = [] for gru_w in gru_weights: for noise_w in noise_weights: for scale_w in scale_weights: for shift_w in shift_weights: for smooth_w in smooth_weights: tta_weights = { 'original': 1.0, # 总是包含原始数据 'noise': noise_w, 'scale': scale_w, 'shift': shift_w, 'smooth': smooth_w } configs.append((cache, gru_w, tta_weights, args.eval_type)) total_configs = len(configs) print(f"搜索空间: {total_configs} 个配置") print(f"使用 {args.max_workers} 个线程并行处理...") best_per = float('inf') best_config = None all_results = [] completed_count = 0 # 使用线程池并行处理 with ThreadPoolExecutor(max_workers=args.max_workers) as executor: # 提交所有任务 future_to_config = { executor.submit(evaluate_config_wrapper, config): config for config in configs } # 处理完成的任务 for future in as_completed(future_to_config): try: result = future.result() all_results.append(result) completed_count += 1 # 更新最佳结果 if result['per'] < best_per: best_per = result['per'] best_config = result tw = result['tta_weights'] print(f"🎯 新最佳[{completed_count}/{total_configs}]: PER={best_per:.3f}% | " f"GRU={result['gru_weight']:.1f} | " f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})") # 定期进度报告 if completed_count % args.batch_size == 0: progress = 100 * completed_count / total_configs print(f"📊 进度: {completed_count}/{total_configs} ({progress:.1f}%) | 当前最优PER: {best_per:.3f}%") elif completed_count % 50 == 0: # 更频繁的简单进度 print(f"... {completed_count}/{total_configs} ...", end='', flush=True) except Exception as e: completed_count += 1 config = future_to_config[future] print(f"❌ 配置失败: {config[1]:.1f}, 错误: {e}") # 添加失败的结果记录 all_results.append({ 'gru_weight': config[1], 'lstm_weight': 1.0 - config[1], 'tta_weights': config[2], 'per': float('inf'), 'error': str(e) }) print(f"\n✅ 所有任务完成!") # 找到真正的最佳结果(防止异常情况) valid_results = [r for r in all_results if 'error' not in r and r['per'] != float('inf')] if valid_results: best_config = min(valid_results, key=lambda x: x['per']) return all_results, best_config def main(): args = parse_arguments() print("🚀 高效TTA-E参数搜索") print("=" * 50) # 第一阶段:生成或加载缓存 if args.force_recache or not os.path.exists(args.cache_file): cache = generate_base_predictions(args) else: print(f"📁 加载现有缓存: {args.cache_file}") with open(args.cache_file, 'rb') as f: cache = pickle.load(f) print("✅ 缓存加载完成") # 第二阶段:参数搜索 all_results, best_config = search_parameters(cache, args) # 保存结果 results = { 'best_config': best_config, 'all_results': all_results, 'args': vars(args), 'timestamp': time.strftime("%Y-%m-%d %H:%M:%S") } with open(args.output_file, 'w') as f: json.dump(results, f, indent=2) print("\n" + "=" * 50) print("🏆 搜索完成!") if best_config is not None: print(f"最佳配置: PER={best_config['per']:.3f}%") print(f"GRU权重: {best_config['gru_weight']:.1f}") print(f"TTA权重: {best_config['tta_weights']}") print(f"结果保存到: {args.output_file}") # 显示top-10 sorted_results = sorted([r for r in all_results if 'error' not in r and r['per'] != float('inf')], key=lambda x: x['per'])[:10] print(f"\n📊 Top-10配置:") for i, result in enumerate(sorted_results): tw = result['tta_weights'] print(f"{i+1:2d}. PER={result['per']:6.3f}% | GRU={result['gru_weight']:.1f} | " f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})") else: print("❌ 未找到有效的配置结果!所有配置都失败了。") print("请检查缓存数据和评估逻辑。") # 搜索统计 valid_results = [r for r in all_results if 'error' not in r and r['per'] != float('inf')] print(f"\n📈 搜索统计:") print(f" 总配置数: {len(all_results)}") print(f" 成功配置数: {len(valid_results)}") print(f" 失败配置数: {len(all_results) - len(valid_results)}") if valid_results: valid_pers = [r['per'] for r in valid_results] print(f" PER范围: {min(valid_pers):.3f}% - {max(valid_pers):.3f}%") print(f" 平均PER: {sum(valid_pers)/len(valid_pers):.3f}%") if __name__ == "__main__": main()