386 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			386 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| """
 | |
| 高效版TTA-E参数搜索
 | |
| 先缓存所有基础预测,然后快速搜索参数组合
 | |
| """
 | |
| 
 | |
| import os
 | |
| import sys
 | |
| import torch
 | |
| import numpy as np
 | |
| import pandas as pd
 | |
| from omegaconf import OmegaConf
 | |
| import time
 | |
| from tqdm import tqdm
 | |
| import editdistance
 | |
| import argparse
 | |
| import itertools
 | |
| import json
 | |
| import pickle
 | |
| 
 | |
| # Add parent directories to path to import models
 | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
 | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
 | |
| 
 | |
| from model_training.rnn_model import GRUDecoder
 | |
| from model_training_lstm.rnn_model import LSTMDecoder
 | |
| from model_training.evaluate_model_helpers import *
 | |
| 
 | |
| def parse_arguments():
 | |
|     parser = argparse.ArgumentParser(description='高效TTA-E参数搜索')
 | |
|     
 | |
|     # 模型和数据路径
 | |
|     parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline')
 | |
|     parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn')
 | |
|     parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final')
 | |
|     parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv')
 | |
|     parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'])
 | |
|     parser.add_argument('--gpu_number', type=int, default=0)
 | |
|     
 | |
|     # 搜索空间
 | |
|     parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
 | |
|     parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
 | |
|     parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
 | |
|     parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
 | |
|     parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
 | |
|     
 | |
|     # TTA参数
 | |
|     parser.add_argument('--tta_noise_std', type=float, default=0.01)
 | |
|     parser.add_argument('--tta_smooth_range', type=float, default=0.5)
 | |
|     parser.add_argument('--tta_scale_range', type=float, default=0.05)
 | |
|     parser.add_argument('--tta_cut_max', type=int, default=3)
 | |
|     
 | |
|     # 缓存控制
 | |
|     parser.add_argument('--cache_file', type=str, default='tta_cache.pkl')
 | |
|     parser.add_argument('--force_recache', action='store_true')
 | |
|     parser.add_argument('--output_file', type=str, default='search_results.json')
 | |
|     
 | |
|     return parser.parse_args()
 | |
| 
 | |
| def generate_base_predictions(args):
 | |
|     """生成所有基础增强的预测结果"""
 | |
|     print("🔄 生成基础预测缓存...")
 | |
|     
 | |
|     # 设置设备
 | |
|     if torch.cuda.is_available() and args.gpu_number >= 0:
 | |
|         device = torch.device(f'cuda:{args.gpu_number}')
 | |
|     else:
 | |
|         device = torch.device('cpu')
 | |
|     
 | |
|     # 加载模型
 | |
|     gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml'))
 | |
|     lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml'))
 | |
|     
 | |
|     gru_model = GRUDecoder(
 | |
|         neural_dim=gru_model_args['model']['n_input_features'],
 | |
|         n_units=gru_model_args['model']['n_units'], 
 | |
|         n_days=len(gru_model_args['dataset']['sessions']),
 | |
|         n_classes=gru_model_args['dataset']['n_classes'],
 | |
|         rnn_dropout=gru_model_args['model']['rnn_dropout'],
 | |
|         input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
 | |
|         n_layers=gru_model_args['model']['n_layers'],
 | |
|         patch_size=gru_model_args['model']['patch_size'],
 | |
|         patch_stride=gru_model_args['model']['patch_stride'],
 | |
|     )
 | |
|     
 | |
|     lstm_model = LSTMDecoder(
 | |
|         neural_dim=lstm_model_args['model']['n_input_features'],
 | |
|         n_units=lstm_model_args['model']['n_units'], 
 | |
|         n_days=len(lstm_model_args['dataset']['sessions']),
 | |
|         n_classes=lstm_model_args['dataset']['n_classes'],
 | |
|         rnn_dropout=lstm_model_args['model']['rnn_dropout'],
 | |
|         input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
 | |
|         n_layers=lstm_model_args['model']['n_layers'],
 | |
|         patch_size=lstm_model_args['model']['patch_size'],
 | |
|         patch_stride=lstm_model_args['model']['patch_stride'],
 | |
|     )
 | |
|     
 | |
|     # 加载权重
 | |
|     gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'), 
 | |
|                                weights_only=False, map_location=device)
 | |
|     lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'), 
 | |
|                                 weights_only=False, map_location=device)
 | |
|     
 | |
|     # 清理键名
 | |
|     for checkpoint in [gru_checkpoint, lstm_checkpoint]:
 | |
|         for key in list(checkpoint['model_state_dict'].keys()):
 | |
|             checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
 | |
|             checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
 | |
|     
 | |
|     gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
 | |
|     lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
 | |
|     
 | |
|     gru_model.to(device)
 | |
|     lstm_model.to(device)
 | |
|     gru_model.eval()
 | |
|     lstm_model.eval()
 | |
|     
 | |
|     # 加载数据
 | |
|     b2txt_csv_df = pd.read_csv(args.csv_path)
 | |
|     test_data = {}
 | |
|     total_trials = 0
 | |
|     
 | |
|     for session in gru_model_args['dataset']['sessions']:
 | |
|         files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')]
 | |
|         if f'data_{args.eval_type}.hdf5' in files:
 | |
|             eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5')
 | |
|             data = load_h5py_file(eval_file, b2txt_csv_df)
 | |
|             test_data[session] = data
 | |
|             total_trials += len(test_data[session]["neural_features"])
 | |
|     
 | |
|     print(f"总试验数: {total_trials}")
 | |
|     
 | |
|     # 生成所有基础增强类型的预测
 | |
|     augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
 | |
|     cache = {}
 | |
|     
 | |
|     for aug_type in augmentation_types:
 | |
|         print(f"处理增强类型: {aug_type}")
 | |
|         cache[aug_type] = {}
 | |
|         
 | |
|         with tqdm(total=total_trials, desc=f'{aug_type}', unit='trial') as pbar:
 | |
|             for session, data in test_data.items():
 | |
|                 input_layer = gru_model_args['dataset']['sessions'].index(session)
 | |
|                 cache[aug_type][session] = []
 | |
|                 
 | |
|                 for trial in range(len(data['neural_features'])):
 | |
|                     neural_input = data['neural_features'][trial]
 | |
|                     neural_input = np.expand_dims(neural_input, axis=0)
 | |
|                     neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
 | |
|                     
 | |
|                     # 应用增强
 | |
|                     x_aug = neural_input.clone()
 | |
|                     if aug_type == 'noise':
 | |
|                         noise_scale = args.tta_noise_std * (0.5 + 0.5 * np.random.rand())
 | |
|                         x_aug = x_aug + torch.randn_like(x_aug) * noise_scale
 | |
|                     elif aug_type == 'scale':
 | |
|                         scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * args.tta_scale_range
 | |
|                         x_aug = x_aug * scale_factor
 | |
|                     elif aug_type == 'shift' and args.tta_cut_max > 0:
 | |
|                         shift_amount = np.random.randint(1, min(args.tta_cut_max + 1, x_aug.shape[1] // 8))
 | |
|                         x_aug = torch.cat([x_aug[:, shift_amount:, :], x_aug[:, :shift_amount, :]], dim=1)
 | |
|                     elif aug_type == 'smooth':
 | |
|                         smooth_variation = (torch.rand(1).item() - 0.5) * 2 * args.tta_smooth_range
 | |
|                         varied_smooth_std = max(0.3, gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation)
 | |
|                     
 | |
|                     # 应用高斯平滑
 | |
|                     with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
 | |
|                         if aug_type == 'smooth':
 | |
|                             x_smoothed = gauss_smooth(
 | |
|                                 inputs=x_aug, device=device,
 | |
|                                 smooth_kernel_std=varied_smooth_std,
 | |
|                                 smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'],
 | |
|                                 padding='valid',
 | |
|                             )
 | |
|                         else:
 | |
|                             x_smoothed = gauss_smooth(
 | |
|                                 inputs=x_aug, device=device,
 | |
|                                 smooth_kernel_std=gru_model_args['dataset']['data_transforms']['smooth_kernel_std'],
 | |
|                                 smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'],
 | |
|                                 padding='valid',
 | |
|                             )
 | |
|                         
 | |
|                         with torch.no_grad():
 | |
|                             gru_logits, _ = gru_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True)
 | |
|                             lstm_logits, _ = lstm_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True)
 | |
|                             
 | |
|                             gru_probs = torch.softmax(gru_logits, dim=-1).float().cpu().numpy()
 | |
|                             lstm_probs = torch.softmax(lstm_logits, dim=-1).float().cpu().numpy()
 | |
|                     
 | |
|                     # 保存预测和真实标签信息
 | |
|                     trial_data = {
 | |
|                         'gru_probs': gru_probs,
 | |
|                         'lstm_probs': lstm_probs,
 | |
|                         'trial_info': {
 | |
|                             'session': session,
 | |
|                             'block_num': data['block_num'][trial],
 | |
|                             'trial_num': data['trial_num'][trial],
 | |
|                         }
 | |
|                     }
 | |
|                     
 | |
|                     if args.eval_type == 'val':
 | |
|                         trial_data['trial_info'].update({
 | |
|                             'seq_class_ids': data['seq_class_ids'][trial],
 | |
|                             'seq_len': data['seq_len'][trial],
 | |
|                             'sentence_label': data['sentence_label'][trial],
 | |
|                         })
 | |
|                     
 | |
|                     cache[aug_type][session].append(trial_data)
 | |
|                     pbar.update(1)
 | |
|     
 | |
|     # 保存缓存
 | |
|     with open(args.cache_file, 'wb') as f:
 | |
|         pickle.dump(cache, f)
 | |
|     
 | |
|     print(f"✅ 缓存已保存到 {args.cache_file}")
 | |
|     return cache
 | |
| 
 | |
| def evaluate_config(cache, gru_weight, tta_weights, eval_type='val'):
 | |
|     """评估特定参数配置"""
 | |
|     lstm_weight = 1.0 - gru_weight
 | |
|     epsilon = 1e-8
 | |
|     
 | |
|     total_edit_distance = 0
 | |
|     total_true_length = 0
 | |
|     
 | |
|     # 归一化TTA权重
 | |
|     enabled_augmentations = [k for k, v in tta_weights.items() if v > 0]
 | |
|     if len(enabled_augmentations) == 0:
 | |
|         return float('inf')
 | |
|     
 | |
|     total_tta_weight = sum(tta_weights[k] for k in enabled_augmentations)
 | |
|     norm_tta_weights = {k: tta_weights[k]/total_tta_weight for k in enabled_augmentations}
 | |
|     
 | |
|     # 遍历所有试验
 | |
|     for session in cache['original'].keys():
 | |
|         for trial_idx in range(len(cache['original'][session])):
 | |
|             trial_info = cache['original'][session][trial_idx]['trial_info']
 | |
|             
 | |
|             if eval_type == 'val' and 'seq_class_ids' not in trial_info:
 | |
|                 continue
 | |
|             
 | |
|             # 集成所有启用的增强
 | |
|             weighted_gru_probs = None
 | |
|             weighted_lstm_probs = None
 | |
|             
 | |
|             for aug_type in enabled_augmentations:
 | |
|                 weight = norm_tta_weights[aug_type]
 | |
|                 gru_probs = torch.tensor(cache[aug_type][session][trial_idx]['gru_probs'])
 | |
|                 lstm_probs = torch.tensor(cache[aug_type][session][trial_idx]['lstm_probs'])
 | |
|                 
 | |
|                 if weighted_gru_probs is None:
 | |
|                     weighted_gru_probs = weight * gru_probs
 | |
|                     weighted_lstm_probs = weight * lstm_probs
 | |
|                 else:
 | |
|                     # 处理长度不同的情况
 | |
|                     min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1])
 | |
|                     weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :]
 | |
|                     weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :]
 | |
|             
 | |
|             # GRU+LSTM集成
 | |
|             weighted_gru_probs = weighted_gru_probs + epsilon
 | |
|             weighted_lstm_probs = weighted_lstm_probs + epsilon
 | |
|             
 | |
|             log_ensemble_probs = (gru_weight * torch.log(weighted_gru_probs) + 
 | |
|                                  lstm_weight * torch.log(weighted_lstm_probs))
 | |
|             ensemble_probs = torch.exp(log_ensemble_probs)
 | |
|             ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
 | |
|             
 | |
|             # 解码
 | |
|             pred_seq = torch.argmax(ensemble_probs[0], dim=-1).numpy()
 | |
|             pred_seq = [int(p) for p in pred_seq if p != 0]
 | |
|             pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
 | |
|             pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq]
 | |
|             
 | |
|             if eval_type == 'val':
 | |
|                 true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']]
 | |
|                 true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq]
 | |
|                 
 | |
|                 ed = editdistance.eval(true_phonemes, pred_phonemes)
 | |
|                 total_edit_distance += ed
 | |
|                 total_true_length += len(true_phonemes)
 | |
|     
 | |
|     if eval_type == 'val' and total_true_length > 0:
 | |
|         return 100 * total_edit_distance / total_true_length
 | |
|     return 0.0
 | |
| 
 | |
| def search_parameters(cache, args):
 | |
|     """搜索最优参数"""
 | |
|     print("🔍 开始参数搜索...")
 | |
|     
 | |
|     # 解析搜索空间
 | |
|     gru_weights = [float(x) for x in args.gru_weights.split(',')]
 | |
|     noise_weights = [float(x) for x in args.tta_noise_weights.split(',')]
 | |
|     scale_weights = [float(x) for x in args.tta_scale_weights.split(',')]
 | |
|     shift_weights = [float(x) for x in args.tta_shift_weights.split(',')]
 | |
|     smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')]
 | |
|     
 | |
|     total_configs = len(gru_weights) * len(noise_weights) * len(scale_weights) * len(shift_weights) * len(smooth_weights)
 | |
|     print(f"搜索空间: {total_configs} 个配置")
 | |
|     
 | |
|     best_per = float('inf')
 | |
|     best_config = None
 | |
|     all_results = []
 | |
|     
 | |
|     config_count = 0
 | |
|     for gru_w in gru_weights:
 | |
|         for noise_w in noise_weights:
 | |
|             for scale_w in scale_weights:
 | |
|                 for shift_w in shift_weights:
 | |
|                     for smooth_w in smooth_weights:
 | |
|                         config_count += 1
 | |
|                         
 | |
|                         tta_weights = {
 | |
|                             'original': 1.0,  # 总是包含原始数据
 | |
|                             'noise': noise_w,
 | |
|                             'scale': scale_w,
 | |
|                             'shift': shift_w,
 | |
|                             'smooth': smooth_w
 | |
|                         }
 | |
|                         
 | |
|                         per = evaluate_config(cache, gru_w, tta_weights, args.eval_type)
 | |
|                         
 | |
|                         result = {
 | |
|                             'gru_weight': gru_w,
 | |
|                             'lstm_weight': 1.0 - gru_w,
 | |
|                             'tta_weights': tta_weights,
 | |
|                             'per': per
 | |
|                         }
 | |
|                         all_results.append(result)
 | |
|                         
 | |
|                         if per < best_per:
 | |
|                             best_per = per
 | |
|                             best_config = result
 | |
|                             print(f"🎯 新最佳: PER={per:.3f}% | GRU={gru_w:.1f} | TTA=({noise_w},{scale_w},{shift_w},{smooth_w})")
 | |
|                         
 | |
|                         if config_count % 50 == 0:
 | |
|                             print(f"进度: {config_count}/{total_configs} ({100*config_count/total_configs:.1f}%)")
 | |
|     
 | |
|     return all_results, best_config
 | |
| 
 | |
| def main():
 | |
|     args = parse_arguments()
 | |
|     
 | |
|     print("🚀 高效TTA-E参数搜索")
 | |
|     print("=" * 50)
 | |
|     
 | |
|     # 第一阶段:生成或加载缓存
 | |
|     if args.force_recache or not os.path.exists(args.cache_file):
 | |
|         cache = generate_base_predictions(args)
 | |
|     else:
 | |
|         print(f"📁 加载现有缓存: {args.cache_file}")
 | |
|         with open(args.cache_file, 'rb') as f:
 | |
|             cache = pickle.load(f)
 | |
|         print("✅ 缓存加载完成")
 | |
|     
 | |
|     # 第二阶段:参数搜索
 | |
|     all_results, best_config = search_parameters(cache, args)
 | |
|     
 | |
|     # 保存结果
 | |
|     results = {
 | |
|         'best_config': best_config,
 | |
|         'all_results': all_results,
 | |
|         'args': vars(args),
 | |
|         'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
 | |
|     }
 | |
|     
 | |
|     with open(args.output_file, 'w') as f:
 | |
|         json.dump(results, f, indent=2)
 | |
|     
 | |
|     print("\n" + "=" * 50)
 | |
|     print("🏆 搜索完成!")
 | |
|     print(f"最佳配置: PER={best_config['per']:.3f}%")
 | |
|     print(f"GRU权重: {best_config['gru_weight']:.1f}")
 | |
|     print(f"TTA权重: {best_config['tta_weights']}")
 | |
|     print(f"结果保存到: {args.output_file}")
 | |
|     
 | |
|     # 显示top-10
 | |
|     sorted_results = sorted(all_results, key=lambda x: x['per'])[:10]
 | |
|     print(f"\n📊 Top-10配置:")
 | |
|     for i, result in enumerate(sorted_results):
 | |
|         tw = result['tta_weights']
 | |
|         print(f"{i+1:2d}. PER={result['per']:6.3f}% | GRU={result['gru_weight']:.1f} | "
 | |
|               f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})")
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main() | 
