590 lines
26 KiB
Python
590 lines
26 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
遗传算法优化TTA-E集成参数
|
||
使用PyGAD优化gru_weight和tta_weights参数,目标是最小化PER
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import torch
|
||
import numpy as np
|
||
import pandas as pd
|
||
from omegaconf import OmegaConf
|
||
import time
|
||
from tqdm import tqdm
|
||
import editdistance
|
||
import pygad
|
||
import pickle
|
||
import multiprocessing as mp
|
||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||
import hashlib
|
||
from functools import lru_cache
|
||
import cupy as cp
|
||
import warnings
|
||
warnings.filterwarnings("ignore")
|
||
|
||
# Add parent directories to path to import models
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
|
||
|
||
from model_training.rnn_model import GRUDecoder
|
||
from model_training_lstm.rnn_model import LSTMDecoder
|
||
from model_training.evaluate_model_helpers import *
|
||
|
||
class TTAEnsembleCache:
|
||
"""高效缓存系统,存储GRU和LSTM在5种增强方式下的预测结果"""
|
||
|
||
def __init__(self, cache_dir='./tta_cache'):
|
||
self.cache_dir = cache_dir
|
||
os.makedirs(cache_dir, exist_ok=True)
|
||
self.gru_cache = {}
|
||
self.lstm_cache = {}
|
||
self.augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
|
||
|
||
def _get_cache_key(self, session, trial, aug_type):
|
||
"""生成缓存键"""
|
||
return f"{session}_{trial}_{aug_type}"
|
||
|
||
def _get_cache_file(self, model_type):
|
||
"""获取缓存文件路径"""
|
||
return os.path.join(self.cache_dir, f'{model_type}_predictions.pkl')
|
||
|
||
def save_cache(self):
|
||
"""保存缓存到磁盘"""
|
||
with open(self._get_cache_file('gru'), 'wb') as f:
|
||
pickle.dump(self.gru_cache, f)
|
||
with open(self._get_cache_file('lstm'), 'wb') as f:
|
||
pickle.dump(self.lstm_cache, f)
|
||
|
||
def load_cache(self):
|
||
"""从磁盘加载缓存"""
|
||
try:
|
||
with open(self._get_cache_file('gru'), 'rb') as f:
|
||
self.gru_cache = pickle.load(f)
|
||
with open(self._get_cache_file('lstm'), 'rb') as f:
|
||
self.lstm_cache = pickle.load(f)
|
||
return True
|
||
except FileNotFoundError:
|
||
return False
|
||
|
||
def add_prediction(self, model_type, session, trial, aug_type, prediction):
|
||
"""添加预测结果到缓存"""
|
||
cache = self.gru_cache if model_type == 'gru' else self.lstm_cache
|
||
key = self._get_cache_key(session, trial, aug_type)
|
||
cache[key] = prediction
|
||
|
||
def get_prediction(self, model_type, session, trial, aug_type):
|
||
"""从缓存获取预测结果"""
|
||
cache = self.gru_cache if model_type == 'gru' else self.lstm_cache
|
||
key = self._get_cache_key(session, trial, aug_type)
|
||
return cache.get(key, None)
|
||
|
||
def is_complete(self, sessions, trials_per_session):
|
||
"""检查缓存是否完整"""
|
||
total_expected = 0
|
||
total_cached_gru = 0
|
||
total_cached_lstm = 0
|
||
|
||
for session in sessions:
|
||
for trial in range(trials_per_session[session]):
|
||
for aug_type in self.augmentation_types:
|
||
total_expected += 1
|
||
if self.get_prediction('gru', session, trial, aug_type) is not None:
|
||
total_cached_gru += 1
|
||
if self.get_prediction('lstm', session, trial, aug_type) is not None:
|
||
total_cached_lstm += 1
|
||
|
||
return (total_cached_gru == total_expected and
|
||
total_cached_lstm == total_expected)
|
||
|
||
class TTAEGeneticOptimizer:
|
||
"""使用遗传算法优化TTA-E参数的主类"""
|
||
|
||
def __init__(self,
|
||
gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
|
||
lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
|
||
data_dir='../data/hdf5_data_final',
|
||
csv_path='../data/t15_copyTaskData_description.csv',
|
||
gpu_number=0):
|
||
|
||
self.gru_model_path = gru_model_path
|
||
self.lstm_model_path = lstm_model_path
|
||
self.data_dir = data_dir
|
||
self.csv_path = csv_path
|
||
self.gpu_number = gpu_number
|
||
|
||
# 初始化设备
|
||
if torch.cuda.is_available() and gpu_number >= 0:
|
||
self.device = torch.device(f'cuda:{gpu_number}')
|
||
print(f'Using {self.device} for model inference.')
|
||
else:
|
||
self.device = torch.device('cpu')
|
||
print('Using CPU for model inference.')
|
||
|
||
# 初始化缓存
|
||
self.cache = TTAEnsembleCache()
|
||
|
||
# TTA参数
|
||
self.tta_noise_std = 0.01
|
||
self.tta_smooth_range = 0.5
|
||
self.tta_scale_range = 0.05
|
||
self.tta_cut_max = 3
|
||
|
||
# 遗传算法参数
|
||
self.population_size = 20
|
||
self.num_generations = 20
|
||
self.num_parents_mating = 5
|
||
self.mutation_percent_genes = 20
|
||
|
||
# 加载模型和数据
|
||
self._load_models()
|
||
self._load_data()
|
||
|
||
def _load_models(self):
|
||
"""加载GRU和LSTM模型"""
|
||
print("Loading models...")
|
||
|
||
# 加载模型参数
|
||
self.gru_model_args = OmegaConf.load(os.path.join(self.gru_model_path, 'checkpoint/args.yaml'))
|
||
self.lstm_model_args = OmegaConf.load(os.path.join(self.lstm_model_path, 'checkpoint/args.yaml'))
|
||
|
||
# 定义GRU模型
|
||
self.gru_model = GRUDecoder(
|
||
neural_dim=self.gru_model_args['model']['n_input_features'],
|
||
n_units=self.gru_model_args['model']['n_units'],
|
||
n_days=len(self.gru_model_args['dataset']['sessions']),
|
||
n_classes=self.gru_model_args['dataset']['n_classes'],
|
||
rnn_dropout=self.gru_model_args['model']['rnn_dropout'],
|
||
input_dropout=self.gru_model_args['model']['input_network']['input_layer_dropout'],
|
||
n_layers=self.gru_model_args['model']['n_layers'],
|
||
patch_size=self.gru_model_args['model']['patch_size'],
|
||
patch_stride=self.gru_model_args['model']['patch_stride'],
|
||
)
|
||
|
||
# 加载GRU模型权重
|
||
gru_checkpoint = torch.load(os.path.join(self.gru_model_path, 'checkpoint/best_checkpoint'),
|
||
weights_only=False, map_location=self.device)
|
||
# 处理模型键名
|
||
for key in list(gru_checkpoint['model_state_dict'].keys()):
|
||
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
||
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
||
self.gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
|
||
|
||
# 定义LSTM模型
|
||
self.lstm_model = LSTMDecoder(
|
||
neural_dim=self.lstm_model_args['model']['n_input_features'],
|
||
n_units=self.lstm_model_args['model']['n_units'],
|
||
n_days=len(self.lstm_model_args['dataset']['sessions']),
|
||
n_classes=self.lstm_model_args['dataset']['n_classes'],
|
||
rnn_dropout=self.lstm_model_args['model']['rnn_dropout'],
|
||
input_dropout=self.lstm_model_args['model']['input_network']['input_layer_dropout'],
|
||
n_layers=self.lstm_model_args['model']['n_layers'],
|
||
patch_size=self.lstm_model_args['model']['patch_size'],
|
||
patch_stride=self.lstm_model_args['model']['patch_stride'],
|
||
)
|
||
|
||
# 加载LSTM模型权重
|
||
lstm_checkpoint = torch.load(os.path.join(self.lstm_model_path, 'checkpoint/best_checkpoint'),
|
||
weights_only=False, map_location=self.device)
|
||
# 处理模型键名
|
||
for key in list(lstm_checkpoint['model_state_dict'].keys()):
|
||
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
||
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
||
self.lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
|
||
|
||
# 移动模型到设备并设置为评估模式
|
||
self.gru_model.to(self.device)
|
||
self.lstm_model.to(self.device)
|
||
self.gru_model.eval()
|
||
self.lstm_model.eval()
|
||
|
||
print("Models loaded successfully!")
|
||
|
||
def _load_data(self):
|
||
"""加载验证数据集"""
|
||
print("Loading validation data...")
|
||
|
||
# 加载CSV文件
|
||
b2txt_csv_df = pd.read_csv(self.csv_path)
|
||
|
||
# 加载验证数据
|
||
self.test_data = {}
|
||
self.trials_per_session = {}
|
||
total_test_trials = 0
|
||
|
||
for session in self.gru_model_args['dataset']['sessions']:
|
||
files = [f for f in os.listdir(os.path.join(self.data_dir, session)) if f.endswith('.hdf5')]
|
||
if 'data_val.hdf5' in files:
|
||
eval_file = os.path.join(self.data_dir, session, 'data_val.hdf5')
|
||
data = load_h5py_file(eval_file, b2txt_csv_df)
|
||
self.test_data[session] = data
|
||
self.trials_per_session[session] = len(data["neural_features"])
|
||
total_test_trials += len(data["neural_features"])
|
||
print(f'Loaded {len(data["neural_features"])} validation trials for session {session}.')
|
||
|
||
print(f'Total number of validation trials: {total_test_trials}')
|
||
|
||
def _apply_augmentation(self, x, aug_type):
|
||
"""应用数据增强"""
|
||
x_augmented = x.clone()
|
||
|
||
if aug_type == 'original':
|
||
pass
|
||
elif aug_type == 'noise':
|
||
noise_scale = self.tta_noise_std * (0.5 + 0.5 * np.random.rand())
|
||
noise = torch.randn_like(x_augmented) * noise_scale
|
||
x_augmented = x_augmented + noise
|
||
elif aug_type == 'scale':
|
||
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * self.tta_scale_range
|
||
x_augmented = x_augmented * scale_factor
|
||
elif aug_type == 'shift' and self.tta_cut_max > 0:
|
||
shift_amount = np.random.randint(1, min(self.tta_cut_max + 1, x_augmented.shape[1] // 8))
|
||
x_augmented = torch.cat([x_augmented[:, shift_amount:, :],
|
||
x_augmented[:, :shift_amount, :]], dim=1)
|
||
elif aug_type == 'smooth':
|
||
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range
|
||
varied_smooth_std = max(0.3, self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation)
|
||
|
||
return x_augmented
|
||
|
||
def _get_model_prediction(self, model, model_args, x_smoothed, input_layer):
|
||
"""获取单个模型的预测结果"""
|
||
with torch.no_grad():
|
||
logits, _ = model(
|
||
x=x_smoothed,
|
||
day_idx=torch.tensor([input_layer], device=self.device),
|
||
states=None,
|
||
return_state=True,
|
||
)
|
||
probs = torch.softmax(logits, dim=-1)
|
||
return probs
|
||
|
||
def generate_all_predictions(self):
|
||
"""生成所有模型在所有增强方式下的预测结果并缓存"""
|
||
print("Generating all TTA predictions for caching...")
|
||
|
||
# 尝试加载现有缓存
|
||
if self.cache.load_cache():
|
||
if self.cache.is_complete(self.test_data.keys(), self.trials_per_session):
|
||
print("Complete cache found, skipping prediction generation.")
|
||
return
|
||
else:
|
||
print("Incomplete cache found, generating missing predictions...")
|
||
|
||
total_trials = sum(self.trials_per_session.values())
|
||
total_predictions = total_trials * len(self.cache.augmentation_types) * 2 # 2 models
|
||
|
||
with tqdm(total=total_predictions, desc='Generating cached predictions', unit='pred') as pbar:
|
||
for session, data in self.test_data.items():
|
||
input_layer = self.gru_model_args['dataset']['sessions'].index(session)
|
||
|
||
for trial in range(len(data['neural_features'])):
|
||
# 获取神经输入
|
||
neural_input = data['neural_features'][trial]
|
||
neural_input = np.expand_dims(neural_input, axis=0)
|
||
neural_input = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16)
|
||
|
||
# 为每种增强方式生成预测
|
||
for aug_type in self.cache.augmentation_types:
|
||
# 检查是否已缓存
|
||
if (self.cache.get_prediction('gru', session, trial, aug_type) is not None and
|
||
self.cache.get_prediction('lstm', session, trial, aug_type) is not None):
|
||
pbar.update(2)
|
||
continue
|
||
|
||
# 应用增强
|
||
x_augmented = self._apply_augmentation(neural_input, aug_type)
|
||
|
||
# 应用高斯平滑
|
||
default_smooth_std = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
|
||
default_smooth_size = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
|
||
|
||
if aug_type == 'smooth':
|
||
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range
|
||
varied_smooth_std = max(0.3, default_smooth_std + smooth_variation)
|
||
else:
|
||
varied_smooth_std = default_smooth_std
|
||
|
||
with torch.autocast(device_type="cuda", enabled=self.gru_model_args['use_amp'], dtype=torch.bfloat16):
|
||
x_smoothed = gauss_smooth(
|
||
inputs=x_augmented,
|
||
device=self.device,
|
||
smooth_kernel_std=varied_smooth_std,
|
||
smooth_kernel_size=default_smooth_size,
|
||
padding='valid',
|
||
)
|
||
|
||
# GRU预测
|
||
if self.cache.get_prediction('gru', session, trial, aug_type) is None:
|
||
gru_probs = self._get_model_prediction(
|
||
self.gru_model, self.gru_model_args, x_smoothed, input_layer
|
||
)
|
||
self.cache.add_prediction('gru', session, trial, aug_type,
|
||
gru_probs.cpu().numpy())
|
||
pbar.update(1)
|
||
|
||
# LSTM预测
|
||
if self.cache.get_prediction('lstm', session, trial, aug_type) is None:
|
||
lstm_probs = self._get_model_prediction(
|
||
self.lstm_model, self.lstm_model_args, x_smoothed, input_layer
|
||
)
|
||
self.cache.add_prediction('lstm', session, trial, aug_type,
|
||
lstm_probs.cpu().numpy())
|
||
pbar.update(1)
|
||
|
||
# 保存缓存
|
||
print("Saving cache to disk...")
|
||
self.cache.save_cache()
|
||
print("Cache generation completed!")
|
||
|
||
def evaluate_parameters(self, gru_weight, tta_weights):
|
||
"""评估给定参数组合的PER性能"""
|
||
lstm_weight = 1.0 - gru_weight
|
||
|
||
# 将tta_weights转换为字典
|
||
tta_weights_dict = {
|
||
'original': tta_weights[0],
|
||
'noise': tta_weights[1],
|
||
'scale': tta_weights[2],
|
||
'shift': tta_weights[3],
|
||
'smooth': tta_weights[4]
|
||
}
|
||
|
||
total_true_length = 0
|
||
total_edit_distance = 0
|
||
|
||
for session, data in self.test_data.items():
|
||
for trial in range(len(data['neural_features'])):
|
||
# 收集所有增强方式的预测结果
|
||
all_gru_probs = []
|
||
all_lstm_probs = []
|
||
sample_weights = []
|
||
|
||
for aug_type in self.cache.augmentation_types:
|
||
if tta_weights_dict[aug_type] <= 0:
|
||
continue
|
||
|
||
# 从缓存获取预测结果
|
||
gru_probs = self.cache.get_prediction('gru', session, trial, aug_type)
|
||
lstm_probs = self.cache.get_prediction('lstm', session, trial, aug_type)
|
||
|
||
if gru_probs is not None and lstm_probs is not None:
|
||
all_gru_probs.append(torch.tensor(gru_probs))
|
||
all_lstm_probs.append(torch.tensor(lstm_probs))
|
||
sample_weights.append(tta_weights_dict[aug_type])
|
||
|
||
if len(all_gru_probs) == 0:
|
||
continue
|
||
|
||
# TTA融合
|
||
if len(all_gru_probs) > 1:
|
||
min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs])
|
||
|
||
# 截断到最小长度
|
||
truncated_gru_probs = []
|
||
truncated_lstm_probs = []
|
||
for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs):
|
||
if gru_probs.shape[1] > min_length:
|
||
truncated_gru_probs.append(gru_probs[:, :min_length, :])
|
||
else:
|
||
truncated_gru_probs.append(gru_probs)
|
||
|
||
if lstm_probs.shape[1] > min_length:
|
||
truncated_lstm_probs.append(lstm_probs[:, :min_length, :])
|
||
else:
|
||
truncated_lstm_probs.append(lstm_probs)
|
||
|
||
# 加权平均
|
||
sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32)
|
||
sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum()
|
||
|
||
weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0])
|
||
weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0])
|
||
|
||
for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)):
|
||
weighted_gru_probs += weight * gru_probs
|
||
weighted_lstm_probs += weight * lstm_probs
|
||
|
||
avg_gru_probs = weighted_gru_probs
|
||
avg_lstm_probs = weighted_lstm_probs
|
||
else:
|
||
avg_gru_probs = all_gru_probs[0]
|
||
avg_lstm_probs = all_lstm_probs[0]
|
||
|
||
# 集成融合 (几何平均)
|
||
epsilon = 1e-8
|
||
avg_gru_probs = avg_gru_probs + epsilon
|
||
avg_lstm_probs = avg_lstm_probs + epsilon
|
||
|
||
log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) +
|
||
lstm_weight * torch.log(avg_lstm_probs))
|
||
ensemble_probs = torch.exp(log_ensemble_probs)
|
||
ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
|
||
|
||
# 转换为预测序列
|
||
logits = torch.log(ensemble_probs + epsilon)
|
||
pred_seq = np.argmax(logits[0].numpy(), axis=-1)
|
||
|
||
# 移除空白和连续重复
|
||
pred_seq = [int(p) for p in pred_seq if p != 0]
|
||
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
|
||
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
|
||
|
||
# 获取真实序列
|
||
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
|
||
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
|
||
|
||
# 计算编辑距离
|
||
ed = editdistance.eval(true_seq, pred_seq)
|
||
total_true_length += len(true_seq)
|
||
total_edit_distance += ed
|
||
|
||
# 计算PER
|
||
if total_true_length == 0:
|
||
return 100.0 # 返回最大PER作为惩罚
|
||
|
||
per = 100 * total_edit_distance / total_true_length
|
||
return per
|
||
|
||
def fitness_function(self, ga_instance, solution, solution_idx):
|
||
"""遗传算法的适应度函数"""
|
||
# 解码参数
|
||
gru_weight = solution[0] # 范围 [0, 1]
|
||
tta_weights = solution[1:6] # 5个TTA权重
|
||
|
||
# 确保权重非负
|
||
tta_weights = np.maximum(tta_weights, 0)
|
||
|
||
# 如果所有TTA权重都为0,返回最低适应度
|
||
if np.sum(tta_weights) == 0:
|
||
return -100.0
|
||
|
||
try:
|
||
per = self.evaluate_parameters(gru_weight, tta_weights)
|
||
# 适应度 = -PER (因为我们要最小化PER)
|
||
fitness = -per
|
||
return fitness
|
||
except Exception as e:
|
||
print(f"Error in fitness evaluation: {e}")
|
||
return -100.0
|
||
|
||
def on_generation(self, ga_instance):
|
||
"""每代结束时的回调函数"""
|
||
solution, solution_fitness, solution_idx = ga_instance.best_solution()
|
||
print(f"Generation {ga_instance.generations_completed}")
|
||
print(f"Best solution: GRU weight={solution[0]:.3f}, TTA weights={solution[1:6]}")
|
||
print(f"Best fitness (negative PER): {solution_fitness:.3f}")
|
||
print(f"Best PER: {-solution_fitness:.3f}%")
|
||
print("-" * 50)
|
||
|
||
def optimize(self):
|
||
"""运行遗传算法优化"""
|
||
print("Starting genetic algorithm optimization...")
|
||
|
||
# 首先生成所有预测并缓存
|
||
self.generate_all_predictions()
|
||
|
||
# 定义参数边界
|
||
# gru_weight: [0, 1]
|
||
# tta_weights: [0, 5] for each weight
|
||
gene_space = [
|
||
{'low': 0.0, 'high': 1.0}, # gru_weight
|
||
{'low': 0.0, 'high': 5.0}, # original weight
|
||
{'low': 0.0, 'high': 5.0}, # noise weight
|
||
{'low': 0.0, 'high': 5.0}, # scale weight
|
||
{'low': 0.0, 'high': 5.0}, # shift weight
|
||
{'low': 0.0, 'high': 5.0}, # smooth weight
|
||
]
|
||
|
||
# 创建遗传算法实例
|
||
ga_instance = pygad.GA(
|
||
num_generations=self.num_generations,
|
||
num_parents_mating=self.num_parents_mating,
|
||
fitness_func=self.fitness_function,
|
||
sol_per_pop=self.population_size,
|
||
num_genes=6, # 1个gru_weight + 5个tta_weights
|
||
gene_space=gene_space,
|
||
mutation_percent_genes=self.mutation_percent_genes,
|
||
parent_selection_type="sss", # steady-state selection
|
||
keep_parents=2, # 保持2个最佳父代,必须 <= num_parents_mating (5)
|
||
crossover_type="single_point",
|
||
mutation_type="random",
|
||
on_generation=self.on_generation,
|
||
parallel_processing=['thread', mp.cpu_count()//2], # 使用线程并行处理
|
||
save_solutions=True,
|
||
)
|
||
|
||
# 运行优化
|
||
print(f"Running optimization with {self.population_size} population size for {self.num_generations} generations...")
|
||
start_time = time.time()
|
||
ga_instance.run()
|
||
end_time = time.time()
|
||
|
||
# 获取最佳解
|
||
solution, solution_fitness, solution_idx = ga_instance.best_solution()
|
||
|
||
print("\n" + "="*60)
|
||
print("OPTIMIZATION COMPLETED!")
|
||
print("="*60)
|
||
print(f"Optimization time: {end_time - start_time:.2f} seconds")
|
||
print(f"Best GRU weight: {solution[0]:.4f}")
|
||
print(f"Best LSTM weight: {1.0 - solution[0]:.4f}")
|
||
print(f"Best TTA weights:")
|
||
aug_types = ['original', 'noise', 'scale', 'shift', 'smooth']
|
||
for i, aug_type in enumerate(aug_types):
|
||
print(f" - {aug_type}: {solution[i+1]:.4f}")
|
||
print(f"Best PER: {-solution_fitness:.4f}%")
|
||
|
||
# 保存结果
|
||
result = {
|
||
'gru_weight': solution[0],
|
||
'lstm_weight': 1.0 - solution[0],
|
||
'tta_weights': {
|
||
'original': solution[1],
|
||
'noise': solution[2],
|
||
'scale': solution[3],
|
||
'shift': solution[4],
|
||
'smooth': solution[5]
|
||
},
|
||
'best_per': -solution_fitness,
|
||
'optimization_time': end_time - start_time,
|
||
'generations': self.num_generations,
|
||
'population_size': self.population_size
|
||
}
|
||
|
||
# 保存到文件
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
result_file = f'ga_optimization_result_{timestamp}.pkl'
|
||
with open(result_file, 'wb') as f:
|
||
pickle.dump(result, f)
|
||
|
||
print(f"Results saved to: {result_file}")
|
||
return result
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("TTA-E Genetic Algorithm Optimization")
|
||
print("="*50)
|
||
|
||
# 创建优化器
|
||
optimizer = TTAEGeneticOptimizer(
|
||
gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
|
||
lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
|
||
data_dir='../data/hdf5_data_final',
|
||
csv_path='../data/t15_copyTaskData_description.csv',
|
||
gpu_number=0
|
||
)
|
||
|
||
# 运行优化
|
||
result = optimizer.optimize()
|
||
|
||
print("\nOptimization completed successfully!")
|
||
return result
|
||
|
||
if __name__ == "__main__":
|
||
# 设置环境
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||
|
||
# 运行主函数
|
||
result = main()
|