629 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			629 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | |||
|  | """
 | |||
|  | 差分进化算法优化TTA-E集成参数 | |||
|  | 使用SciPy的差分进化算法优化gru_weight和tta_weights参数,目标是最小化PER | |||
|  | """
 | |||
|  | 
 | |||
|  | import os | |||
|  | import sys | |||
|  | import torch | |||
|  | import numpy as np | |||
|  | import pandas as pd | |||
|  | from omegaconf import OmegaConf | |||
|  | import time | |||
|  | from tqdm import tqdm | |||
|  | import editdistance | |||
|  | import pickle | |||
|  | import multiprocessing as mp | |||
|  | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor | |||
|  | from scipy.optimize import differential_evolution | |||
|  | import warnings | |||
|  | warnings.filterwarnings("ignore") | |||
|  | 
 | |||
|  | # Add parent directories to path to import models | |||
|  | sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) | |||
|  | sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) | |||
|  | 
 | |||
|  | from model_training.rnn_model import GRUDecoder | |||
|  | from model_training_lstm.rnn_model import LSTMDecoder | |||
|  | from model_training.evaluate_model_helpers import * | |||
|  | 
 | |||
|  | class TTAEnsembleCache: | |||
|  |     """高效缓存系统,存储GRU和LSTM在5种增强方式下的预测结果""" | |||
|  |      | |||
|  |     def __init__(self, cache_dir='./tta_cache'): | |||
|  |         self.cache_dir = cache_dir | |||
|  |         os.makedirs(cache_dir, exist_ok=True) | |||
|  |         self.gru_cache = {} | |||
|  |         self.lstm_cache = {} | |||
|  |         self.augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth'] | |||
|  |          | |||
|  |     def _get_cache_key(self, session, trial, aug_type): | |||
|  |         """生成缓存键""" | |||
|  |         return f"{session}_{trial}_{aug_type}" | |||
|  |      | |||
|  |     def _get_cache_file(self, model_type): | |||
|  |         """获取缓存文件路径""" | |||
|  |         return os.path.join(self.cache_dir, f'{model_type}_predictions.pkl') | |||
|  |      | |||
|  |     def save_cache(self): | |||
|  |         """保存缓存到磁盘""" | |||
|  |         with open(self._get_cache_file('gru'), 'wb') as f: | |||
|  |             pickle.dump(self.gru_cache, f) | |||
|  |         with open(self._get_cache_file('lstm'), 'wb') as f: | |||
|  |             pickle.dump(self.lstm_cache, f) | |||
|  |      | |||
|  |     def load_cache(self): | |||
|  |         """从磁盘加载缓存""" | |||
|  |         try: | |||
|  |             with open(self._get_cache_file('gru'), 'rb') as f: | |||
|  |                 self.gru_cache = pickle.load(f) | |||
|  |             with open(self._get_cache_file('lstm'), 'rb') as f: | |||
|  |                 self.lstm_cache = pickle.load(f) | |||
|  |             return True | |||
|  |         except FileNotFoundError: | |||
|  |             return False | |||
|  |      | |||
|  |     def add_prediction(self, model_type, session, trial, aug_type, prediction): | |||
|  |         """添加预测结果到缓存""" | |||
|  |         cache = self.gru_cache if model_type == 'gru' else self.lstm_cache | |||
|  |         key = self._get_cache_key(session, trial, aug_type) | |||
|  |         cache[key] = prediction | |||
|  |      | |||
|  |     def get_prediction(self, model_type, session, trial, aug_type): | |||
|  |         """从缓存获取预测结果""" | |||
|  |         cache = self.gru_cache if model_type == 'gru' else self.lstm_cache | |||
|  |         key = self._get_cache_key(session, trial, aug_type) | |||
|  |         return cache.get(key, None) | |||
|  |      | |||
|  |     def is_complete(self, sessions, trials_per_session): | |||
|  |         """检查缓存是否完整""" | |||
|  |         total_expected = 0 | |||
|  |         total_cached_gru = 0 | |||
|  |         total_cached_lstm = 0 | |||
|  |          | |||
|  |         for session in sessions: | |||
|  |             for trial in range(trials_per_session[session]): | |||
|  |                 for aug_type in self.augmentation_types: | |||
|  |                     total_expected += 1 | |||
|  |                     if self.get_prediction('gru', session, trial, aug_type) is not None: | |||
|  |                         total_cached_gru += 1 | |||
|  |                     if self.get_prediction('lstm', session, trial, aug_type) is not None: | |||
|  |                         total_cached_lstm += 1 | |||
|  |          | |||
|  |         return (total_cached_gru == total_expected and  | |||
|  |                 total_cached_lstm == total_expected) | |||
|  | 
 | |||
|  | class TTAEDifferentialEvolutionOptimizer: | |||
|  |     """使用差分进化算法优化TTA-E参数的主类""" | |||
|  |      | |||
|  |     def __init__(self,  | |||
|  |                  gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', | |||
|  |                  lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', | |||
|  |                  data_dir='../data/hdf5_data_final', | |||
|  |                  csv_path='../data/t15_copyTaskData_description.csv', | |||
|  |                  gpu_number=0): | |||
|  |          | |||
|  |         self.gru_model_path = gru_model_path | |||
|  |         self.lstm_model_path = lstm_model_path | |||
|  |         self.data_dir = data_dir | |||
|  |         self.csv_path = csv_path | |||
|  |         self.gpu_number = gpu_number | |||
|  |          | |||
|  |         # 初始化设备 | |||
|  |         if torch.cuda.is_available() and gpu_number >= 0: | |||
|  |             self.device = torch.device(f'cuda:{gpu_number}') | |||
|  |             print(f'Using {self.device} for model inference.') | |||
|  |         else: | |||
|  |             self.device = torch.device('cpu') | |||
|  |             print('Using CPU for model inference.') | |||
|  |          | |||
|  |         # 初始化缓存 | |||
|  |         self.cache = TTAEnsembleCache() | |||
|  |          | |||
|  |         # TTA参数 | |||
|  |         self.tta_noise_std = 0.01 | |||
|  |         self.tta_smooth_range = 0.5 | |||
|  |         self.tta_scale_range = 0.05 | |||
|  |         self.tta_cut_max = 3 | |||
|  |          | |||
|  |         # 差分进化算法参数 | |||
|  |         self.population_size = 15  # 种群大小倍数 (实际种群 = 15 * 6维 = 90) | |||
|  |         self.max_iterations = 50   # 最大迭代次数 | |||
|  |         self.tolerance = 1e-6      # 收敛容忍度 | |||
|  |         self.mutation_factor = 0.7 # 变异因子 [0.5, 2.0] | |||
|  |         self.crossover_prob = 0.9  # 交叉概率 [0, 1] | |||
|  |          | |||
|  |         # 评估计数器 | |||
|  |         self.evaluation_count = 0 | |||
|  |         self.best_per_history = [] | |||
|  |          | |||
|  |         # 加载模型和数据 | |||
|  |         self._load_models() | |||
|  |         self._load_data() | |||
|  |          | |||
|  |     def _load_models(self): | |||
|  |         """加载GRU和LSTM模型""" | |||
|  |         print("Loading models...") | |||
|  |          | |||
|  |         # 加载模型参数 | |||
|  |         self.gru_model_args = OmegaConf.load(os.path.join(self.gru_model_path, 'checkpoint/args.yaml')) | |||
|  |         self.lstm_model_args = OmegaConf.load(os.path.join(self.lstm_model_path, 'checkpoint/args.yaml')) | |||
|  |          | |||
|  |         # 定义GRU模型 | |||
|  |         self.gru_model = GRUDecoder( | |||
|  |             neural_dim=self.gru_model_args['model']['n_input_features'], | |||
|  |             n_units=self.gru_model_args['model']['n_units'],  | |||
|  |             n_days=len(self.gru_model_args['dataset']['sessions']), | |||
|  |             n_classes=self.gru_model_args['dataset']['n_classes'], | |||
|  |             rnn_dropout=self.gru_model_args['model']['rnn_dropout'], | |||
|  |             input_dropout=self.gru_model_args['model']['input_network']['input_layer_dropout'], | |||
|  |             n_layers=self.gru_model_args['model']['n_layers'], | |||
|  |             patch_size=self.gru_model_args['model']['patch_size'], | |||
|  |             patch_stride=self.gru_model_args['model']['patch_stride'], | |||
|  |         ) | |||
|  |          | |||
|  |         # 加载GRU模型权重 | |||
|  |         gru_checkpoint = torch.load(os.path.join(self.gru_model_path, 'checkpoint/best_checkpoint'),  | |||
|  |                                    weights_only=False, map_location=self.device) | |||
|  |         # 处理模型键名 | |||
|  |         for key in list(gru_checkpoint['model_state_dict'].keys()): | |||
|  |             gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key) | |||
|  |             gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key) | |||
|  |         self.gru_model.load_state_dict(gru_checkpoint['model_state_dict']) | |||
|  |          | |||
|  |         # 定义LSTM模型 | |||
|  |         self.lstm_model = LSTMDecoder( | |||
|  |             neural_dim=self.lstm_model_args['model']['n_input_features'], | |||
|  |             n_units=self.lstm_model_args['model']['n_units'],  | |||
|  |             n_days=len(self.lstm_model_args['dataset']['sessions']), | |||
|  |             n_classes=self.lstm_model_args['dataset']['n_classes'], | |||
|  |             rnn_dropout=self.lstm_model_args['model']['rnn_dropout'], | |||
|  |             input_dropout=self.lstm_model_args['model']['input_network']['input_layer_dropout'], | |||
|  |             n_layers=self.lstm_model_args['model']['n_layers'], | |||
|  |             patch_size=self.lstm_model_args['model']['patch_size'], | |||
|  |             patch_stride=self.lstm_model_args['model']['patch_stride'], | |||
|  |         ) | |||
|  |          | |||
|  |         # 加载LSTM模型权重 | |||
|  |         lstm_checkpoint = torch.load(os.path.join(self.lstm_model_path, 'checkpoint/best_checkpoint'),  | |||
|  |                                     weights_only=False, map_location=self.device) | |||
|  |         # 处理模型键名 | |||
|  |         for key in list(lstm_checkpoint['model_state_dict'].keys()): | |||
|  |             lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key) | |||
|  |             lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key) | |||
|  |         self.lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) | |||
|  |          | |||
|  |         # 移动模型到设备并设置为评估模式 | |||
|  |         self.gru_model.to(self.device) | |||
|  |         self.lstm_model.to(self.device) | |||
|  |         self.gru_model.eval() | |||
|  |         self.lstm_model.eval() | |||
|  |          | |||
|  |         print("Models loaded successfully!") | |||
|  |      | |||
|  |     def _load_data(self): | |||
|  |         """加载验证数据集""" | |||
|  |         print("Loading validation data...") | |||
|  |          | |||
|  |         # 加载CSV文件 | |||
|  |         b2txt_csv_df = pd.read_csv(self.csv_path) | |||
|  |          | |||
|  |         # 加载验证数据 | |||
|  |         self.test_data = {} | |||
|  |         self.trials_per_session = {} | |||
|  |         total_test_trials = 0 | |||
|  |          | |||
|  |         for session in self.gru_model_args['dataset']['sessions']: | |||
|  |             files = [f for f in os.listdir(os.path.join(self.data_dir, session)) if f.endswith('.hdf5')] | |||
|  |             if 'data_val.hdf5' in files: | |||
|  |                 eval_file = os.path.join(self.data_dir, session, 'data_val.hdf5') | |||
|  |                 data = load_h5py_file(eval_file, b2txt_csv_df) | |||
|  |                 self.test_data[session] = data | |||
|  |                 self.trials_per_session[session] = len(data["neural_features"]) | |||
|  |                 total_test_trials += len(data["neural_features"]) | |||
|  |                 print(f'Loaded {len(data["neural_features"])} validation trials for session {session}.') | |||
|  |          | |||
|  |         print(f'Total number of validation trials: {total_test_trials}') | |||
|  |      | |||
|  |     def _apply_augmentation(self, x, aug_type): | |||
|  |         """应用数据增强""" | |||
|  |         x_augmented = x.clone() | |||
|  |          | |||
|  |         if aug_type == 'original': | |||
|  |             pass | |||
|  |         elif aug_type == 'noise': | |||
|  |             noise_scale = self.tta_noise_std * (0.5 + 0.5 * np.random.rand()) | |||
|  |             noise = torch.randn_like(x_augmented) * noise_scale | |||
|  |             x_augmented = x_augmented + noise | |||
|  |         elif aug_type == 'scale': | |||
|  |             scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * self.tta_scale_range | |||
|  |             x_augmented = x_augmented * scale_factor | |||
|  |         elif aug_type == 'shift' and self.tta_cut_max > 0: | |||
|  |             shift_amount = np.random.randint(1, min(self.tta_cut_max + 1, x_augmented.shape[1] // 8)) | |||
|  |             x_augmented = torch.cat([x_augmented[:, shift_amount:, :],  | |||
|  |                                    x_augmented[:, :shift_amount, :]], dim=1) | |||
|  |         elif aug_type == 'smooth': | |||
|  |             smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range | |||
|  |             varied_smooth_std = max(0.3, self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation) | |||
|  |          | |||
|  |         return x_augmented | |||
|  |      | |||
|  |     def _get_model_prediction(self, model, model_args, x_smoothed, input_layer): | |||
|  |         """获取单个模型的预测结果""" | |||
|  |         with torch.no_grad(): | |||
|  |             logits, _ = model( | |||
|  |                 x=x_smoothed, | |||
|  |                 day_idx=torch.tensor([input_layer], device=self.device), | |||
|  |                 states=None, | |||
|  |                 return_state=True, | |||
|  |             ) | |||
|  |             probs = torch.softmax(logits, dim=-1) | |||
|  |             return probs | |||
|  |      | |||
|  |     def generate_all_predictions(self): | |||
|  |         """生成所有模型在所有增强方式下的预测结果并缓存""" | |||
|  |         print("Generating all TTA predictions for caching...") | |||
|  |          | |||
|  |         # 尝试加载现有缓存 | |||
|  |         if self.cache.load_cache(): | |||
|  |             if self.cache.is_complete(self.test_data.keys(), self.trials_per_session): | |||
|  |                 print("Complete cache found, skipping prediction generation.") | |||
|  |                 return | |||
|  |             else: | |||
|  |                 print("Incomplete cache found, generating missing predictions...") | |||
|  |          | |||
|  |         total_trials = sum(self.trials_per_session.values()) | |||
|  |         total_predictions = total_trials * len(self.cache.augmentation_types) * 2  # 2 models | |||
|  |          | |||
|  |         with tqdm(total=total_predictions, desc='Generating cached predictions', unit='pred') as pbar: | |||
|  |             for session, data in self.test_data.items(): | |||
|  |                 input_layer = self.gru_model_args['dataset']['sessions'].index(session) | |||
|  |                  | |||
|  |                 for trial in range(len(data['neural_features'])): | |||
|  |                     # 获取神经输入 | |||
|  |                     neural_input = data['neural_features'][trial] | |||
|  |                     neural_input = np.expand_dims(neural_input, axis=0) | |||
|  |                     neural_input = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16) | |||
|  |                      | |||
|  |                     # 为每种增强方式生成预测 | |||
|  |                     for aug_type in self.cache.augmentation_types: | |||
|  |                         # 检查是否已缓存 | |||
|  |                         if (self.cache.get_prediction('gru', session, trial, aug_type) is not None and | |||
|  |                             self.cache.get_prediction('lstm', session, trial, aug_type) is not None): | |||
|  |                             pbar.update(2) | |||
|  |                             continue | |||
|  |                          | |||
|  |                         # 应用增强 | |||
|  |                         x_augmented = self._apply_augmentation(neural_input, aug_type) | |||
|  |                          | |||
|  |                         # 应用高斯平滑 | |||
|  |                         default_smooth_std = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] | |||
|  |                         default_smooth_size = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_size'] | |||
|  |                          | |||
|  |                         if aug_type == 'smooth': | |||
|  |                             smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range | |||
|  |                             varied_smooth_std = max(0.3, default_smooth_std + smooth_variation) | |||
|  |                         else: | |||
|  |                             varied_smooth_std = default_smooth_std | |||
|  |                          | |||
|  |                         with torch.autocast(device_type="cuda", enabled=self.gru_model_args['use_amp'], dtype=torch.bfloat16): | |||
|  |                             x_smoothed = gauss_smooth( | |||
|  |                                 inputs=x_augmented,  | |||
|  |                                 device=self.device, | |||
|  |                                 smooth_kernel_std=varied_smooth_std, | |||
|  |                                 smooth_kernel_size=default_smooth_size, | |||
|  |                                 padding='valid', | |||
|  |                             ) | |||
|  |                              | |||
|  |                             # GRU预测 | |||
|  |                             if self.cache.get_prediction('gru', session, trial, aug_type) is None: | |||
|  |                                 gru_probs = self._get_model_prediction( | |||
|  |                                     self.gru_model, self.gru_model_args, x_smoothed, input_layer | |||
|  |                                 ) | |||
|  |                                 self.cache.add_prediction('gru', session, trial, aug_type,  | |||
|  |                                                         gru_probs.cpu().numpy()) | |||
|  |                                 pbar.update(1) | |||
|  |                              | |||
|  |                             # LSTM预测 | |||
|  |                             if self.cache.get_prediction('lstm', session, trial, aug_type) is None: | |||
|  |                                 lstm_probs = self._get_model_prediction( | |||
|  |                                     self.lstm_model, self.lstm_model_args, x_smoothed, input_layer | |||
|  |                                 ) | |||
|  |                                 self.cache.add_prediction('lstm', session, trial, aug_type,  | |||
|  |                                                          lstm_probs.cpu().numpy()) | |||
|  |                                 pbar.update(1) | |||
|  |          | |||
|  |         # 保存缓存 | |||
|  |         print("Saving cache to disk...") | |||
|  |         self.cache.save_cache() | |||
|  |         print("Cache generation completed!") | |||
|  |      | |||
|  |     def evaluate_parameters(self, gru_weight, tta_weights): | |||
|  |         """评估给定参数组合的PER性能""" | |||
|  |         lstm_weight = 1.0 - gru_weight | |||
|  |          | |||
|  |         # 将tta_weights转换为字典 | |||
|  |         tta_weights_dict = { | |||
|  |             'original': tta_weights[0], | |||
|  |             'noise': tta_weights[1], | |||
|  |             'scale': tta_weights[2], | |||
|  |             'shift': tta_weights[3], | |||
|  |             'smooth': tta_weights[4] | |||
|  |         } | |||
|  |          | |||
|  |         total_true_length = 0 | |||
|  |         total_edit_distance = 0 | |||
|  |          | |||
|  |         for session, data in self.test_data.items(): | |||
|  |             for trial in range(len(data['neural_features'])): | |||
|  |                 # 收集所有增强方式的预测结果 | |||
|  |                 all_gru_probs = [] | |||
|  |                 all_lstm_probs = [] | |||
|  |                 sample_weights = [] | |||
|  |                  | |||
|  |                 for aug_type in self.cache.augmentation_types: | |||
|  |                     if tta_weights_dict[aug_type] <= 0: | |||
|  |                         continue | |||
|  |                      | |||
|  |                     # 从缓存获取预测结果 | |||
|  |                     gru_probs = self.cache.get_prediction('gru', session, trial, aug_type) | |||
|  |                     lstm_probs = self.cache.get_prediction('lstm', session, trial, aug_type) | |||
|  |                      | |||
|  |                     if gru_probs is not None and lstm_probs is not None: | |||
|  |                         all_gru_probs.append(torch.tensor(gru_probs)) | |||
|  |                         all_lstm_probs.append(torch.tensor(lstm_probs)) | |||
|  |                         sample_weights.append(tta_weights_dict[aug_type]) | |||
|  |                  | |||
|  |                 if len(all_gru_probs) == 0: | |||
|  |                     continue | |||
|  |                  | |||
|  |                 # TTA融合 | |||
|  |                 if len(all_gru_probs) > 1: | |||
|  |                     min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs]) | |||
|  |                      | |||
|  |                     # 截断到最小长度 | |||
|  |                     truncated_gru_probs = [] | |||
|  |                     truncated_lstm_probs = [] | |||
|  |                     for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs): | |||
|  |                         if gru_probs.shape[1] > min_length: | |||
|  |                             truncated_gru_probs.append(gru_probs[:, :min_length, :]) | |||
|  |                         else: | |||
|  |                             truncated_gru_probs.append(gru_probs) | |||
|  |                          | |||
|  |                         if lstm_probs.shape[1] > min_length: | |||
|  |                             truncated_lstm_probs.append(lstm_probs[:, :min_length, :]) | |||
|  |                         else: | |||
|  |                             truncated_lstm_probs.append(lstm_probs) | |||
|  |                      | |||
|  |                     # 加权平均 | |||
|  |                     sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32) | |||
|  |                     sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum() | |||
|  |                      | |||
|  |                     weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0]) | |||
|  |                     weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0]) | |||
|  |                      | |||
|  |                     for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)): | |||
|  |                         weighted_gru_probs += weight * gru_probs | |||
|  |                         weighted_lstm_probs += weight * lstm_probs | |||
|  |                      | |||
|  |                     avg_gru_probs = weighted_gru_probs | |||
|  |                     avg_lstm_probs = weighted_lstm_probs | |||
|  |                 else: | |||
|  |                     avg_gru_probs = all_gru_probs[0] | |||
|  |                     avg_lstm_probs = all_lstm_probs[0] | |||
|  |                  | |||
|  |                 # 集成融合 (几何平均) | |||
|  |                 epsilon = 1e-8 | |||
|  |                 avg_gru_probs = avg_gru_probs + epsilon | |||
|  |                 avg_lstm_probs = avg_lstm_probs + epsilon | |||
|  |                  | |||
|  |                 log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) +  | |||
|  |                                      lstm_weight * torch.log(avg_lstm_probs)) | |||
|  |                 ensemble_probs = torch.exp(log_ensemble_probs) | |||
|  |                 ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True) | |||
|  |                  | |||
|  |                 # 转换为预测序列 | |||
|  |                 logits = torch.log(ensemble_probs + epsilon) | |||
|  |                 pred_seq = np.argmax(logits[0].numpy(), axis=-1) | |||
|  |                  | |||
|  |                 # 移除空白和连续重复 | |||
|  |                 pred_seq = [int(p) for p in pred_seq if p != 0] | |||
|  |                 pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] | |||
|  |                 pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq] | |||
|  |                  | |||
|  |                 # 获取真实序列 | |||
|  |                 true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]] | |||
|  |                 true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq] | |||
|  |                  | |||
|  |                 # 计算编辑距离 | |||
|  |                 ed = editdistance.eval(true_seq, pred_seq) | |||
|  |                 total_true_length += len(true_seq) | |||
|  |                 total_edit_distance += ed | |||
|  |          | |||
|  |         # 计算PER | |||
|  |         if total_true_length == 0: | |||
|  |             return 100.0  # 返回最大PER作为惩罚 | |||
|  |          | |||
|  |         per = 100 * total_edit_distance / total_true_length | |||
|  |         return per | |||
|  |      | |||
|  |     def objective_function(self, params): | |||
|  |         """差分进化的目标函数""" | |||
|  |         self.evaluation_count += 1 | |||
|  |          | |||
|  |         # 解码参数 | |||
|  |         gru_weight = params[0]  # 范围 [0, 1] | |||
|  |         tta_weights = params[1:6]  # 5个TTA权重 | |||
|  |          | |||
|  |         # 确保权重非负 | |||
|  |         tta_weights = np.maximum(tta_weights, 0) | |||
|  |          | |||
|  |         # 如果所有TTA权重都为0,返回最大PER作为惩罚 | |||
|  |         if np.sum(tta_weights) == 0: | |||
|  |             return 100.0 | |||
|  |          | |||
|  |         try: | |||
|  |             per = self.evaluate_parameters(gru_weight, tta_weights) | |||
|  |              | |||
|  |             # 记录历史最佳 | |||
|  |             if len(self.best_per_history) == 0 or per < min(self.best_per_history): | |||
|  |                 print(f"🎯 Eval {self.evaluation_count}: New best PER = {per:.4f}%") | |||
|  |                 print(f"   GRU weight = {gru_weight:.4f}, TTA weights = {tta_weights}") | |||
|  |              | |||
|  |             self.best_per_history.append(per) | |||
|  |              | |||
|  |             # 每10次评估输出进度 | |||
|  |             if self.evaluation_count % 10 == 0: | |||
|  |                 current_best = min(self.best_per_history) | |||
|  |                 print(f"📊 Progress: {self.evaluation_count} evaluations, Best PER = {current_best:.4f}%") | |||
|  |              | |||
|  |             return per  # 差分进化默认最小化目标函数 | |||
|  |              | |||
|  |         except Exception as e: | |||
|  |             print(f"Error in objective function evaluation: {e}") | |||
|  |             return 100.0 | |||
|  |      | |||
|  |     def optimize(self): | |||
|  |         """运行差分进化算法优化""" | |||
|  |         print("Starting Differential Evolution optimization...") | |||
|  |         print(f"Algorithm parameters:") | |||
|  |         print(f"  - Population size multiplier: {self.population_size}") | |||
|  |         print(f"  - Max iterations: {self.max_iterations}") | |||
|  |         print(f"  - Mutation factor: {self.mutation_factor}") | |||
|  |         print(f"  - Crossover probability: {self.crossover_prob}") | |||
|  |         print(f"  - Tolerance: {self.tolerance}") | |||
|  |          | |||
|  |         # 首先生成所有预测并缓存 | |||
|  |         self.generate_all_predictions() | |||
|  |          | |||
|  |         # 定义参数边界 | |||
|  |         # params = [gru_weight, original_weight, noise_weight, scale_weight, shift_weight, smooth_weight] | |||
|  |         bounds = [ | |||
|  |             (0.0, 1.0),  # gru_weight: [0, 1] | |||
|  |             (0.0, 5.0),  # original weight: [0, 5] | |||
|  |             (0.0, 5.0),  # noise weight: [0, 5] | |||
|  |             (0.0, 5.0),  # scale weight: [0, 5] | |||
|  |             (0.0, 5.0),  # shift weight: [0, 5] | |||
|  |             (0.0, 5.0),  # smooth weight: [0, 5] | |||
|  |         ] | |||
|  |          | |||
|  |         print(f"\nParameter bounds:") | |||
|  |         param_names = ['GRU weight', 'Original', 'Noise', 'Scale', 'Shift', 'Smooth'] | |||
|  |         for i, (name, (low, high)) in enumerate(zip(param_names, bounds)): | |||
|  |             print(f"  - {name}: [{low}, {high}]") | |||
|  |          | |||
|  |         # 运行差分进化优化 | |||
|  |         print(f"\nRunning differential evolution optimization...") | |||
|  |         start_time = time.time() | |||
|  |          | |||
|  |         result = differential_evolution( | |||
|  |             func=self.objective_function, | |||
|  |             bounds=bounds, | |||
|  |             popsize=self.population_size,    # 种群大小倍数 | |||
|  |             maxiter=self.max_iterations,     # 最大迭代次数 | |||
|  |             tol=self.tolerance,              # 收敛容忍度 | |||
|  |             mutation=self.mutation_factor,   # 变异因子 | |||
|  |             recombination=self.crossover_prob, # 交叉概率 | |||
|  |             seed=42,                         # 随机种子确保可复现 | |||
|  |             disp=True,                       # 显示优化过程 | |||
|  |             polish=True,                     # 最后用局部优化算法精炼 | |||
|  |             workers=1,                       # 单线程避免缓存冲突 | |||
|  |             updating='deferred',             # 延迟更新策略 | |||
|  |         ) | |||
|  |          | |||
|  |         end_time = time.time() | |||
|  |          | |||
|  |         # 获取最佳解 | |||
|  |         best_params = result.x | |||
|  |         best_per = result.fun | |||
|  |          | |||
|  |         print("\n" + "="*60) | |||
|  |         print("DIFFERENTIAL EVOLUTION OPTIMIZATION COMPLETED!") | |||
|  |         print("="*60) | |||
|  |         print(f"Optimization time: {end_time - start_time:.2f} seconds") | |||
|  |         print(f"Total function evaluations: {result.nfev}") | |||
|  |         print(f"Optimization success: {result.success}") | |||
|  |         print(f"Termination message: {result.message}") | |||
|  |          | |||
|  |         print(f"\nBest solution:") | |||
|  |         print(f"Best GRU weight: {best_params[0]:.4f}") | |||
|  |         print(f"Best LSTM weight: {1.0 - best_params[0]:.4f}") | |||
|  |         print(f"Best TTA weights:") | |||
|  |         aug_types = ['original', 'noise', 'scale', 'shift', 'smooth'] | |||
|  |         for i, aug_type in enumerate(aug_types): | |||
|  |             print(f"  - {aug_type}: {best_params[i+1]:.4f}") | |||
|  |         print(f"Best PER: {best_per:.4f}%") | |||
|  |          | |||
|  |         # 保存详细结果 | |||
|  |         optimization_result = { | |||
|  |             'best_params': best_params, | |||
|  |             'gru_weight': best_params[0], | |||
|  |             'lstm_weight': 1.0 - best_params[0], | |||
|  |             'tta_weights': { | |||
|  |                 'original': best_params[1], | |||
|  |                 'noise': best_params[2],  | |||
|  |                 'scale': best_params[3], | |||
|  |                 'shift': best_params[4], | |||
|  |                 'smooth': best_params[5] | |||
|  |             }, | |||
|  |             'best_per': best_per, | |||
|  |             'optimization_time': end_time - start_time, | |||
|  |             'function_evaluations': result.nfev, | |||
|  |             'success': result.success, | |||
|  |             'message': result.message, | |||
|  |             'per_history': self.best_per_history, | |||
|  |             'algorithm': 'differential_evolution', | |||
|  |             'algorithm_params': { | |||
|  |                 'popsize': self.population_size, | |||
|  |                 'maxiter': self.max_iterations, | |||
|  |                 'mutation': self.mutation_factor, | |||
|  |                 'recombination': self.crossover_prob, | |||
|  |                 'tolerance': self.tolerance | |||
|  |             } | |||
|  |         } | |||
|  |          | |||
|  |         # 保存到文件 | |||
|  |         timestamp = time.strftime("%Y%m%d_%H%M%S") | |||
|  |         result_file = f'de_optimization_result_{timestamp}.pkl' | |||
|  |         with open(result_file, 'wb') as f: | |||
|  |             pickle.dump(optimization_result, f) | |||
|  |          | |||
|  |         print(f"\nResults saved to: {result_file}") | |||
|  |          | |||
|  |         # 性能分析 | |||
|  |         print(f"\nPerformance Analysis:") | |||
|  |         print(f"  - Average evaluation time: {(end_time - start_time) / result.nfev:.3f} seconds") | |||
|  |         print(f"  - Evaluations per minute: {result.nfev / ((end_time - start_time) / 60):.1f}") | |||
|  |         if len(self.best_per_history) > 10: | |||
|  |             improvement = self.best_per_history[0] - min(self.best_per_history) | |||
|  |             print(f"  - Total PER improvement: {improvement:.4f}%") | |||
|  |          | |||
|  |         return optimization_result | |||
|  | 
 | |||
|  | def main(): | |||
|  |     """主函数""" | |||
|  |     print("TTA-E Differential Evolution Optimization") | |||
|  |     print("="*50) | |||
|  |      | |||
|  |     # 创建优化器 | |||
|  |     optimizer = TTAEDifferentialEvolutionOptimizer( | |||
|  |         gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', | |||
|  |         lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', | |||
|  |         data_dir='../data/hdf5_data_final', | |||
|  |         csv_path='../data/t15_copyTaskData_description.csv', | |||
|  |         gpu_number=0 | |||
|  |     ) | |||
|  |      | |||
|  |     # 运行优化 | |||
|  |     result = optimizer.optimize() | |||
|  |      | |||
|  |     print("\nDifferential Evolution optimization completed successfully!") | |||
|  |     return result | |||
|  | 
 | |||
|  | if __name__ == "__main__": | |||
|  |     # 设置环境 | |||
|  |     os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||
|  |      | |||
|  |     # 运行主函数 | |||
|  |     result = main() |