Files
b2txt25/TTA-E/efficient_search copy 2.py
2025-10-06 15:17:44 +08:00

386 lines
18 KiB
Python

#!/usr/bin/env python3
"""
高效版TTA-E参数搜索
先缓存所有基础预测,然后快速搜索参数组合
"""
import os
import sys
import torch
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
import itertools
import json
import pickle
# Add parent directories to path to import models
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
from model_training.rnn_model import GRUDecoder
from model_training_lstm.rnn_model import LSTMDecoder
from model_training.evaluate_model_helpers import *
def parse_arguments():
parser = argparse.ArgumentParser(description='高效TTA-E参数搜索')
# 模型和数据路径
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline')
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv')
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'])
parser.add_argument('--gpu_number', type=int, default=0)
# 搜索空间
parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
# TTA参数
parser.add_argument('--tta_noise_std', type=float, default=0.01)
parser.add_argument('--tta_smooth_range', type=float, default=0.5)
parser.add_argument('--tta_scale_range', type=float, default=0.05)
parser.add_argument('--tta_cut_max', type=int, default=3)
# 缓存控制
parser.add_argument('--cache_file', type=str, default='tta_cache.pkl')
parser.add_argument('--force_recache', action='store_true')
parser.add_argument('--output_file', type=str, default='search_results.json')
return parser.parse_args()
def generate_base_predictions(args):
"""生成所有基础增强的预测结果"""
print("🔄 生成基础预测缓存...")
# 设置设备
if torch.cuda.is_available() and args.gpu_number >= 0:
device = torch.device(f'cuda:{args.gpu_number}')
else:
device = torch.device('cpu')
# 加载模型
gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml'))
lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml'))
gru_model = GRUDecoder(
neural_dim=gru_model_args['model']['n_input_features'],
n_units=gru_model_args['model']['n_units'],
n_days=len(gru_model_args['dataset']['sessions']),
n_classes=gru_model_args['dataset']['n_classes'],
rnn_dropout=gru_model_args['model']['rnn_dropout'],
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
n_layers=gru_model_args['model']['n_layers'],
patch_size=gru_model_args['model']['patch_size'],
patch_stride=gru_model_args['model']['patch_stride'],
)
lstm_model = LSTMDecoder(
neural_dim=lstm_model_args['model']['n_input_features'],
n_units=lstm_model_args['model']['n_units'],
n_days=len(lstm_model_args['dataset']['sessions']),
n_classes=lstm_model_args['dataset']['n_classes'],
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
n_layers=lstm_model_args['model']['n_layers'],
patch_size=lstm_model_args['model']['patch_size'],
patch_stride=lstm_model_args['model']['patch_stride'],
)
# 加载权重
gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# 清理键名
for checkpoint in [gru_checkpoint, lstm_checkpoint]:
for key in list(checkpoint['model_state_dict'].keys()):
checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
gru_model.to(device)
lstm_model.to(device)
gru_model.eval()
lstm_model.eval()
# 加载数据
b2txt_csv_df = pd.read_csv(args.csv_path)
test_data = {}
total_trials = 0
for session in gru_model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')]
if f'data_{args.eval_type}.hdf5' in files:
eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_trials += len(test_data[session]["neural_features"])
print(f"总试验数: {total_trials}")
# 生成所有基础增强类型的预测
augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
cache = {}
for aug_type in augmentation_types:
print(f"处理增强类型: {aug_type}")
cache[aug_type] = {}
with tqdm(total=total_trials, desc=f'{aug_type}', unit='trial') as pbar:
for session, data in test_data.items():
input_layer = gru_model_args['dataset']['sessions'].index(session)
cache[aug_type][session] = []
for trial in range(len(data['neural_features'])):
neural_input = data['neural_features'][trial]
neural_input = np.expand_dims(neural_input, axis=0)
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
# 应用增强
x_aug = neural_input.clone()
if aug_type == 'noise':
noise_scale = args.tta_noise_std * (0.5 + 0.5 * np.random.rand())
x_aug = x_aug + torch.randn_like(x_aug) * noise_scale
elif aug_type == 'scale':
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * args.tta_scale_range
x_aug = x_aug * scale_factor
elif aug_type == 'shift' and args.tta_cut_max > 0:
shift_amount = np.random.randint(1, min(args.tta_cut_max + 1, x_aug.shape[1] // 8))
x_aug = torch.cat([x_aug[:, shift_amount:, :], x_aug[:, :shift_amount, :]], dim=1)
elif aug_type == 'smooth':
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * args.tta_smooth_range
varied_smooth_std = max(0.3, gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation)
# 应用高斯平滑
with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
if aug_type == 'smooth':
x_smoothed = gauss_smooth(
inputs=x_aug, device=device,
smooth_kernel_std=varied_smooth_std,
smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'],
padding='valid',
)
else:
x_smoothed = gauss_smooth(
inputs=x_aug, device=device,
smooth_kernel_std=gru_model_args['dataset']['data_transforms']['smooth_kernel_std'],
smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'],
padding='valid',
)
with torch.no_grad():
gru_logits, _ = gru_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True)
lstm_logits, _ = lstm_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True)
gru_probs = torch.softmax(gru_logits, dim=-1).float().cpu().numpy()
lstm_probs = torch.softmax(lstm_logits, dim=-1).float().cpu().numpy()
# 保存预测和真实标签信息
trial_data = {
'gru_probs': gru_probs,
'lstm_probs': lstm_probs,
'trial_info': {
'session': session,
'block_num': data['block_num'][trial],
'trial_num': data['trial_num'][trial],
}
}
if args.eval_type == 'val':
trial_data['trial_info'].update({
'seq_class_ids': data['seq_class_ids'][trial],
'seq_len': data['seq_len'][trial],
'sentence_label': data['sentence_label'][trial],
})
cache[aug_type][session].append(trial_data)
pbar.update(1)
# 保存缓存
with open(args.cache_file, 'wb') as f:
pickle.dump(cache, f)
print(f"✅ 缓存已保存到 {args.cache_file}")
return cache
def evaluate_config(cache, gru_weight, tta_weights, eval_type='val'):
"""评估特定参数配置"""
lstm_weight = 1.0 - gru_weight
epsilon = 1e-8
total_edit_distance = 0
total_true_length = 0
# 归一化TTA权重
enabled_augmentations = [k for k, v in tta_weights.items() if v > 0]
if len(enabled_augmentations) == 0:
return float('inf')
total_tta_weight = sum(tta_weights[k] for k in enabled_augmentations)
norm_tta_weights = {k: tta_weights[k]/total_tta_weight for k in enabled_augmentations}
# 遍历所有试验
for session in cache['original'].keys():
for trial_idx in range(len(cache['original'][session])):
trial_info = cache['original'][session][trial_idx]['trial_info']
if eval_type == 'val' and 'seq_class_ids' not in trial_info:
continue
# 集成所有启用的增强
weighted_gru_probs = None
weighted_lstm_probs = None
for aug_type in enabled_augmentations:
weight = norm_tta_weights[aug_type]
gru_probs = torch.tensor(cache[aug_type][session][trial_idx]['gru_probs'])
lstm_probs = torch.tensor(cache[aug_type][session][trial_idx]['lstm_probs'])
if weighted_gru_probs is None:
weighted_gru_probs = weight * gru_probs
weighted_lstm_probs = weight * lstm_probs
else:
# 处理长度不同的情况
min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1])
weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :]
weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :]
# GRU+LSTM集成
weighted_gru_probs = weighted_gru_probs + epsilon
weighted_lstm_probs = weighted_lstm_probs + epsilon
log_ensemble_probs = (gru_weight * torch.log(weighted_gru_probs) +
lstm_weight * torch.log(weighted_lstm_probs))
ensemble_probs = torch.exp(log_ensemble_probs)
ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
# 解码
pred_seq = torch.argmax(ensemble_probs[0], dim=-1).numpy()
pred_seq = [int(p) for p in pred_seq if p != 0]
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq]
if eval_type == 'val':
true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']]
true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq]
ed = editdistance.eval(true_phonemes, pred_phonemes)
total_edit_distance += ed
total_true_length += len(true_phonemes)
if eval_type == 'val' and total_true_length > 0:
return 100 * total_edit_distance / total_true_length
return 0.0
def search_parameters(cache, args):
"""搜索最优参数"""
print("🔍 开始参数搜索...")
# 解析搜索空间
gru_weights = [float(x) for x in args.gru_weights.split(',')]
noise_weights = [float(x) for x in args.tta_noise_weights.split(',')]
scale_weights = [float(x) for x in args.tta_scale_weights.split(',')]
shift_weights = [float(x) for x in args.tta_shift_weights.split(',')]
smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')]
total_configs = len(gru_weights) * len(noise_weights) * len(scale_weights) * len(shift_weights) * len(smooth_weights)
print(f"搜索空间: {total_configs} 个配置")
best_per = float('inf')
best_config = None
all_results = []
config_count = 0
for gru_w in gru_weights:
for noise_w in noise_weights:
for scale_w in scale_weights:
for shift_w in shift_weights:
for smooth_w in smooth_weights:
config_count += 1
tta_weights = {
'original': 1.0, # 总是包含原始数据
'noise': noise_w,
'scale': scale_w,
'shift': shift_w,
'smooth': smooth_w
}
per = evaluate_config(cache, gru_w, tta_weights, args.eval_type)
result = {
'gru_weight': gru_w,
'lstm_weight': 1.0 - gru_w,
'tta_weights': tta_weights,
'per': per
}
all_results.append(result)
if per < best_per:
best_per = per
best_config = result
print(f"🎯 新最佳: PER={per:.3f}% | GRU={gru_w:.1f} | TTA=({noise_w},{scale_w},{shift_w},{smooth_w})")
if config_count % 50 == 0:
print(f"进度: {config_count}/{total_configs} ({100*config_count/total_configs:.1f}%)")
return all_results, best_config
def main():
args = parse_arguments()
print("🚀 高效TTA-E参数搜索")
print("=" * 50)
# 第一阶段:生成或加载缓存
if args.force_recache or not os.path.exists(args.cache_file):
cache = generate_base_predictions(args)
else:
print(f"📁 加载现有缓存: {args.cache_file}")
with open(args.cache_file, 'rb') as f:
cache = pickle.load(f)
print("✅ 缓存加载完成")
# 第二阶段:参数搜索
all_results, best_config = search_parameters(cache, args)
# 保存结果
results = {
'best_config': best_config,
'all_results': all_results,
'args': vars(args),
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
}
with open(args.output_file, 'w') as f:
json.dump(results, f, indent=2)
print("\n" + "=" * 50)
print("🏆 搜索完成!")
print(f"最佳配置: PER={best_config['per']:.3f}%")
print(f"GRU权重: {best_config['gru_weight']:.1f}")
print(f"TTA权重: {best_config['tta_weights']}")
print(f"结果保存到: {args.output_file}")
# 显示top-10
sorted_results = sorted(all_results, key=lambda x: x['per'])[:10]
print(f"\n📊 Top-10配置:")
for i, result in enumerate(sorted_results):
tw = result['tta_weights']
print(f"{i+1:2d}. PER={result['per']:6.3f}% | GRU={result['gru_weight']:.1f} | "
f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})")
if __name__ == "__main__":
main()