Files
b2txt25/TTA-E/staged_search.py
2025-10-06 15:17:44 +08:00

294 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
分阶段TTA-E参数搜索
先粗搜索找到有希望的区域,再精细搜索
"""
import os
import sys
import argparse
import json
import numpy as np
from itertools import product
import subprocess
import time
def parse_arguments():
parser = argparse.ArgumentParser(description='分阶段TTA-E参数搜索')
# 基础参数
parser.add_argument('--base_script', type=str, default='evaluate_model.py')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final')
parser.add_argument('--eval_type', type=str, default='val')
parser.add_argument('--gpu_number', type=int, default=0)
# 搜索阶段控制
parser.add_argument('--stage', type=str, default='coarse', choices=['coarse', 'fine', 'both'],
help='搜索阶段coarse=粗搜索fine=精细搜索both=两阶段')
parser.add_argument('--coarse_results', type=str, default='coarse_results.json',
help='粗搜索结果文件(用于精细搜索阶段)')
parser.add_argument('--final_results', type=str, default='final_results.json',
help='最终结果文件')
# 粗搜索参数步长0.2
parser.add_argument('--coarse_gru_weights', type=str, default='0.2,0.4,0.6,0.8,1.0')
parser.add_argument('--coarse_tta_weights', type=str, default='0.0,0.5,1.0')
# 精细搜索参数步长0.1,在最佳配置周围)
parser.add_argument('--fine_range', type=float, default=0.3,
help='精细搜索范围(围绕最佳配置的±范围)')
parser.add_argument('--fine_step', type=float, default=0.1,
help='精细搜索步长')
# 筛选控制
parser.add_argument('--top_k', type=int, default=5,
help='选择前K个最佳配置进行精细搜索')
return parser.parse_args()
def generate_coarse_search_space(args):
"""生成粗搜索空间"""
gru_weights = [float(x.strip()) for x in args.coarse_gru_weights.split(',')]
tta_weights = [float(x.strip()) for x in args.coarse_tta_weights.split(',')]
search_space = []
for gru_w in gru_weights:
for noise_w in tta_weights:
for scale_w in tta_weights:
for shift_w in tta_weights:
for smooth_w in tta_weights:
search_space.append((gru_w, 1.0, noise_w, scale_w, shift_w, smooth_w))
return search_space
def generate_fine_search_space(best_configs, args):
"""基于最佳配置生成精细搜索空间"""
fine_search_space = []
for config in best_configs:
gru_w = config['gru_weight']
tta_w = config['tta_weights']
# 在每个参数周围生成精细搜索点
gru_range = np.arange(
max(0.1, gru_w - args.fine_range),
min(1.0, gru_w + args.fine_range) + args.fine_step,
args.fine_step
)
for param_name in ['noise', 'scale', 'shift', 'smooth']:
base_val = tta_w[param_name]
param_range = np.arange(
max(0.0, base_val - args.fine_range),
min(1.0, base_val + args.fine_range) + args.fine_step,
args.fine_step
)
# 围绕当前最佳配置生成邻域
for gru_fine in gru_range:
for noise_fine in param_range if param_name == 'noise' else [tta_w['noise']]:
for scale_fine in param_range if param_name == 'scale' else [tta_w['scale']]:
for shift_fine in param_range if param_name == 'shift' else [tta_w['shift']]:
for smooth_fine in param_range if param_name == 'smooth' else [tta_w['smooth']]:
config_tuple = (
round(gru_fine, 1), 1.0,
round(noise_fine, 1), round(scale_fine, 1),
round(shift_fine, 1), round(smooth_fine, 1)
)
if config_tuple not in fine_search_space:
fine_search_space.append(config_tuple)
return fine_search_space
def run_evaluation(config, args):
"""运行单个配置的评估"""
gru_w, orig_w, noise_w, scale_w, shift_w, smooth_w = config
tta_weights_str = f"{orig_w},{noise_w},{scale_w},{shift_w},{smooth_w}"
cmd = [
'python', args.base_script,
'--gru_weight', str(gru_w),
'--tta_weights', tta_weights_str,
'--data_dir', args.data_dir,
'--eval_type', args.eval_type,
'--gpu_number', str(args.gpu_number)
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30分钟超时
# 解析PER结果
per = None
for line in result.stdout.split('\n'):
if 'Aggregate Phoneme Error Rate (PER):' in line:
per_str = line.split(':')[-1].strip().replace('%', '')
per = float(per_str)
break
if per is None:
print(f"⚠️ 无法解析PER结果: {config}")
per = float('inf')
return {
'config': config,
'gru_weight': gru_w,
'tta_weights': {
'original': orig_w,
'noise': noise_w,
'scale': scale_w,
'shift': shift_w,
'smooth': smooth_w
},
'per': per,
'success': result.returncode == 0,
'stdout': result.stdout[:1000], # 只保存前1000字符
}
except subprocess.TimeoutExpired:
return {'config': config, 'per': float('inf'), 'error': 'Timeout'}
except Exception as e:
return {'config': config, 'per': float('inf'), 'error': str(e)}
def run_coarse_search(args):
"""运行粗搜索"""
print("🔍 第一阶段:粗搜索")
print("=" * 50)
search_space = generate_coarse_search_space(args)
total_configs = len(search_space)
print(f"粗搜索空间: {total_configs} 个配置")
print(f"GRU权重: {args.coarse_gru_weights}")
print(f"TTA权重: {args.coarse_tta_weights}")
print()
results = []
best_per = float('inf')
for i, config in enumerate(search_space):
print(f"进度: {i+1}/{total_configs} ({100*(i+1)/total_configs:.1f}%)")
print(f"配置: GRU={config[0]:.1f}, TTA=({config[2]},{config[3]},{config[4]},{config[5]})")
result = run_evaluation(config, args)
results.append(result)
if result['per'] < best_per:
best_per = result['per']
print(f"🎯 新最佳PER: {best_per:.3f}%")
else:
print(f" PER: {result['per']:.3f}%")
# 保存粗搜索结果
coarse_results = {
'results': results,
'stage': 'coarse',
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
'args': vars(args)
}
with open(args.coarse_results, 'w') as f:
json.dump(coarse_results, f, indent=2)
# 选择最佳配置
valid_results = [r for r in results if r['per'] != float('inf')]
best_configs = sorted(valid_results, key=lambda x: x['per'])[:args.top_k]
print(f"\n粗搜索完成!选择前{args.top_k}个配置进行精细搜索:")
for i, config in enumerate(best_configs):
print(f"{i+1}. PER={config['per']:.3f}% | GRU={config['gru_weight']:.1f} | {config['tta_weights']}")
return best_configs
def run_fine_search(best_configs, args):
"""运行精细搜索"""
print(f"\n🔬 第二阶段:精细搜索")
print("=" * 50)
fine_search_space = generate_fine_search_space(best_configs, args)
total_configs = len(fine_search_space)
print(f"精细搜索空间: {total_configs} 个配置")
print(f"搜索范围: ±{args.fine_range}")
print(f"搜索步长: {args.fine_step}")
print()
results = []
best_per = float('inf')
for i, config in enumerate(fine_search_space):
print(f"进度: {i+1}/{total_configs} ({100*(i+1)/total_configs:.1f}%)")
result = run_evaluation(config, args)
results.append(result)
if result['per'] < best_per:
best_per = result['per']
print(f"🎯 新最佳PER: {best_per:.3f}%")
print(f" 配置: GRU={result['gru_weight']:.1f} | {result['tta_weights']}")
if i % 10 == 0: # 每10个配置显示一次进度
print(f" 当前PER: {result['per']:.3f}%")
return results
def main():
args = parse_arguments()
print("🚀 分阶段TTA-E参数搜索")
print("=" * 60)
if args.stage in ['coarse', 'both']:
# 运行粗搜索
best_configs = run_coarse_search(args)
if args.stage == 'coarse':
print(f"\n✅ 粗搜索完成,结果保存到: {args.coarse_results}")
return
else:
# 从文件加载粗搜索结果
print(f"📁 加载粗搜索结果: {args.coarse_results}")
with open(args.coarse_results, 'r') as f:
coarse_data = json.load(f)
valid_results = [r for r in coarse_data['results'] if r['per'] != float('inf')]
best_configs = sorted(valid_results, key=lambda x: x['per'])[:args.top_k]
if args.stage in ['fine', 'both']:
# 运行精细搜索
fine_results = run_fine_search(best_configs, args)
# 合并所有结果
all_results = fine_results
if args.stage == 'both':
all_results.extend([r for r in coarse_data['results'] if 'results' in locals()])
# 找到最终最佳配置
valid_results = [r for r in all_results if r['per'] != float('inf')]
final_best = min(valid_results, key=lambda x: x['per'])
# 保存最终结果
final_results = {
'best_config': final_best,
'all_fine_results': fine_results,
'stage': args.stage,
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
'args': vars(args)
}
with open(args.final_results, 'w') as f:
json.dump(final_results, f, indent=2)
print(f"\n🏆 最终最佳配置:")
print(f"PER: {final_best['per']:.3f}%")
print(f"GRU权重: {final_best['gru_weight']:.1f}")
print(f"TTA权重: {final_best['tta_weights']}")
print(f"结果保存到: {args.final_results}")
# 显示top-10
sorted_results = sorted(valid_results, key=lambda x: x['per'])[:10]
print(f"\n📊 Top-10配置:")
for i, result in enumerate(sorted_results):
tw = result['tta_weights']
print(f"{i+1:2d}. PER={result['per']:6.3f}% | GRU={result['gru_weight']:.1f} | "
f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})")
if __name__ == "__main__":
main()