473 lines
21 KiB
Python
473 lines
21 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
TTA-E参数搜索脚本
|
|||
|
先运行所有基础配置获取预测结果,然后搜索最优的参数组合
|
|||
|
避免重复模型推理,提高搜索效率
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import torch
|
|||
|
import numpy as np
|
|||
|
import pandas as pd
|
|||
|
from omegaconf import OmegaConf
|
|||
|
import time
|
|||
|
from tqdm import tqdm
|
|||
|
import editdistance
|
|||
|
import argparse
|
|||
|
import itertools
|
|||
|
import json
|
|||
|
from concurrent.futures import ProcessPoolExecutor
|
|||
|
import pickle
|
|||
|
|
|||
|
# Add parent directories to path to import models
|
|||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
|
|||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
|
|||
|
|
|||
|
from model_training.rnn_model import GRUDecoder
|
|||
|
from model_training_lstm.rnn_model import LSTMDecoder
|
|||
|
from model_training.evaluate_model_helpers import *
|
|||
|
|
|||
|
def parse_arguments():
|
|||
|
parser = argparse.ArgumentParser(description='TTA-E Parameter Search for optimal configuration.')
|
|||
|
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
|
|||
|
help='Path to the pretrained GRU model directory.')
|
|||
|
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
|
|||
|
help='Path to the pretrained LSTM model directory.')
|
|||
|
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
|||
|
help='Path to the dataset directory.')
|
|||
|
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
|||
|
help='Path to the CSV file with metadata.')
|
|||
|
parser.add_argument('--gpu_number', type=int, default=0,
|
|||
|
help='GPU number to use for model inference.')
|
|||
|
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
|
|||
|
help='Evaluation type.')
|
|||
|
|
|||
|
# 搜索空间参数
|
|||
|
parser.add_argument('--gru_weights', type=str, default='0.4,0.5,0.6,0.7,0.8,1.0',
|
|||
|
help='Comma-separated GRU weights to search.')
|
|||
|
parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.5,1.0',
|
|||
|
help='Comma-separated noise weights to search.')
|
|||
|
parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.5,1.0',
|
|||
|
help='Comma-separated scale weights to search.')
|
|||
|
parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.5,1.0',
|
|||
|
help='Comma-separated shift weights to search.')
|
|||
|
parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.5,1.0',
|
|||
|
help='Comma-separated smooth weights to search.')
|
|||
|
|
|||
|
# TTA固定参数
|
|||
|
parser.add_argument('--tta_noise_std', type=float, default=0.01,
|
|||
|
help='Standard deviation for TTA noise augmentation.')
|
|||
|
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
|
|||
|
help='Range for TTA smoothing kernel variation.')
|
|||
|
parser.add_argument('--tta_scale_range', type=float, default=0.05,
|
|||
|
help='Range for TTA amplitude scaling.')
|
|||
|
parser.add_argument('--tta_cut_max', type=int, default=3,
|
|||
|
help='Maximum timesteps for TTA shift.')
|
|||
|
|
|||
|
# 输出控制
|
|||
|
parser.add_argument('--cache_file', type=str, default='tta_predictions_cache.pkl',
|
|||
|
help='File to cache model predictions.')
|
|||
|
parser.add_argument('--results_file', type=str, default='parameter_search_results.json',
|
|||
|
help='File to save search results.')
|
|||
|
parser.add_argument('--force_recache', action='store_true',
|
|||
|
help='Force re-computation of predictions cache.')
|
|||
|
|
|||
|
return parser.parse_args()
|
|||
|
|
|||
|
def generate_all_base_configs():
|
|||
|
"""生成所有需要运行的基础配置(每种增强单独运行)"""
|
|||
|
base_configs = [
|
|||
|
{'name': 'original', 'tta_weights': {'original': 1.0, 'noise': 0.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 0.0}},
|
|||
|
{'name': 'noise', 'tta_weights': {'original': 0.0, 'noise': 1.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 0.0}},
|
|||
|
{'name': 'scale', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 1.0, 'shift': 0.0, 'smooth': 0.0}},
|
|||
|
{'name': 'shift', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 0.0, 'shift': 1.0, 'smooth': 0.0}},
|
|||
|
{'name': 'smooth', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 1.0}},
|
|||
|
]
|
|||
|
return base_configs
|
|||
|
|
|||
|
def run_single_tta_prediction(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
|
|||
|
device, aug_type, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max):
|
|||
|
"""运行单个TTA增强的预测"""
|
|||
|
x_augmented = x.clone()
|
|||
|
|
|||
|
# Get default smoothing parameters
|
|||
|
default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
|
|||
|
default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
|
|||
|
|
|||
|
if aug_type == 'original':
|
|||
|
pass
|
|||
|
elif aug_type == 'noise':
|
|||
|
noise_scale = tta_noise_std * (0.5 + 0.5 * np.random.rand())
|
|||
|
noise = torch.randn_like(x_augmented) * noise_scale
|
|||
|
x_augmented = x_augmented + noise
|
|||
|
elif aug_type == 'scale':
|
|||
|
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range
|
|||
|
x_augmented = x_augmented * scale_factor
|
|||
|
elif aug_type == 'shift' and tta_cut_max > 0:
|
|||
|
shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8))
|
|||
|
x_augmented = torch.cat([x_augmented[:, shift_amount:, :],
|
|||
|
x_augmented[:, :shift_amount, :]], dim=1)
|
|||
|
elif aug_type == 'smooth':
|
|||
|
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range
|
|||
|
varied_smooth_std = max(0.3, default_smooth_std + smooth_variation)
|
|||
|
|
|||
|
# Use autocast for efficiency
|
|||
|
with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
|
|||
|
|
|||
|
# Apply Gaussian smoothing
|
|||
|
if aug_type == 'smooth':
|
|||
|
x_smoothed = gauss_smooth(
|
|||
|
inputs=x_augmented,
|
|||
|
device=device,
|
|||
|
smooth_kernel_std=varied_smooth_std,
|
|||
|
smooth_kernel_size=default_smooth_size,
|
|||
|
padding='valid',
|
|||
|
)
|
|||
|
else:
|
|||
|
x_smoothed = gauss_smooth(
|
|||
|
inputs=x_augmented,
|
|||
|
device=device,
|
|||
|
smooth_kernel_std=default_smooth_std,
|
|||
|
smooth_kernel_size=default_smooth_size,
|
|||
|
padding='valid',
|
|||
|
)
|
|||
|
|
|||
|
with torch.no_grad():
|
|||
|
# Get GRU logits and convert to probabilities
|
|||
|
gru_logits, _ = gru_model(
|
|||
|
x=x_smoothed,
|
|||
|
day_idx=torch.tensor([input_layer], device=device),
|
|||
|
states=None,
|
|||
|
return_state=True,
|
|||
|
)
|
|||
|
gru_probs = torch.softmax(gru_logits, dim=-1)
|
|||
|
|
|||
|
# Get LSTM logits and convert to probabilities
|
|||
|
lstm_logits, _ = lstm_model(
|
|||
|
x=x_smoothed,
|
|||
|
day_idx=torch.tensor([input_layer], device=device),
|
|||
|
states=None,
|
|||
|
return_state=True,
|
|||
|
)
|
|||
|
lstm_probs = torch.softmax(lstm_logits, dim=-1)
|
|||
|
|
|||
|
return gru_probs.float().cpu().numpy(), lstm_probs.float().cpu().numpy()
|
|||
|
|
|||
|
def cache_model_predictions(args):
|
|||
|
"""缓存所有基础模型预测结果"""
|
|||
|
print("=== 第一阶段:缓存模型预测结果 ===")
|
|||
|
|
|||
|
# 设置设备
|
|||
|
if torch.cuda.is_available() and args.gpu_number >= 0:
|
|||
|
device = torch.device(f'cuda:{args.gpu_number}')
|
|||
|
else:
|
|||
|
device = torch.device('cpu')
|
|||
|
print(f'Using device: {device}')
|
|||
|
|
|||
|
# 加载模型
|
|||
|
print("Loading models...")
|
|||
|
gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml'))
|
|||
|
lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml'))
|
|||
|
|
|||
|
# Define GRU model
|
|||
|
gru_model = GRUDecoder(
|
|||
|
neural_dim=gru_model_args['model']['n_input_features'],
|
|||
|
n_units=gru_model_args['model']['n_units'],
|
|||
|
n_days=len(gru_model_args['dataset']['sessions']),
|
|||
|
n_classes=gru_model_args['dataset']['n_classes'],
|
|||
|
rnn_dropout=gru_model_args['model']['rnn_dropout'],
|
|||
|
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
|
|||
|
n_layers=gru_model_args['model']['n_layers'],
|
|||
|
patch_size=gru_model_args['model']['patch_size'],
|
|||
|
patch_stride=gru_model_args['model']['patch_stride'],
|
|||
|
)
|
|||
|
|
|||
|
# Load GRU model weights
|
|||
|
gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'),
|
|||
|
weights_only=False, map_location=device)
|
|||
|
for key in list(gru_checkpoint['model_state_dict'].keys()):
|
|||
|
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
|||
|
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
|||
|
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
|
|||
|
|
|||
|
# Define LSTM model
|
|||
|
lstm_model = LSTMDecoder(
|
|||
|
neural_dim=lstm_model_args['model']['n_input_features'],
|
|||
|
n_units=lstm_model_args['model']['n_units'],
|
|||
|
n_days=len(lstm_model_args['dataset']['sessions']),
|
|||
|
n_classes=lstm_model_args['dataset']['n_classes'],
|
|||
|
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
|
|||
|
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
|
|||
|
n_layers=lstm_model_args['model']['n_layers'],
|
|||
|
patch_size=lstm_model_args['model']['patch_size'],
|
|||
|
patch_stride=lstm_model_args['model']['patch_stride'],
|
|||
|
)
|
|||
|
|
|||
|
# Load LSTM model weights
|
|||
|
lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'),
|
|||
|
weights_only=False, map_location=device)
|
|||
|
for key in list(lstm_checkpoint['model_state_dict'].keys()):
|
|||
|
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
|||
|
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
|||
|
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
|
|||
|
|
|||
|
gru_model.to(device)
|
|||
|
lstm_model.to(device)
|
|||
|
gru_model.eval()
|
|||
|
lstm_model.eval()
|
|||
|
|
|||
|
# 加载数据
|
|||
|
print("Loading data...")
|
|||
|
b2txt_csv_df = pd.read_csv(args.csv_path)
|
|||
|
test_data = {}
|
|||
|
total_trials = 0
|
|||
|
|
|||
|
for session in gru_model_args['dataset']['sessions']:
|
|||
|
files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')]
|
|||
|
if f'data_{args.eval_type}.hdf5' in files:
|
|||
|
eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5')
|
|||
|
data = load_h5py_file(eval_file, b2txt_csv_df)
|
|||
|
test_data[session] = data
|
|||
|
total_trials += len(test_data[session]["neural_features"])
|
|||
|
print(f'Loaded {len(test_data[session]["neural_features"])} {args.eval_type} trials for session {session}.')
|
|||
|
|
|||
|
print(f'Total trials: {total_trials}')
|
|||
|
|
|||
|
# 生成所有基础预测
|
|||
|
base_configs = generate_all_base_configs()
|
|||
|
predictions_cache = {}
|
|||
|
|
|||
|
for config in base_configs:
|
|||
|
config_name = config['name']
|
|||
|
print(f"\nGenerating predictions for: {config_name}")
|
|||
|
|
|||
|
predictions_cache[config_name] = {}
|
|||
|
|
|||
|
with tqdm(total=total_trials, desc=f'Processing {config_name}', unit='trial') as pbar:
|
|||
|
for session, data in test_data.items():
|
|||
|
input_layer = gru_model_args['dataset']['sessions'].index(session)
|
|||
|
session_predictions = []
|
|||
|
|
|||
|
for trial in range(len(data['neural_features'])):
|
|||
|
neural_input = data['neural_features'][trial]
|
|||
|
neural_input = np.expand_dims(neural_input, axis=0)
|
|||
|
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
|
|||
|
|
|||
|
gru_probs, lstm_probs = run_single_tta_prediction(
|
|||
|
neural_input, input_layer, gru_model, lstm_model,
|
|||
|
gru_model_args, lstm_model_args, device, config_name,
|
|||
|
args.tta_noise_std, args.tta_smooth_range,
|
|||
|
args.tta_scale_range, args.tta_cut_max
|
|||
|
)
|
|||
|
|
|||
|
session_predictions.append({
|
|||
|
'gru_probs': gru_probs,
|
|||
|
'lstm_probs': lstm_probs,
|
|||
|
'trial_info': {
|
|||
|
'session': session,
|
|||
|
'block_num': data['block_num'][trial],
|
|||
|
'trial_num': data['trial_num'][trial],
|
|||
|
'seq_class_ids': data['seq_class_ids'][trial] if args.eval_type == 'val' else None,
|
|||
|
'seq_len': data['seq_len'][trial] if args.eval_type == 'val' else None,
|
|||
|
'sentence_label': data['sentence_label'][trial] if args.eval_type == 'val' else None,
|
|||
|
}
|
|||
|
})
|
|||
|
|
|||
|
pbar.update(1)
|
|||
|
|
|||
|
predictions_cache[config_name][session] = session_predictions
|
|||
|
|
|||
|
# 保存缓存
|
|||
|
print(f"\nSaving predictions cache to {args.cache_file}...")
|
|||
|
with open(args.cache_file, 'wb') as f:
|
|||
|
pickle.dump(predictions_cache, f)
|
|||
|
|
|||
|
print("✓ Predictions cache saved successfully!")
|
|||
|
return predictions_cache
|
|||
|
|
|||
|
def ensemble_and_evaluate(predictions_cache, gru_weight, tta_weights, eval_type='val'):
|
|||
|
"""基于缓存的预测结果进行集成和评估"""
|
|||
|
lstm_weight = 1.0 - gru_weight
|
|||
|
epsilon = 1e-8
|
|||
|
|
|||
|
total_trials = 0
|
|||
|
total_edit_distance = 0
|
|||
|
total_true_length = 0
|
|||
|
|
|||
|
# 检查哪些增强被启用
|
|||
|
enabled_augmentations = [aug_type for aug_type, weight in tta_weights.items() if weight > 0]
|
|||
|
if len(enabled_augmentations) == 0:
|
|||
|
return float('inf') # 无效配置
|
|||
|
|
|||
|
# 归一化TTA权重
|
|||
|
total_tta_weight = sum(weight for weight in tta_weights.values() if weight > 0)
|
|||
|
normalized_tta_weights = {k: v/total_tta_weight for k, v in tta_weights.items() if v > 0}
|
|||
|
|
|||
|
for session in predictions_cache['original'].keys():
|
|||
|
session_predictions = predictions_cache['original'][session]
|
|||
|
|
|||
|
for trial_idx in range(len(session_predictions)):
|
|||
|
trial_info = session_predictions[trial_idx]['trial_info']
|
|||
|
|
|||
|
if eval_type == 'val' and trial_info['seq_class_ids'] is None:
|
|||
|
continue
|
|||
|
|
|||
|
# 收集所有启用增强的概率
|
|||
|
weighted_gru_probs = None
|
|||
|
weighted_lstm_probs = None
|
|||
|
|
|||
|
for aug_type in enabled_augmentations:
|
|||
|
weight = normalized_tta_weights[aug_type]
|
|||
|
gru_probs = predictions_cache[aug_type][session][trial_idx]['gru_probs']
|
|||
|
lstm_probs = predictions_cache[aug_type][session][trial_idx]['lstm_probs']
|
|||
|
|
|||
|
gru_probs = torch.tensor(gru_probs)
|
|||
|
lstm_probs = torch.tensor(lstm_probs)
|
|||
|
|
|||
|
if weighted_gru_probs is None:
|
|||
|
weighted_gru_probs = weight * gru_probs
|
|||
|
weighted_lstm_probs = weight * lstm_probs
|
|||
|
else:
|
|||
|
# 处理序列长度不同的情况
|
|||
|
min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1])
|
|||
|
weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :]
|
|||
|
weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :]
|
|||
|
|
|||
|
# 集成GRU和LSTM(几何平均)
|
|||
|
weighted_gru_probs = weighted_gru_probs + epsilon
|
|||
|
weighted_lstm_probs = weighted_lstm_probs + epsilon
|
|||
|
|
|||
|
log_ensemble_probs = (gru_weight * torch.log(weighted_gru_probs) +
|
|||
|
lstm_weight * torch.log(weighted_lstm_probs))
|
|||
|
ensemble_probs = torch.exp(log_ensemble_probs)
|
|||
|
ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
|
|||
|
|
|||
|
# 解码预测序列
|
|||
|
pred_seq = torch.argmax(ensemble_probs[0], dim=-1).numpy()
|
|||
|
pred_seq = [int(p) for p in pred_seq if p != 0]
|
|||
|
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
|
|||
|
pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq]
|
|||
|
|
|||
|
if eval_type == 'val':
|
|||
|
# 计算PER
|
|||
|
true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']]
|
|||
|
true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq]
|
|||
|
|
|||
|
ed = editdistance.eval(true_phonemes, pred_phonemes)
|
|||
|
total_edit_distance += ed
|
|||
|
total_true_length += len(true_phonemes)
|
|||
|
|
|||
|
total_trials += 1
|
|||
|
|
|||
|
if eval_type == 'val' and total_true_length > 0:
|
|||
|
per = 100 * total_edit_distance / total_true_length
|
|||
|
return per
|
|||
|
else:
|
|||
|
return 0.0 # test模式返回0
|
|||
|
|
|||
|
def search_optimal_parameters(predictions_cache, args):
|
|||
|
"""搜索最优参数组合"""
|
|||
|
print("\n=== 第二阶段:参数搜索 ===")
|
|||
|
|
|||
|
# 解析搜索空间
|
|||
|
gru_weights = [float(x) for x in args.gru_weights.split(',')]
|
|||
|
noise_weights = [float(x) for x in args.tta_noise_weights.split(',')]
|
|||
|
scale_weights = [float(x) for x in args.tta_scale_weights.split(',')]
|
|||
|
shift_weights = [float(x) for x in args.tta_shift_weights.split(',')]
|
|||
|
smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')]
|
|||
|
|
|||
|
print(f"Search space:")
|
|||
|
print(f" GRU weights: {gru_weights}")
|
|||
|
print(f" Noise weights: {noise_weights}")
|
|||
|
print(f" Scale weights: {scale_weights}")
|
|||
|
print(f" Shift weights: {shift_weights}")
|
|||
|
print(f" Smooth weights: {smooth_weights}")
|
|||
|
|
|||
|
# 生成所有参数组合
|
|||
|
all_combinations = list(itertools.product(gru_weights, noise_weights, scale_weights, shift_weights, smooth_weights))
|
|||
|
total_combinations = len(all_combinations)
|
|||
|
print(f"Total combinations to evaluate: {total_combinations}")
|
|||
|
|
|||
|
best_per = float('inf')
|
|||
|
best_config = None
|
|||
|
results = []
|
|||
|
|
|||
|
with tqdm(total=total_combinations, desc='Parameter search', unit='config') as pbar:
|
|||
|
for gru_w, noise_w, scale_w, shift_w, smooth_w in all_combinations:
|
|||
|
tta_weights = {
|
|||
|
'original': 1.0, # 总是包含原始数据
|
|||
|
'noise': noise_w,
|
|||
|
'scale': scale_w,
|
|||
|
'shift': shift_w,
|
|||
|
'smooth': smooth_w
|
|||
|
}
|
|||
|
|
|||
|
per = ensemble_and_evaluate(predictions_cache, gru_w, tta_weights, args.eval_type)
|
|||
|
|
|||
|
config = {
|
|||
|
'gru_weight': gru_w,
|
|||
|
'lstm_weight': 1.0 - gru_w,
|
|||
|
'tta_weights': tta_weights,
|
|||
|
'per': per
|
|||
|
}
|
|||
|
results.append(config)
|
|||
|
|
|||
|
if per < best_per:
|
|||
|
best_per = per
|
|||
|
best_config = config
|
|||
|
print(f"\n🎯 New best PER: {per:.3f}%")
|
|||
|
print(f" GRU weight: {gru_w:.1f}")
|
|||
|
print(f" TTA weights: {tta_weights}")
|
|||
|
|
|||
|
pbar.update(1)
|
|||
|
|
|||
|
return results, best_config
|
|||
|
|
|||
|
def main():
|
|||
|
args = parse_arguments()
|
|||
|
|
|||
|
print("TTA-E Parameter Search")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
# 第一阶段:缓存预测结果
|
|||
|
if args.force_recache or not os.path.exists(args.cache_file):
|
|||
|
predictions_cache = cache_model_predictions(args)
|
|||
|
else:
|
|||
|
print(f"Loading existing predictions cache from {args.cache_file}...")
|
|||
|
with open(args.cache_file, 'rb') as f:
|
|||
|
predictions_cache = pickle.load(f)
|
|||
|
print("✓ Cache loaded successfully!")
|
|||
|
|
|||
|
# 第二阶段:参数搜索
|
|||
|
results, best_config = search_optimal_parameters(predictions_cache, args)
|
|||
|
|
|||
|
# 保存结果
|
|||
|
print(f"\n=== 搜索完成 ===")
|
|||
|
print(f"Best configuration:")
|
|||
|
print(f" PER: {best_config['per']:.3f}%")
|
|||
|
print(f" GRU weight: {best_config['gru_weight']:.1f}")
|
|||
|
print(f" LSTM weight: {best_config['lstm_weight']:.1f}")
|
|||
|
print(f" TTA weights: {best_config['tta_weights']}")
|
|||
|
|
|||
|
# 保存所有结果
|
|||
|
search_results = {
|
|||
|
'best_config': best_config,
|
|||
|
'all_results': results,
|
|||
|
'search_args': vars(args),
|
|||
|
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
|
|||
|
}
|
|||
|
|
|||
|
with open(args.results_file, 'w') as f:
|
|||
|
json.dump(search_results, f, indent=2)
|
|||
|
|
|||
|
print(f"\n✓ Results saved to {args.results_file}")
|
|||
|
|
|||
|
# 显示前10个最佳配置
|
|||
|
sorted_results = sorted(results, key=lambda x: x['per'])
|
|||
|
print(f"\nTop 10 configurations:")
|
|||
|
for i, config in enumerate(sorted_results[:10]):
|
|||
|
print(f"{i+1:2d}. PER={config['per']:6.3f}% | GRU={config['gru_weight']:.1f} | TTA={config['tta_weights']}")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|