Files
b2txt25/TTA-E/DE_optimize.py
2025-10-06 15:17:44 +08:00

629 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
差分进化算法优化TTA-E集成参数
使用SciPy的差分进化算法优化gru_weight和tta_weights参数目标是最小化PER
"""
import os
import sys
import torch
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import pickle
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from scipy.optimize import differential_evolution
import warnings
warnings.filterwarnings("ignore")
# Add parent directories to path to import models
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
from model_training.rnn_model import GRUDecoder
from model_training_lstm.rnn_model import LSTMDecoder
from model_training.evaluate_model_helpers import *
class TTAEnsembleCache:
"""高效缓存系统存储GRU和LSTM在5种增强方式下的预测结果"""
def __init__(self, cache_dir='./tta_cache'):
self.cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
self.gru_cache = {}
self.lstm_cache = {}
self.augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
def _get_cache_key(self, session, trial, aug_type):
"""生成缓存键"""
return f"{session}_{trial}_{aug_type}"
def _get_cache_file(self, model_type):
"""获取缓存文件路径"""
return os.path.join(self.cache_dir, f'{model_type}_predictions.pkl')
def save_cache(self):
"""保存缓存到磁盘"""
with open(self._get_cache_file('gru'), 'wb') as f:
pickle.dump(self.gru_cache, f)
with open(self._get_cache_file('lstm'), 'wb') as f:
pickle.dump(self.lstm_cache, f)
def load_cache(self):
"""从磁盘加载缓存"""
try:
with open(self._get_cache_file('gru'), 'rb') as f:
self.gru_cache = pickle.load(f)
with open(self._get_cache_file('lstm'), 'rb') as f:
self.lstm_cache = pickle.load(f)
return True
except FileNotFoundError:
return False
def add_prediction(self, model_type, session, trial, aug_type, prediction):
"""添加预测结果到缓存"""
cache = self.gru_cache if model_type == 'gru' else self.lstm_cache
key = self._get_cache_key(session, trial, aug_type)
cache[key] = prediction
def get_prediction(self, model_type, session, trial, aug_type):
"""从缓存获取预测结果"""
cache = self.gru_cache if model_type == 'gru' else self.lstm_cache
key = self._get_cache_key(session, trial, aug_type)
return cache.get(key, None)
def is_complete(self, sessions, trials_per_session):
"""检查缓存是否完整"""
total_expected = 0
total_cached_gru = 0
total_cached_lstm = 0
for session in sessions:
for trial in range(trials_per_session[session]):
for aug_type in self.augmentation_types:
total_expected += 1
if self.get_prediction('gru', session, trial, aug_type) is not None:
total_cached_gru += 1
if self.get_prediction('lstm', session, trial, aug_type) is not None:
total_cached_lstm += 1
return (total_cached_gru == total_expected and
total_cached_lstm == total_expected)
class TTAEDifferentialEvolutionOptimizer:
"""使用差分进化算法优化TTA-E参数的主类"""
def __init__(self,
gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
data_dir='../data/hdf5_data_final',
csv_path='../data/t15_copyTaskData_description.csv',
gpu_number=0):
self.gru_model_path = gru_model_path
self.lstm_model_path = lstm_model_path
self.data_dir = data_dir
self.csv_path = csv_path
self.gpu_number = gpu_number
# 初始化设备
if torch.cuda.is_available() and gpu_number >= 0:
self.device = torch.device(f'cuda:{gpu_number}')
print(f'Using {self.device} for model inference.')
else:
self.device = torch.device('cpu')
print('Using CPU for model inference.')
# 初始化缓存
self.cache = TTAEnsembleCache()
# TTA参数
self.tta_noise_std = 0.01
self.tta_smooth_range = 0.5
self.tta_scale_range = 0.05
self.tta_cut_max = 3
# 差分进化算法参数
self.population_size = 15 # 种群大小倍数 (实际种群 = 15 * 6维 = 90)
self.max_iterations = 50 # 最大迭代次数
self.tolerance = 1e-6 # 收敛容忍度
self.mutation_factor = 0.7 # 变异因子 [0.5, 2.0]
self.crossover_prob = 0.9 # 交叉概率 [0, 1]
# 评估计数器
self.evaluation_count = 0
self.best_per_history = []
# 加载模型和数据
self._load_models()
self._load_data()
def _load_models(self):
"""加载GRU和LSTM模型"""
print("Loading models...")
# 加载模型参数
self.gru_model_args = OmegaConf.load(os.path.join(self.gru_model_path, 'checkpoint/args.yaml'))
self.lstm_model_args = OmegaConf.load(os.path.join(self.lstm_model_path, 'checkpoint/args.yaml'))
# 定义GRU模型
self.gru_model = GRUDecoder(
neural_dim=self.gru_model_args['model']['n_input_features'],
n_units=self.gru_model_args['model']['n_units'],
n_days=len(self.gru_model_args['dataset']['sessions']),
n_classes=self.gru_model_args['dataset']['n_classes'],
rnn_dropout=self.gru_model_args['model']['rnn_dropout'],
input_dropout=self.gru_model_args['model']['input_network']['input_layer_dropout'],
n_layers=self.gru_model_args['model']['n_layers'],
patch_size=self.gru_model_args['model']['patch_size'],
patch_stride=self.gru_model_args['model']['patch_stride'],
)
# 加载GRU模型权重
gru_checkpoint = torch.load(os.path.join(self.gru_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=self.device)
# 处理模型键名
for key in list(gru_checkpoint['model_state_dict'].keys()):
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
self.gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
# 定义LSTM模型
self.lstm_model = LSTMDecoder(
neural_dim=self.lstm_model_args['model']['n_input_features'],
n_units=self.lstm_model_args['model']['n_units'],
n_days=len(self.lstm_model_args['dataset']['sessions']),
n_classes=self.lstm_model_args['dataset']['n_classes'],
rnn_dropout=self.lstm_model_args['model']['rnn_dropout'],
input_dropout=self.lstm_model_args['model']['input_network']['input_layer_dropout'],
n_layers=self.lstm_model_args['model']['n_layers'],
patch_size=self.lstm_model_args['model']['patch_size'],
patch_stride=self.lstm_model_args['model']['patch_stride'],
)
# 加载LSTM模型权重
lstm_checkpoint = torch.load(os.path.join(self.lstm_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=self.device)
# 处理模型键名
for key in list(lstm_checkpoint['model_state_dict'].keys()):
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
self.lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
# 移动模型到设备并设置为评估模式
self.gru_model.to(self.device)
self.lstm_model.to(self.device)
self.gru_model.eval()
self.lstm_model.eval()
print("Models loaded successfully!")
def _load_data(self):
"""加载验证数据集"""
print("Loading validation data...")
# 加载CSV文件
b2txt_csv_df = pd.read_csv(self.csv_path)
# 加载验证数据
self.test_data = {}
self.trials_per_session = {}
total_test_trials = 0
for session in self.gru_model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(self.data_dir, session)) if f.endswith('.hdf5')]
if 'data_val.hdf5' in files:
eval_file = os.path.join(self.data_dir, session, 'data_val.hdf5')
data = load_h5py_file(eval_file, b2txt_csv_df)
self.test_data[session] = data
self.trials_per_session[session] = len(data["neural_features"])
total_test_trials += len(data["neural_features"])
print(f'Loaded {len(data["neural_features"])} validation trials for session {session}.')
print(f'Total number of validation trials: {total_test_trials}')
def _apply_augmentation(self, x, aug_type):
"""应用数据增强"""
x_augmented = x.clone()
if aug_type == 'original':
pass
elif aug_type == 'noise':
noise_scale = self.tta_noise_std * (0.5 + 0.5 * np.random.rand())
noise = torch.randn_like(x_augmented) * noise_scale
x_augmented = x_augmented + noise
elif aug_type == 'scale':
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * self.tta_scale_range
x_augmented = x_augmented * scale_factor
elif aug_type == 'shift' and self.tta_cut_max > 0:
shift_amount = np.random.randint(1, min(self.tta_cut_max + 1, x_augmented.shape[1] // 8))
x_augmented = torch.cat([x_augmented[:, shift_amount:, :],
x_augmented[:, :shift_amount, :]], dim=1)
elif aug_type == 'smooth':
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range
varied_smooth_std = max(0.3, self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation)
return x_augmented
def _get_model_prediction(self, model, model_args, x_smoothed, input_layer):
"""获取单个模型的预测结果"""
with torch.no_grad():
logits, _ = model(
x=x_smoothed,
day_idx=torch.tensor([input_layer], device=self.device),
states=None,
return_state=True,
)
probs = torch.softmax(logits, dim=-1)
return probs
def generate_all_predictions(self):
"""生成所有模型在所有增强方式下的预测结果并缓存"""
print("Generating all TTA predictions for caching...")
# 尝试加载现有缓存
if self.cache.load_cache():
if self.cache.is_complete(self.test_data.keys(), self.trials_per_session):
print("Complete cache found, skipping prediction generation.")
return
else:
print("Incomplete cache found, generating missing predictions...")
total_trials = sum(self.trials_per_session.values())
total_predictions = total_trials * len(self.cache.augmentation_types) * 2 # 2 models
with tqdm(total=total_predictions, desc='Generating cached predictions', unit='pred') as pbar:
for session, data in self.test_data.items():
input_layer = self.gru_model_args['dataset']['sessions'].index(session)
for trial in range(len(data['neural_features'])):
# 获取神经输入
neural_input = data['neural_features'][trial]
neural_input = np.expand_dims(neural_input, axis=0)
neural_input = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16)
# 为每种增强方式生成预测
for aug_type in self.cache.augmentation_types:
# 检查是否已缓存
if (self.cache.get_prediction('gru', session, trial, aug_type) is not None and
self.cache.get_prediction('lstm', session, trial, aug_type) is not None):
pbar.update(2)
continue
# 应用增强
x_augmented = self._apply_augmentation(neural_input, aug_type)
# 应用高斯平滑
default_smooth_std = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
default_smooth_size = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
if aug_type == 'smooth':
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range
varied_smooth_std = max(0.3, default_smooth_std + smooth_variation)
else:
varied_smooth_std = default_smooth_std
with torch.autocast(device_type="cuda", enabled=self.gru_model_args['use_amp'], dtype=torch.bfloat16):
x_smoothed = gauss_smooth(
inputs=x_augmented,
device=self.device,
smooth_kernel_std=varied_smooth_std,
smooth_kernel_size=default_smooth_size,
padding='valid',
)
# GRU预测
if self.cache.get_prediction('gru', session, trial, aug_type) is None:
gru_probs = self._get_model_prediction(
self.gru_model, self.gru_model_args, x_smoothed, input_layer
)
self.cache.add_prediction('gru', session, trial, aug_type,
gru_probs.cpu().numpy())
pbar.update(1)
# LSTM预测
if self.cache.get_prediction('lstm', session, trial, aug_type) is None:
lstm_probs = self._get_model_prediction(
self.lstm_model, self.lstm_model_args, x_smoothed, input_layer
)
self.cache.add_prediction('lstm', session, trial, aug_type,
lstm_probs.cpu().numpy())
pbar.update(1)
# 保存缓存
print("Saving cache to disk...")
self.cache.save_cache()
print("Cache generation completed!")
def evaluate_parameters(self, gru_weight, tta_weights):
"""评估给定参数组合的PER性能"""
lstm_weight = 1.0 - gru_weight
# 将tta_weights转换为字典
tta_weights_dict = {
'original': tta_weights[0],
'noise': tta_weights[1],
'scale': tta_weights[2],
'shift': tta_weights[3],
'smooth': tta_weights[4]
}
total_true_length = 0
total_edit_distance = 0
for session, data in self.test_data.items():
for trial in range(len(data['neural_features'])):
# 收集所有增强方式的预测结果
all_gru_probs = []
all_lstm_probs = []
sample_weights = []
for aug_type in self.cache.augmentation_types:
if tta_weights_dict[aug_type] <= 0:
continue
# 从缓存获取预测结果
gru_probs = self.cache.get_prediction('gru', session, trial, aug_type)
lstm_probs = self.cache.get_prediction('lstm', session, trial, aug_type)
if gru_probs is not None and lstm_probs is not None:
all_gru_probs.append(torch.tensor(gru_probs))
all_lstm_probs.append(torch.tensor(lstm_probs))
sample_weights.append(tta_weights_dict[aug_type])
if len(all_gru_probs) == 0:
continue
# TTA融合
if len(all_gru_probs) > 1:
min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs])
# 截断到最小长度
truncated_gru_probs = []
truncated_lstm_probs = []
for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs):
if gru_probs.shape[1] > min_length:
truncated_gru_probs.append(gru_probs[:, :min_length, :])
else:
truncated_gru_probs.append(gru_probs)
if lstm_probs.shape[1] > min_length:
truncated_lstm_probs.append(lstm_probs[:, :min_length, :])
else:
truncated_lstm_probs.append(lstm_probs)
# 加权平均
sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32)
sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum()
weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0])
weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0])
for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)):
weighted_gru_probs += weight * gru_probs
weighted_lstm_probs += weight * lstm_probs
avg_gru_probs = weighted_gru_probs
avg_lstm_probs = weighted_lstm_probs
else:
avg_gru_probs = all_gru_probs[0]
avg_lstm_probs = all_lstm_probs[0]
# 集成融合 (几何平均)
epsilon = 1e-8
avg_gru_probs = avg_gru_probs + epsilon
avg_lstm_probs = avg_lstm_probs + epsilon
log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) +
lstm_weight * torch.log(avg_lstm_probs))
ensemble_probs = torch.exp(log_ensemble_probs)
ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
# 转换为预测序列
logits = torch.log(ensemble_probs + epsilon)
pred_seq = np.argmax(logits[0].numpy(), axis=-1)
# 移除空白和连续重复
pred_seq = [int(p) for p in pred_seq if p != 0]
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
# 获取真实序列
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
# 计算编辑距离
ed = editdistance.eval(true_seq, pred_seq)
total_true_length += len(true_seq)
total_edit_distance += ed
# 计算PER
if total_true_length == 0:
return 100.0 # 返回最大PER作为惩罚
per = 100 * total_edit_distance / total_true_length
return per
def objective_function(self, params):
"""差分进化的目标函数"""
self.evaluation_count += 1
# 解码参数
gru_weight = params[0] # 范围 [0, 1]
tta_weights = params[1:6] # 5个TTA权重
# 确保权重非负
tta_weights = np.maximum(tta_weights, 0)
# 如果所有TTA权重都为0返回最大PER作为惩罚
if np.sum(tta_weights) == 0:
return 100.0
try:
per = self.evaluate_parameters(gru_weight, tta_weights)
# 记录历史最佳
if len(self.best_per_history) == 0 or per < min(self.best_per_history):
print(f"🎯 Eval {self.evaluation_count}: New best PER = {per:.4f}%")
print(f" GRU weight = {gru_weight:.4f}, TTA weights = {tta_weights}")
self.best_per_history.append(per)
# 每10次评估输出进度
if self.evaluation_count % 10 == 0:
current_best = min(self.best_per_history)
print(f"📊 Progress: {self.evaluation_count} evaluations, Best PER = {current_best:.4f}%")
return per # 差分进化默认最小化目标函数
except Exception as e:
print(f"Error in objective function evaluation: {e}")
return 100.0
def optimize(self):
"""运行差分进化算法优化"""
print("Starting Differential Evolution optimization...")
print(f"Algorithm parameters:")
print(f" - Population size multiplier: {self.population_size}")
print(f" - Max iterations: {self.max_iterations}")
print(f" - Mutation factor: {self.mutation_factor}")
print(f" - Crossover probability: {self.crossover_prob}")
print(f" - Tolerance: {self.tolerance}")
# 首先生成所有预测并缓存
self.generate_all_predictions()
# 定义参数边界
# params = [gru_weight, original_weight, noise_weight, scale_weight, shift_weight, smooth_weight]
bounds = [
(0.0, 1.0), # gru_weight: [0, 1]
(0.0, 5.0), # original weight: [0, 5]
(0.0, 5.0), # noise weight: [0, 5]
(0.0, 5.0), # scale weight: [0, 5]
(0.0, 5.0), # shift weight: [0, 5]
(0.0, 5.0), # smooth weight: [0, 5]
]
print(f"\nParameter bounds:")
param_names = ['GRU weight', 'Original', 'Noise', 'Scale', 'Shift', 'Smooth']
for i, (name, (low, high)) in enumerate(zip(param_names, bounds)):
print(f" - {name}: [{low}, {high}]")
# 运行差分进化优化
print(f"\nRunning differential evolution optimization...")
start_time = time.time()
result = differential_evolution(
func=self.objective_function,
bounds=bounds,
popsize=self.population_size, # 种群大小倍数
maxiter=self.max_iterations, # 最大迭代次数
tol=self.tolerance, # 收敛容忍度
mutation=self.mutation_factor, # 变异因子
recombination=self.crossover_prob, # 交叉概率
seed=42, # 随机种子确保可复现
disp=True, # 显示优化过程
polish=True, # 最后用局部优化算法精炼
workers=1, # 单线程避免缓存冲突
updating='deferred', # 延迟更新策略
)
end_time = time.time()
# 获取最佳解
best_params = result.x
best_per = result.fun
print("\n" + "="*60)
print("DIFFERENTIAL EVOLUTION OPTIMIZATION COMPLETED!")
print("="*60)
print(f"Optimization time: {end_time - start_time:.2f} seconds")
print(f"Total function evaluations: {result.nfev}")
print(f"Optimization success: {result.success}")
print(f"Termination message: {result.message}")
print(f"\nBest solution:")
print(f"Best GRU weight: {best_params[0]:.4f}")
print(f"Best LSTM weight: {1.0 - best_params[0]:.4f}")
print(f"Best TTA weights:")
aug_types = ['original', 'noise', 'scale', 'shift', 'smooth']
for i, aug_type in enumerate(aug_types):
print(f" - {aug_type}: {best_params[i+1]:.4f}")
print(f"Best PER: {best_per:.4f}%")
# 保存详细结果
optimization_result = {
'best_params': best_params,
'gru_weight': best_params[0],
'lstm_weight': 1.0 - best_params[0],
'tta_weights': {
'original': best_params[1],
'noise': best_params[2],
'scale': best_params[3],
'shift': best_params[4],
'smooth': best_params[5]
},
'best_per': best_per,
'optimization_time': end_time - start_time,
'function_evaluations': result.nfev,
'success': result.success,
'message': result.message,
'per_history': self.best_per_history,
'algorithm': 'differential_evolution',
'algorithm_params': {
'popsize': self.population_size,
'maxiter': self.max_iterations,
'mutation': self.mutation_factor,
'recombination': self.crossover_prob,
'tolerance': self.tolerance
}
}
# 保存到文件
timestamp = time.strftime("%Y%m%d_%H%M%S")
result_file = f'de_optimization_result_{timestamp}.pkl'
with open(result_file, 'wb') as f:
pickle.dump(optimization_result, f)
print(f"\nResults saved to: {result_file}")
# 性能分析
print(f"\nPerformance Analysis:")
print(f" - Average evaluation time: {(end_time - start_time) / result.nfev:.3f} seconds")
print(f" - Evaluations per minute: {result.nfev / ((end_time - start_time) / 60):.1f}")
if len(self.best_per_history) > 10:
improvement = self.best_per_history[0] - min(self.best_per_history)
print(f" - Total PER improvement: {improvement:.4f}%")
return optimization_result
def main():
"""主函数"""
print("TTA-E Differential Evolution Optimization")
print("="*50)
# 创建优化器
optimizer = TTAEDifferentialEvolutionOptimizer(
gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
data_dir='../data/hdf5_data_final',
csv_path='../data/t15_copyTaskData_description.csv',
gpu_number=0
)
# 运行优化
result = optimizer.optimize()
print("\nDifferential Evolution optimization completed successfully!")
return result
if __name__ == "__main__":
# 设置环境
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# 运行主函数
result = main()