Files
b2txt25/TTA-E/simple_search.py

299 lines
11 KiB
Python
Raw Normal View History

2025-10-06 15:17:44 +08:00
#!/usr/bin/env python3
"""
简化版TTA-E参数搜索
专门搜索6个关键参数gru_weight 5个TTA权重
"""
import os
import sys
import argparse
import json
import numpy as np
from itertools import product
from concurrent.futures import ThreadPoolExecutor, as_completed
import subprocess
import time
from tqdm import tqdm
def parse_arguments():
parser = argparse.ArgumentParser(description='简化版TTA-E参数搜索')
# 搜索空间定义 - 精度0.1
parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='GRU权重搜索范围')
parser.add_argument('--original_weights', type=str, default='1.0',
help='原始数据权重通常固定为1.0')
parser.add_argument('--noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='噪声增强权重搜索范围')
parser.add_argument('--scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='缩放增强权重搜索范围')
parser.add_argument('--shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='偏移增强权重搜索范围')
parser.add_argument('--smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
help='平滑增强权重搜索范围')
# 基础评估参数
parser.add_argument('--base_script', type=str, default='evaluate_model.py',
help='基础评估脚本路径')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='数据目录')
parser.add_argument('--eval_type', type=str, default='val',
help='评估类型')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU编号')
# 输出控制
parser.add_argument('--output_file', type=str, default='parameter_search_results.json',
help='搜索结果输出文件')
parser.add_argument('--dry_run', action='store_true',
help='只显示搜索空间,不实际运行')
parser.add_argument('--max_workers', type=int, default=25,
help='最大并行工作线程数')
parser.add_argument('--batch_size', type=int, default=100,
help='每批处理的配置数量')
return parser.parse_args()
def generate_search_space(args):
"""生成搜索空间"""
gru_weights = [float(x.strip()) for x in args.gru_weights.split(',')]
original_weights = [float(x.strip()) for x in args.original_weights.split(',')]
noise_weights = [float(x.strip()) for x in args.noise_weights.split(',')]
scale_weights = [float(x.strip()) for x in args.scale_weights.split(',')]
shift_weights = [float(x.strip()) for x in args.shift_weights.split(',')]
smooth_weights = [float(x.strip()) for x in args.smooth_weights.split(',')]
search_space = list(product(
gru_weights, original_weights, noise_weights,
scale_weights, shift_weights, smooth_weights
))
return search_space
def run_single_evaluation(config, args):
"""运行单个配置的评估"""
gru_w, orig_w, noise_w, scale_w, shift_w, smooth_w = config
# 构建TTA权重字符串
tta_weights_str = f"{orig_w},{noise_w},{scale_w},{shift_w},{smooth_w}"
# 构建命令
cmd = [
'python', args.base_script,
'--gru_weight', str(gru_w),
'--tta_weights', tta_weights_str,
'--data_dir', args.data_dir,
'--eval_type', args.eval_type,
'--gpu_number', str(args.gpu_number)
]
if args.dry_run:
print(f"Would run: {' '.join(cmd)}")
return {
'config': config,
'gru_weight': gru_w,
'tta_weights': {
'original': orig_w,
'noise': noise_w,
'scale': scale_w,
'shift': shift_w,
'smooth': smooth_w
},
'per': np.random.uniform(20, 40), # 模拟PER结果
'command': ' '.join(cmd)
}
# 实际运行命令
import subprocess
import tempfile
try:
# 捕获输出
result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 增加超时到30分钟
# 解析PER结果
per = None
for line in result.stdout.split('\n'):
if 'Aggregate Phoneme Error Rate (PER):' in line:
# 提取百分号前的数字
try:
per_str = line.split('Aggregate Phoneme Error Rate (PER):')[-1].strip()
per_str = per_str.replace('%', '').strip()
per = float(per_str)
break
except (ValueError, IndexError) as e:
print(f"Error parsing PER from line: {line}, error: {e}")
continue
if per is None:
print(f"Warning: Could not parse PER from output for config {config}")
per = float('inf')
return {
'config': config,
'gru_weight': gru_w,
'tta_weights': {
'original': orig_w,
'noise': noise_w,
'scale': scale_w,
'shift': shift_w,
'smooth': smooth_w
},
'per': per,
'command': ' '.join(cmd),
'success': result.returncode == 0
}
except subprocess.TimeoutExpired:
return {
'config': config,
'per': float('inf'),
'error': 'Timeout',
'command': ' '.join(cmd)
}
except Exception as e:
return {
'config': config,
'per': float('inf'),
'error': str(e),
'command': ' '.join(cmd)
}
def main():
args = parse_arguments()
print("🔍 TTA-E参数搜索")
print("=" * 50)
# 生成搜索空间
search_space = generate_search_space(args)
total_configs = len(search_space)
print(f"搜索空间大小: {total_configs} 个配置")
print(f"参数范围:")
print(f" GRU权重: {args.gru_weights}")
print(f" 原始权重: {args.original_weights}")
print(f" 噪声权重: {args.noise_weights}")
print(f" 缩放权重: {args.scale_weights}")
print(f" 偏移权重: {args.shift_weights}")
print(f" 平滑权重: {args.smooth_weights}")
print()
if args.dry_run:
print("🧪 Dry run模式 - 显示前5个配置示例:")
for i, config in enumerate(search_space[:5]):
result = run_single_evaluation(config, args)
print(f"{i+1}. {result['command']}")
print(f"\n总共会运行 {total_configs} 个配置")
return
# 运行搜索
print("🚀 开始参数搜索...")
print(f"使用 {args.max_workers} 个线程并行处理...")
results = []
best_per = float('inf')
best_config = None
completed_count = 0
# 使用线程池并行处理
with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
# 提交所有任务
future_to_config = {
executor.submit(run_single_evaluation, config, args): config
for config in search_space
}
# 处理完成的任务
for future in as_completed(future_to_config):
try:
result = future.result()
results.append(result)
completed_count += 1
# 更新最佳结果
if result['per'] < best_per:
best_per = result['per']
best_config = result
config = future_to_config[future]
print(f"\n🎯 新最优配置[{completed_count}/{total_configs}]: PER={best_per:.3f}%")
print(f" GRU={config[0]:.1f}, TTA=({config[1]},{config[2]},{config[3]},{config[4]},{config[5]})")
# 定期进度报告
if completed_count % args.batch_size == 0:
progress = 100 * completed_count / total_configs
print(f"\n📊 进度: {completed_count}/{total_configs} ({progress:.1f}%)")
print(f" 当前最优PER: {best_per:.3f}%")
elif completed_count % 50 == 0: # 更频繁的简单进度
print(f"... {completed_count}/{total_configs} ...", end='', flush=True)
except Exception as e:
completed_count += 1
config = future_to_config[future]
print(f"\n❌ 配置失败: {config}, 错误: {e}")
results.append({
'config': config,
'per': float('inf'),
'error': str(e)
})
print(f"\n✅ 所有任务完成!")
# 找到真正的最佳结果(防止异常情况)
valid_results = [r for r in results if 'error' not in r and r['per'] != float('inf')]
if valid_results:
best_config = min(valid_results, key=lambda x: x['per'])
# 保存结果
search_results = {
'best_config': best_config,
'all_results': results,
'search_space_size': total_configs,
'args': vars(args),
'timestamp': __import__('time').strftime("%Y-%m-%d %H:%M:%S")
}
with open(args.output_file, 'w') as f:
json.dump(search_results, f, indent=2)
# 总结结果
print("\n" + "=" * 50)
print("🏆 搜索完成!")
if best_config is not None:
print(f"最佳配置:")
print(f" PER: {best_config['per']:.3f}%")
print(f" GRU权重: {best_config['gru_weight']:.1f}")
print(f" TTA权重: {best_config['tta_weights']}")
print(f" 命令: {best_config['command']}")
# 显示前10个最佳结果
sorted_results = sorted([r for r in results if r['per'] != float('inf')],
key=lambda x: x['per'])
print(f"\n📊 前10个最佳配置:")
print("排名 | PER(%) | GRU | Original | Noise | Scale | Shift | Smooth")
print("-" * 70)
for i, result in enumerate(sorted_results[:10]):
tw = result['tta_weights']
print(f"{i+1:3d} | {result['per']:6.3f} | {result['gru_weight']:3.1f} | "
f"{tw['original']:8.1f} | {tw['noise']:5.1f} | {tw['scale']:5.1f} | "
f"{tw['shift']:5.1f} | {tw['smooth']:6.1f}")
else:
print("❌ 未找到有效的配置结果!所有配置都失败了。")
print("请检查评估脚本是否正常工作。")
print(f"\n📈 搜索统计:")
print(f" 总配置数: {total_configs}")
print(f" 成功配置数: {len(valid_results)}")
print(f" 失败配置数: {total_configs - len(valid_results)}")
if valid_results:
valid_pers = [r['per'] for r in valid_results]
print(f" PER范围: {min(valid_pers):.3f}% - {max(valid_pers):.3f}%")
print(f" 平均PER: {sum(valid_pers)/len(valid_pers):.3f}%")
print(f"\n✅ 结果已保存到: {args.output_file}")
if __name__ == "__main__":
main()