#!/usr/bin/env python3 """ 简化版TTA-E参数搜索 专门搜索6个关键参数:gru_weight 和 5个TTA权重 """ import os import sys import argparse import json import numpy as np from itertools import product from concurrent.futures import ThreadPoolExecutor, as_completed import subprocess import time from tqdm import tqdm def parse_arguments(): parser = argparse.ArgumentParser(description='简化版TTA-E参数搜索') # 搜索空间定义 - 精度0.1 parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', help='GRU权重搜索范围') parser.add_argument('--original_weights', type=str, default='1.0', help='原始数据权重(通常固定为1.0)') parser.add_argument('--noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', help='噪声增强权重搜索范围') parser.add_argument('--scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', help='缩放增强权重搜索范围') parser.add_argument('--shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', help='偏移增强权重搜索范围') parser.add_argument('--smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', help='平滑增强权重搜索范围') # 基础评估参数 parser.add_argument('--base_script', type=str, default='evaluate_model.py', help='基础评估脚本路径') parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', help='数据目录') parser.add_argument('--eval_type', type=str, default='val', help='评估类型') parser.add_argument('--gpu_number', type=int, default=0, help='GPU编号') # 输出控制 parser.add_argument('--output_file', type=str, default='parameter_search_results.json', help='搜索结果输出文件') parser.add_argument('--dry_run', action='store_true', help='只显示搜索空间,不实际运行') parser.add_argument('--max_workers', type=int, default=25, help='最大并行工作线程数') parser.add_argument('--batch_size', type=int, default=100, help='每批处理的配置数量') return parser.parse_args() def generate_search_space(args): """生成搜索空间""" gru_weights = [float(x.strip()) for x in args.gru_weights.split(',')] original_weights = [float(x.strip()) for x in args.original_weights.split(',')] noise_weights = [float(x.strip()) for x in args.noise_weights.split(',')] scale_weights = [float(x.strip()) for x in args.scale_weights.split(',')] shift_weights = [float(x.strip()) for x in args.shift_weights.split(',')] smooth_weights = [float(x.strip()) for x in args.smooth_weights.split(',')] search_space = list(product( gru_weights, original_weights, noise_weights, scale_weights, shift_weights, smooth_weights )) return search_space def run_single_evaluation(config, args): """运行单个配置的评估""" gru_w, orig_w, noise_w, scale_w, shift_w, smooth_w = config # 构建TTA权重字符串 tta_weights_str = f"{orig_w},{noise_w},{scale_w},{shift_w},{smooth_w}" # 构建命令 cmd = [ 'python', args.base_script, '--gru_weight', str(gru_w), '--tta_weights', tta_weights_str, '--data_dir', args.data_dir, '--eval_type', args.eval_type, '--gpu_number', str(args.gpu_number) ] if args.dry_run: print(f"Would run: {' '.join(cmd)}") return { 'config': config, 'gru_weight': gru_w, 'tta_weights': { 'original': orig_w, 'noise': noise_w, 'scale': scale_w, 'shift': shift_w, 'smooth': smooth_w }, 'per': np.random.uniform(20, 40), # 模拟PER结果 'command': ' '.join(cmd) } # 实际运行命令 import subprocess import tempfile try: # 捕获输出 result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800) # 增加超时到30分钟 # 解析PER结果 per = None for line in result.stdout.split('\n'): if 'Aggregate Phoneme Error Rate (PER):' in line: # 提取百分号前的数字 try: per_str = line.split('Aggregate Phoneme Error Rate (PER):')[-1].strip() per_str = per_str.replace('%', '').strip() per = float(per_str) break except (ValueError, IndexError) as e: print(f"Error parsing PER from line: {line}, error: {e}") continue if per is None: print(f"Warning: Could not parse PER from output for config {config}") per = float('inf') return { 'config': config, 'gru_weight': gru_w, 'tta_weights': { 'original': orig_w, 'noise': noise_w, 'scale': scale_w, 'shift': shift_w, 'smooth': smooth_w }, 'per': per, 'command': ' '.join(cmd), 'success': result.returncode == 0 } except subprocess.TimeoutExpired: return { 'config': config, 'per': float('inf'), 'error': 'Timeout', 'command': ' '.join(cmd) } except Exception as e: return { 'config': config, 'per': float('inf'), 'error': str(e), 'command': ' '.join(cmd) } def main(): args = parse_arguments() print("🔍 TTA-E参数搜索") print("=" * 50) # 生成搜索空间 search_space = generate_search_space(args) total_configs = len(search_space) print(f"搜索空间大小: {total_configs} 个配置") print(f"参数范围:") print(f" GRU权重: {args.gru_weights}") print(f" 原始权重: {args.original_weights}") print(f" 噪声权重: {args.noise_weights}") print(f" 缩放权重: {args.scale_weights}") print(f" 偏移权重: {args.shift_weights}") print(f" 平滑权重: {args.smooth_weights}") print() if args.dry_run: print("🧪 Dry run模式 - 显示前5个配置示例:") for i, config in enumerate(search_space[:5]): result = run_single_evaluation(config, args) print(f"{i+1}. {result['command']}") print(f"\n总共会运行 {total_configs} 个配置") return # 运行搜索 print("🚀 开始参数搜索...") print(f"使用 {args.max_workers} 个线程并行处理...") results = [] best_per = float('inf') best_config = None completed_count = 0 # 使用线程池并行处理 with ThreadPoolExecutor(max_workers=args.max_workers) as executor: # 提交所有任务 future_to_config = { executor.submit(run_single_evaluation, config, args): config for config in search_space } # 处理完成的任务 for future in as_completed(future_to_config): try: result = future.result() results.append(result) completed_count += 1 # 更新最佳结果 if result['per'] < best_per: best_per = result['per'] best_config = result config = future_to_config[future] print(f"\n🎯 新最优配置[{completed_count}/{total_configs}]: PER={best_per:.3f}%") print(f" GRU={config[0]:.1f}, TTA=({config[1]},{config[2]},{config[3]},{config[4]},{config[5]})") # 定期进度报告 if completed_count % args.batch_size == 0: progress = 100 * completed_count / total_configs print(f"\n📊 进度: {completed_count}/{total_configs} ({progress:.1f}%)") print(f" 当前最优PER: {best_per:.3f}%") elif completed_count % 50 == 0: # 更频繁的简单进度 print(f"... {completed_count}/{total_configs} ...", end='', flush=True) except Exception as e: completed_count += 1 config = future_to_config[future] print(f"\n❌ 配置失败: {config}, 错误: {e}") results.append({ 'config': config, 'per': float('inf'), 'error': str(e) }) print(f"\n✅ 所有任务完成!") # 找到真正的最佳结果(防止异常情况) valid_results = [r for r in results if 'error' not in r and r['per'] != float('inf')] if valid_results: best_config = min(valid_results, key=lambda x: x['per']) # 保存结果 search_results = { 'best_config': best_config, 'all_results': results, 'search_space_size': total_configs, 'args': vars(args), 'timestamp': __import__('time').strftime("%Y-%m-%d %H:%M:%S") } with open(args.output_file, 'w') as f: json.dump(search_results, f, indent=2) # 总结结果 print("\n" + "=" * 50) print("🏆 搜索完成!") if best_config is not None: print(f"最佳配置:") print(f" PER: {best_config['per']:.3f}%") print(f" GRU权重: {best_config['gru_weight']:.1f}") print(f" TTA权重: {best_config['tta_weights']}") print(f" 命令: {best_config['command']}") # 显示前10个最佳结果 sorted_results = sorted([r for r in results if r['per'] != float('inf')], key=lambda x: x['per']) print(f"\n📊 前10个最佳配置:") print("排名 | PER(%) | GRU | Original | Noise | Scale | Shift | Smooth") print("-" * 70) for i, result in enumerate(sorted_results[:10]): tw = result['tta_weights'] print(f"{i+1:3d} | {result['per']:6.3f} | {result['gru_weight']:3.1f} | " f"{tw['original']:8.1f} | {tw['noise']:5.1f} | {tw['scale']:5.1f} | " f"{tw['shift']:5.1f} | {tw['smooth']:6.1f}") else: print("❌ 未找到有效的配置结果!所有配置都失败了。") print("请检查评估脚本是否正常工作。") print(f"\n📈 搜索统计:") print(f" 总配置数: {total_configs}") print(f" 成功配置数: {len(valid_results)}") print(f" 失败配置数: {total_configs - len(valid_results)}") if valid_results: valid_pers = [r['per'] for r in valid_results] print(f" PER范围: {min(valid_pers):.3f}% - {max(valid_pers):.3f}%") print(f" 平均PER: {sum(valid_pers)/len(valid_pers):.3f}%") print(f"\n✅ 结果已保存到: {args.output_file}") if __name__ == "__main__": main()