Files
b2txt25/TTA-E/efficient_search.py

490 lines
22 KiB
Python
Raw Normal View History

2025-10-06 15:17:44 +08:00
#!/usr/bin/env python3
"""
高效版TTA-E参数搜索
先缓存所有基础预测然后快速搜索参数组合
"""
import os
import sys
import torch
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
import itertools
import json
import pickle
from concurrent.futures import ThreadPoolExecutor, as_completed
# GPU加速库
try:
import cupy as cp
GPU_AVAILABLE = True
print("✅ CuPy available - GPU acceleration enabled")
# 定义转换函数
def to_cpu(arr):
return cp.asnumpy(arr) if hasattr(arr, 'device') else arr
except ImportError:
print("⚠️ CuPy not available - falling back to CPU numpy")
import numpy as cp
GPU_AVAILABLE = False
# 定义转换函数
def to_cpu(arr):
return arr
# Add parent directories to path to import models
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
from model_training.rnn_model import GRUDecoder
from model_training_lstm.rnn_model import LSTMDecoder
from model_training.evaluate_model_helpers import *
from model_training.evaluate_model_helpers import LOGIT_TO_PHONEME
def parse_arguments():
parser = argparse.ArgumentParser(description='高效TTA-E参数搜索')
# 模型和数据路径
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline')
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv')
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'])
parser.add_argument('--gpu_number', type=int, default=0)
# 搜索空间
parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0')
# TTA参数
parser.add_argument('--tta_noise_std', type=float, default=0.01)
parser.add_argument('--tta_smooth_range', type=float, default=0.5)
parser.add_argument('--tta_scale_range', type=float, default=0.05)
parser.add_argument('--tta_cut_max', type=int, default=3)
# 缓存控制
parser.add_argument('--cache_file', type=str, default='tta_cache.pkl')
parser.add_argument('--force_recache', action='store_true')
parser.add_argument('--output_file', type=str, default='search_results.json')
# 并行处理
parser.add_argument('--max_workers', type=int, default=25, help='最大并行工作线程数')
parser.add_argument('--batch_size', type=int, default=100, help='每批处理的配置数量')
return parser.parse_args()
def generate_base_predictions(args):
"""生成所有基础增强的预测结果"""
print("🔄 生成基础预测缓存...")
# 设置设备
if torch.cuda.is_available() and args.gpu_number >= 0:
device = torch.device(f'cuda:{args.gpu_number}')
else:
device = torch.device('cpu')
# 加载模型
gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml'))
lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml'))
gru_model = GRUDecoder(
neural_dim=gru_model_args['model']['n_input_features'],
n_units=gru_model_args['model']['n_units'],
n_days=len(gru_model_args['dataset']['sessions']),
n_classes=gru_model_args['dataset']['n_classes'],
rnn_dropout=gru_model_args['model']['rnn_dropout'],
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
n_layers=gru_model_args['model']['n_layers'],
patch_size=gru_model_args['model']['patch_size'],
patch_stride=gru_model_args['model']['patch_stride'],
)
lstm_model = LSTMDecoder(
neural_dim=lstm_model_args['model']['n_input_features'],
n_units=lstm_model_args['model']['n_units'],
n_days=len(lstm_model_args['dataset']['sessions']),
n_classes=lstm_model_args['dataset']['n_classes'],
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
n_layers=lstm_model_args['model']['n_layers'],
patch_size=lstm_model_args['model']['patch_size'],
patch_stride=lstm_model_args['model']['patch_stride'],
)
# 加载权重
gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# 清理键名
for checkpoint in [gru_checkpoint, lstm_checkpoint]:
for key in list(checkpoint['model_state_dict'].keys()):
checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
gru_model.to(device)
lstm_model.to(device)
gru_model.eval()
lstm_model.eval()
# 加载数据
b2txt_csv_df = pd.read_csv(args.csv_path)
test_data = {}
total_trials = 0
for session in gru_model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')]
if f'data_{args.eval_type}.hdf5' in files:
eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_trials += len(test_data[session]["neural_features"])
print(f"总试验数: {total_trials}")
# 生成所有基础增强类型的预测
augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
cache = {}
for aug_type in augmentation_types:
print(f"处理增强类型: {aug_type}")
cache[aug_type] = {}
with tqdm(total=total_trials, desc=f'{aug_type}', unit='trial') as pbar:
for session, data in test_data.items():
input_layer = gru_model_args['dataset']['sessions'].index(session)
cache[aug_type][session] = []
for trial in range(len(data['neural_features'])):
neural_input = data['neural_features'][trial]
neural_input = np.expand_dims(neural_input, axis=0)
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
# 应用增强
x_aug = neural_input.clone()
if aug_type == 'noise':
noise_scale = args.tta_noise_std * (0.5 + 0.5 * np.random.rand())
x_aug = x_aug + torch.randn_like(x_aug) * noise_scale
elif aug_type == 'scale':
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * args.tta_scale_range
x_aug = x_aug * scale_factor
elif aug_type == 'shift' and args.tta_cut_max > 0:
shift_amount = np.random.randint(1, min(args.tta_cut_max + 1, x_aug.shape[1] // 8))
x_aug = torch.cat([x_aug[:, shift_amount:, :], x_aug[:, :shift_amount, :]], dim=1)
elif aug_type == 'smooth':
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * args.tta_smooth_range
varied_smooth_std = max(0.3, gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation)
# 应用高斯平滑
with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
if aug_type == 'smooth':
x_smoothed = gauss_smooth(
inputs=x_aug, device=device,
smooth_kernel_std=varied_smooth_std,
smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'],
padding='valid',
)
else:
x_smoothed = gauss_smooth(
inputs=x_aug, device=device,
smooth_kernel_std=gru_model_args['dataset']['data_transforms']['smooth_kernel_std'],
smooth_kernel_size=gru_model_args['dataset']['data_transforms']['smooth_kernel_size'],
padding='valid',
)
with torch.no_grad():
gru_logits, _ = gru_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True)
lstm_logits, _ = lstm_model(x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True)
gru_probs = torch.softmax(gru_logits, dim=-1).float().cpu().numpy()
lstm_probs = torch.softmax(lstm_logits, dim=-1).float().cpu().numpy()
# 保存预测和真实标签信息
trial_data = {
'gru_probs': gru_probs,
'lstm_probs': lstm_probs,
'trial_info': {
'session': session,
'block_num': data['block_num'][trial],
'trial_num': data['trial_num'][trial],
}
}
if args.eval_type == 'val':
trial_data['trial_info'].update({
'seq_class_ids': data['seq_class_ids'][trial],
'seq_len': data['seq_len'][trial],
'sentence_label': data['sentence_label'][trial],
})
cache[aug_type][session].append(trial_data)
pbar.update(1)
# 保存缓存
with open(args.cache_file, 'wb') as f:
pickle.dump(cache, f)
print(f"✅ 缓存已保存到 {args.cache_file}")
return cache
def evaluate_config(cache, gru_weight, tta_weights, eval_type='val'):
"""评估特定参数配置 - GPU加速版本"""
lstm_weight = 1.0 - gru_weight
epsilon = 1e-8
total_edit_distance = 0
total_true_length = 0
# 归一化TTA权重
enabled_augmentations = [k for k, v in tta_weights.items() if v > 0]
if len(enabled_augmentations) == 0:
return float('inf')
total_tta_weight = sum(tta_weights[k] for k in enabled_augmentations)
norm_tta_weights = {k: tta_weights[k]/total_tta_weight for k in enabled_augmentations}
# 遍历所有试验
for session in cache['original'].keys():
for trial_idx in range(len(cache['original'][session])):
trial_info = cache['original'][session][trial_idx]['trial_info']
if eval_type == 'val' and 'seq_class_ids' not in trial_info:
continue
# 集成所有启用的增强 - 使用GPU加速
weighted_gru_probs = None
weighted_lstm_probs = None
for aug_type in enabled_augmentations:
weight = norm_tta_weights[aug_type]
# 使用GPU数组进行计算
if GPU_AVAILABLE:
gru_probs = cp.asarray(cache[aug_type][session][trial_idx]['gru_probs'])
lstm_probs = cp.asarray(cache[aug_type][session][trial_idx]['lstm_probs'])
else:
gru_probs = cache[aug_type][session][trial_idx]['gru_probs']
lstm_probs = cache[aug_type][session][trial_idx]['lstm_probs']
if weighted_gru_probs is None:
weighted_gru_probs = weight * gru_probs
weighted_lstm_probs = weight * lstm_probs
else:
# 处理长度不同的情况
min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1])
weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :]
weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :]
# GRU+LSTM集成 - GPU加速的log空间计算
weighted_gru_probs = weighted_gru_probs + epsilon
weighted_lstm_probs = weighted_lstm_probs + epsilon
if GPU_AVAILABLE:
log_ensemble_probs = (gru_weight * cp.log(weighted_gru_probs) +
lstm_weight * cp.log(weighted_lstm_probs))
ensemble_probs = cp.exp(log_ensemble_probs)
ensemble_probs = ensemble_probs / ensemble_probs.sum(axis=-1, keepdims=True)
# 解码 - GPU加速的argmax
pred_seq = cp.argmax(ensemble_probs[0], axis=-1)
pred_seq = to_cpu(pred_seq) # 转回CPU进行后续处理
else:
log_ensemble_probs = (gru_weight * np.log(weighted_gru_probs) +
lstm_weight * np.log(weighted_lstm_probs))
ensemble_probs = np.exp(log_ensemble_probs)
ensemble_probs = ensemble_probs / ensemble_probs.sum(axis=-1, keepdims=True)
# 解码
pred_seq = np.argmax(ensemble_probs[0], axis=-1)
# 后处理在CPU上进行因为涉及Python列表操作
pred_seq = [int(p) for p in pred_seq if p != 0]
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq]
if eval_type == 'val':
true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']]
true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq]
ed = editdistance.eval(true_phonemes, pred_phonemes)
total_edit_distance += ed
total_true_length += len(true_phonemes)
if eval_type == 'val' and total_true_length > 0:
return 100 * total_edit_distance / total_true_length
return 0.0
def evaluate_config_wrapper(args_tuple):
"""包装函数用于多线程处理"""
cache, gru_weight, tta_weights, eval_type = args_tuple
per = evaluate_config(cache, gru_weight, tta_weights, eval_type)
return {
'gru_weight': gru_weight,
'lstm_weight': 1.0 - gru_weight,
'tta_weights': tta_weights.copy(),
'per': per
}
def search_parameters(cache, args):
"""搜索最优参数"""
print("🔍 开始参数搜索...")
# 解析搜索空间
gru_weights = [float(x) for x in args.gru_weights.split(',')]
noise_weights = [float(x) for x in args.tta_noise_weights.split(',')]
scale_weights = [float(x) for x in args.tta_scale_weights.split(',')]
shift_weights = [float(x) for x in args.tta_shift_weights.split(',')]
smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')]
# 生成所有配置组合
configs = []
for gru_w in gru_weights:
for noise_w in noise_weights:
for scale_w in scale_weights:
for shift_w in shift_weights:
for smooth_w in smooth_weights:
tta_weights = {
'original': 1.0, # 总是包含原始数据
'noise': noise_w,
'scale': scale_w,
'shift': shift_w,
'smooth': smooth_w
}
configs.append((cache, gru_w, tta_weights, args.eval_type))
total_configs = len(configs)
print(f"搜索空间: {total_configs} 个配置")
print(f"使用 {args.max_workers} 个线程并行处理...")
best_per = float('inf')
best_config = None
all_results = []
completed_count = 0
# 使用线程池并行处理
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
# 提交所有任务
future_to_config = {
executor.submit(evaluate_config_wrapper, config): config
for config in configs
}
# 处理完成的任务
for future in as_completed(future_to_config):
try:
result = future.result()
all_results.append(result)
completed_count += 1
# 更新最佳结果
if result['per'] < best_per:
best_per = result['per']
best_config = result
tw = result['tta_weights']
print(f"🎯 新最佳[{completed_count}/{total_configs}]: PER={best_per:.3f}% | "
f"GRU={result['gru_weight']:.1f} | "
f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})")
# 定期进度报告
if completed_count % args.batch_size == 0:
progress = 100 * completed_count / total_configs
print(f"📊 进度: {completed_count}/{total_configs} ({progress:.1f}%) | 当前最优PER: {best_per:.3f}%")
elif completed_count % 50 == 0: # 更频繁的简单进度
print(f"... {completed_count}/{total_configs} ...", end='', flush=True)
except Exception as e:
completed_count += 1
config = future_to_config[future]
print(f"❌ 配置失败: {config[1]:.1f}, 错误: {e}")
# 添加失败的结果记录
all_results.append({
'gru_weight': config[1],
'lstm_weight': 1.0 - config[1],
'tta_weights': config[2],
'per': float('inf'),
'error': str(e)
})
print(f"\n✅ 所有任务完成!")
# 找到真正的最佳结果(防止异常情况)
valid_results = [r for r in all_results if 'error' not in r and r['per'] != float('inf')]
if valid_results:
best_config = min(valid_results, key=lambda x: x['per'])
return all_results, best_config
def main():
args = parse_arguments()
print("🚀 高效TTA-E参数搜索")
print("=" * 50)
# 第一阶段:生成或加载缓存
if args.force_recache or not os.path.exists(args.cache_file):
cache = generate_base_predictions(args)
else:
print(f"📁 加载现有缓存: {args.cache_file}")
with open(args.cache_file, 'rb') as f:
cache = pickle.load(f)
print("✅ 缓存加载完成")
# 第二阶段:参数搜索
all_results, best_config = search_parameters(cache, args)
# 保存结果
results = {
'best_config': best_config,
'all_results': all_results,
'args': vars(args),
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
}
with open(args.output_file, 'w') as f:
json.dump(results, f, indent=2)
print("\n" + "=" * 50)
print("🏆 搜索完成!")
if best_config is not None:
print(f"最佳配置: PER={best_config['per']:.3f}%")
print(f"GRU权重: {best_config['gru_weight']:.1f}")
print(f"TTA权重: {best_config['tta_weights']}")
print(f"结果保存到: {args.output_file}")
# 显示top-10
sorted_results = sorted([r for r in all_results if 'error' not in r and r['per'] != float('inf')],
key=lambda x: x['per'])[:10]
print(f"\n📊 Top-10配置:")
for i, result in enumerate(sorted_results):
tw = result['tta_weights']
print(f"{i+1:2d}. PER={result['per']:6.3f}% | GRU={result['gru_weight']:.1f} | "
f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})")
else:
print("❌ 未找到有效的配置结果!所有配置都失败了。")
print("请检查缓存数据和评估逻辑。")
# 搜索统计
valid_results = [r for r in all_results if 'error' not in r and r['per'] != float('inf')]
print(f"\n📈 搜索统计:")
print(f" 总配置数: {len(all_results)}")
print(f" 成功配置数: {len(valid_results)}")
print(f" 失败配置数: {len(all_results) - len(valid_results)}")
if valid_results:
valid_pers = [r['per'] for r in valid_results]
print(f" PER范围: {min(valid_pers):.3f}% - {max(valid_pers):.3f}%")
print(f" 平均PER: {sum(valid_pers)/len(valid_pers):.3f}%")
if __name__ == "__main__":
main()