294 lines
11 KiB
Python
294 lines
11 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
分阶段TTA-E参数搜索
|
|||
|
先粗搜索找到有希望的区域,再精细搜索
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import argparse
|
|||
|
import json
|
|||
|
import numpy as np
|
|||
|
from itertools import product
|
|||
|
import subprocess
|
|||
|
import time
|
|||
|
|
|||
|
def parse_arguments():
|
|||
|
parser = argparse.ArgumentParser(description='分阶段TTA-E参数搜索')
|
|||
|
|
|||
|
# 基础参数
|
|||
|
parser.add_argument('--base_script', type=str, default='evaluate_model.py')
|
|||
|
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final')
|
|||
|
parser.add_argument('--eval_type', type=str, default='val')
|
|||
|
parser.add_argument('--gpu_number', type=int, default=0)
|
|||
|
|
|||
|
# 搜索阶段控制
|
|||
|
parser.add_argument('--stage', type=str, default='coarse', choices=['coarse', 'fine', 'both'],
|
|||
|
help='搜索阶段:coarse=粗搜索,fine=精细搜索,both=两阶段')
|
|||
|
parser.add_argument('--coarse_results', type=str, default='coarse_results.json',
|
|||
|
help='粗搜索结果文件(用于精细搜索阶段)')
|
|||
|
parser.add_argument('--final_results', type=str, default='final_results.json',
|
|||
|
help='最终结果文件')
|
|||
|
|
|||
|
# 粗搜索参数(步长0.2)
|
|||
|
parser.add_argument('--coarse_gru_weights', type=str, default='0.2,0.4,0.6,0.8,1.0')
|
|||
|
parser.add_argument('--coarse_tta_weights', type=str, default='0.0,0.5,1.0')
|
|||
|
|
|||
|
# 精细搜索参数(步长0.1,在最佳配置周围)
|
|||
|
parser.add_argument('--fine_range', type=float, default=0.3,
|
|||
|
help='精细搜索范围(围绕最佳配置的±范围)')
|
|||
|
parser.add_argument('--fine_step', type=float, default=0.1,
|
|||
|
help='精细搜索步长')
|
|||
|
|
|||
|
# 筛选控制
|
|||
|
parser.add_argument('--top_k', type=int, default=5,
|
|||
|
help='选择前K个最佳配置进行精细搜索')
|
|||
|
|
|||
|
return parser.parse_args()
|
|||
|
|
|||
|
def generate_coarse_search_space(args):
|
|||
|
"""生成粗搜索空间"""
|
|||
|
gru_weights = [float(x.strip()) for x in args.coarse_gru_weights.split(',')]
|
|||
|
tta_weights = [float(x.strip()) for x in args.coarse_tta_weights.split(',')]
|
|||
|
|
|||
|
search_space = []
|
|||
|
for gru_w in gru_weights:
|
|||
|
for noise_w in tta_weights:
|
|||
|
for scale_w in tta_weights:
|
|||
|
for shift_w in tta_weights:
|
|||
|
for smooth_w in tta_weights:
|
|||
|
search_space.append((gru_w, 1.0, noise_w, scale_w, shift_w, smooth_w))
|
|||
|
|
|||
|
return search_space
|
|||
|
|
|||
|
def generate_fine_search_space(best_configs, args):
|
|||
|
"""基于最佳配置生成精细搜索空间"""
|
|||
|
fine_search_space = []
|
|||
|
|
|||
|
for config in best_configs:
|
|||
|
gru_w = config['gru_weight']
|
|||
|
tta_w = config['tta_weights']
|
|||
|
|
|||
|
# 在每个参数周围生成精细搜索点
|
|||
|
gru_range = np.arange(
|
|||
|
max(0.1, gru_w - args.fine_range),
|
|||
|
min(1.0, gru_w + args.fine_range) + args.fine_step,
|
|||
|
args.fine_step
|
|||
|
)
|
|||
|
|
|||
|
for param_name in ['noise', 'scale', 'shift', 'smooth']:
|
|||
|
base_val = tta_w[param_name]
|
|||
|
param_range = np.arange(
|
|||
|
max(0.0, base_val - args.fine_range),
|
|||
|
min(1.0, base_val + args.fine_range) + args.fine_step,
|
|||
|
args.fine_step
|
|||
|
)
|
|||
|
|
|||
|
# 围绕当前最佳配置生成邻域
|
|||
|
for gru_fine in gru_range:
|
|||
|
for noise_fine in param_range if param_name == 'noise' else [tta_w['noise']]:
|
|||
|
for scale_fine in param_range if param_name == 'scale' else [tta_w['scale']]:
|
|||
|
for shift_fine in param_range if param_name == 'shift' else [tta_w['shift']]:
|
|||
|
for smooth_fine in param_range if param_name == 'smooth' else [tta_w['smooth']]:
|
|||
|
config_tuple = (
|
|||
|
round(gru_fine, 1), 1.0,
|
|||
|
round(noise_fine, 1), round(scale_fine, 1),
|
|||
|
round(shift_fine, 1), round(smooth_fine, 1)
|
|||
|
)
|
|||
|
if config_tuple not in fine_search_space:
|
|||
|
fine_search_space.append(config_tuple)
|
|||
|
|
|||
|
return fine_search_space
|
|||
|
|
|||
|
def run_evaluation(config, args):
|
|||
|
"""运行单个配置的评估"""
|
|||
|
gru_w, orig_w, noise_w, scale_w, shift_w, smooth_w = config
|
|||
|
|
|||
|
tta_weights_str = f"{orig_w},{noise_w},{scale_w},{shift_w},{smooth_w}"
|
|||
|
|
|||
|
cmd = [
|
|||
|
'python', args.base_script,
|
|||
|
'--gru_weight', str(gru_w),
|
|||
|
'--tta_weights', tta_weights_str,
|
|||
|
'--data_dir', args.data_dir,
|
|||
|
'--eval_type', args.eval_type,
|
|||
|
'--gpu_number', str(args.gpu_number)
|
|||
|
]
|
|||
|
|
|||
|
try:
|
|||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 30分钟超时
|
|||
|
|
|||
|
# 解析PER结果
|
|||
|
per = None
|
|||
|
for line in result.stdout.split('\n'):
|
|||
|
if 'Aggregate Phoneme Error Rate (PER):' in line:
|
|||
|
per_str = line.split(':')[-1].strip().replace('%', '')
|
|||
|
per = float(per_str)
|
|||
|
break
|
|||
|
|
|||
|
if per is None:
|
|||
|
print(f"⚠️ 无法解析PER结果: {config}")
|
|||
|
per = float('inf')
|
|||
|
|
|||
|
return {
|
|||
|
'config': config,
|
|||
|
'gru_weight': gru_w,
|
|||
|
'tta_weights': {
|
|||
|
'original': orig_w,
|
|||
|
'noise': noise_w,
|
|||
|
'scale': scale_w,
|
|||
|
'shift': shift_w,
|
|||
|
'smooth': smooth_w
|
|||
|
},
|
|||
|
'per': per,
|
|||
|
'success': result.returncode == 0,
|
|||
|
'stdout': result.stdout[:1000], # 只保存前1000字符
|
|||
|
}
|
|||
|
|
|||
|
except subprocess.TimeoutExpired:
|
|||
|
return {'config': config, 'per': float('inf'), 'error': 'Timeout'}
|
|||
|
except Exception as e:
|
|||
|
return {'config': config, 'per': float('inf'), 'error': str(e)}
|
|||
|
|
|||
|
def run_coarse_search(args):
|
|||
|
"""运行粗搜索"""
|
|||
|
print("🔍 第一阶段:粗搜索")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
search_space = generate_coarse_search_space(args)
|
|||
|
total_configs = len(search_space)
|
|||
|
print(f"粗搜索空间: {total_configs} 个配置")
|
|||
|
print(f"GRU权重: {args.coarse_gru_weights}")
|
|||
|
print(f"TTA权重: {args.coarse_tta_weights}")
|
|||
|
print()
|
|||
|
|
|||
|
results = []
|
|||
|
best_per = float('inf')
|
|||
|
|
|||
|
for i, config in enumerate(search_space):
|
|||
|
print(f"进度: {i+1}/{total_configs} ({100*(i+1)/total_configs:.1f}%)")
|
|||
|
print(f"配置: GRU={config[0]:.1f}, TTA=({config[2]},{config[3]},{config[4]},{config[5]})")
|
|||
|
|
|||
|
result = run_evaluation(config, args)
|
|||
|
results.append(result)
|
|||
|
|
|||
|
if result['per'] < best_per:
|
|||
|
best_per = result['per']
|
|||
|
print(f"🎯 新最佳PER: {best_per:.3f}%")
|
|||
|
else:
|
|||
|
print(f" PER: {result['per']:.3f}%")
|
|||
|
|
|||
|
# 保存粗搜索结果
|
|||
|
coarse_results = {
|
|||
|
'results': results,
|
|||
|
'stage': 'coarse',
|
|||
|
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
|
|||
|
'args': vars(args)
|
|||
|
}
|
|||
|
|
|||
|
with open(args.coarse_results, 'w') as f:
|
|||
|
json.dump(coarse_results, f, indent=2)
|
|||
|
|
|||
|
# 选择最佳配置
|
|||
|
valid_results = [r for r in results if r['per'] != float('inf')]
|
|||
|
best_configs = sorted(valid_results, key=lambda x: x['per'])[:args.top_k]
|
|||
|
|
|||
|
print(f"\n粗搜索完成!选择前{args.top_k}个配置进行精细搜索:")
|
|||
|
for i, config in enumerate(best_configs):
|
|||
|
print(f"{i+1}. PER={config['per']:.3f}% | GRU={config['gru_weight']:.1f} | {config['tta_weights']}")
|
|||
|
|
|||
|
return best_configs
|
|||
|
|
|||
|
def run_fine_search(best_configs, args):
|
|||
|
"""运行精细搜索"""
|
|||
|
print(f"\n🔬 第二阶段:精细搜索")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
fine_search_space = generate_fine_search_space(best_configs, args)
|
|||
|
total_configs = len(fine_search_space)
|
|||
|
print(f"精细搜索空间: {total_configs} 个配置")
|
|||
|
print(f"搜索范围: ±{args.fine_range}")
|
|||
|
print(f"搜索步长: {args.fine_step}")
|
|||
|
print()
|
|||
|
|
|||
|
results = []
|
|||
|
best_per = float('inf')
|
|||
|
|
|||
|
for i, config in enumerate(fine_search_space):
|
|||
|
print(f"进度: {i+1}/{total_configs} ({100*(i+1)/total_configs:.1f}%)")
|
|||
|
|
|||
|
result = run_evaluation(config, args)
|
|||
|
results.append(result)
|
|||
|
|
|||
|
if result['per'] < best_per:
|
|||
|
best_per = result['per']
|
|||
|
print(f"🎯 新最佳PER: {best_per:.3f}%")
|
|||
|
print(f" 配置: GRU={result['gru_weight']:.1f} | {result['tta_weights']}")
|
|||
|
|
|||
|
if i % 10 == 0: # 每10个配置显示一次进度
|
|||
|
print(f" 当前PER: {result['per']:.3f}%")
|
|||
|
|
|||
|
return results
|
|||
|
|
|||
|
def main():
|
|||
|
args = parse_arguments()
|
|||
|
|
|||
|
print("🚀 分阶段TTA-E参数搜索")
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
if args.stage in ['coarse', 'both']:
|
|||
|
# 运行粗搜索
|
|||
|
best_configs = run_coarse_search(args)
|
|||
|
|
|||
|
if args.stage == 'coarse':
|
|||
|
print(f"\n✅ 粗搜索完成,结果保存到: {args.coarse_results}")
|
|||
|
return
|
|||
|
else:
|
|||
|
# 从文件加载粗搜索结果
|
|||
|
print(f"📁 加载粗搜索结果: {args.coarse_results}")
|
|||
|
with open(args.coarse_results, 'r') as f:
|
|||
|
coarse_data = json.load(f)
|
|||
|
valid_results = [r for r in coarse_data['results'] if r['per'] != float('inf')]
|
|||
|
best_configs = sorted(valid_results, key=lambda x: x['per'])[:args.top_k]
|
|||
|
|
|||
|
if args.stage in ['fine', 'both']:
|
|||
|
# 运行精细搜索
|
|||
|
fine_results = run_fine_search(best_configs, args)
|
|||
|
|
|||
|
# 合并所有结果
|
|||
|
all_results = fine_results
|
|||
|
if args.stage == 'both':
|
|||
|
all_results.extend([r for r in coarse_data['results'] if 'results' in locals()])
|
|||
|
|
|||
|
# 找到最终最佳配置
|
|||
|
valid_results = [r for r in all_results if r['per'] != float('inf')]
|
|||
|
final_best = min(valid_results, key=lambda x: x['per'])
|
|||
|
|
|||
|
# 保存最终结果
|
|||
|
final_results = {
|
|||
|
'best_config': final_best,
|
|||
|
'all_fine_results': fine_results,
|
|||
|
'stage': args.stage,
|
|||
|
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
|
|||
|
'args': vars(args)
|
|||
|
}
|
|||
|
|
|||
|
with open(args.final_results, 'w') as f:
|
|||
|
json.dump(final_results, f, indent=2)
|
|||
|
|
|||
|
print(f"\n🏆 最终最佳配置:")
|
|||
|
print(f"PER: {final_best['per']:.3f}%")
|
|||
|
print(f"GRU权重: {final_best['gru_weight']:.1f}")
|
|||
|
print(f"TTA权重: {final_best['tta_weights']}")
|
|||
|
print(f"结果保存到: {args.final_results}")
|
|||
|
|
|||
|
# 显示top-10
|
|||
|
sorted_results = sorted(valid_results, key=lambda x: x['per'])[:10]
|
|||
|
print(f"\n📊 Top-10配置:")
|
|||
|
for i, result in enumerate(sorted_results):
|
|||
|
tw = result['tta_weights']
|
|||
|
print(f"{i+1:2d}. PER={result['per']:6.3f}% | GRU={result['gru_weight']:.1f} | "
|
|||
|
f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|