299 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			299 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | |||
|  | """
 | |||
|  | 简化版TTA-E参数搜索 | |||
|  | 专门搜索6个关键参数:gru_weight 和 5个TTA权重 | |||
|  | """
 | |||
|  | 
 | |||
|  | import os | |||
|  | import sys | |||
|  | import argparse | |||
|  | import json | |||
|  | import numpy as np | |||
|  | from itertools import product | |||
|  | from concurrent.futures import ThreadPoolExecutor, as_completed | |||
|  | import subprocess | |||
|  | import time | |||
|  | from tqdm import tqdm | |||
|  | 
 | |||
|  | def parse_arguments(): | |||
|  |     parser = argparse.ArgumentParser(description='简化版TTA-E参数搜索') | |||
|  |      | |||
|  |         # 搜索空间定义 - 精度0.1 | |||
|  |     parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', | |||
|  |                         help='GRU权重搜索范围') | |||
|  |     parser.add_argument('--original_weights', type=str, default='1.0', | |||
|  |                         help='原始数据权重(通常固定为1.0)') | |||
|  |     parser.add_argument('--noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', | |||
|  |                         help='噪声增强权重搜索范围') | |||
|  |     parser.add_argument('--scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', | |||
|  |                         help='缩放增强权重搜索范围') | |||
|  |     parser.add_argument('--shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', | |||
|  |                         help='偏移增强权重搜索范围') | |||
|  |     parser.add_argument('--smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0', | |||
|  |                         help='平滑增强权重搜索范围') | |||
|  |      | |||
|  |     # 基础评估参数 | |||
|  |     parser.add_argument('--base_script', type=str, default='evaluate_model.py', | |||
|  |                         help='基础评估脚本路径') | |||
|  |     parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', | |||
|  |                         help='数据目录') | |||
|  |     parser.add_argument('--eval_type', type=str, default='val', | |||
|  |                         help='评估类型') | |||
|  |     parser.add_argument('--gpu_number', type=int, default=0, | |||
|  |                         help='GPU编号') | |||
|  |      | |||
|  |     # 输出控制 | |||
|  |     parser.add_argument('--output_file', type=str, default='parameter_search_results.json', | |||
|  |                         help='搜索结果输出文件') | |||
|  |     parser.add_argument('--dry_run', action='store_true', | |||
|  |                         help='只显示搜索空间,不实际运行') | |||
|  |     parser.add_argument('--max_workers', type=int, default=25, | |||
|  |                         help='最大并行工作线程数') | |||
|  |     parser.add_argument('--batch_size', type=int, default=100, | |||
|  |                         help='每批处理的配置数量') | |||
|  |      | |||
|  |     return parser.parse_args() | |||
|  | 
 | |||
|  | def generate_search_space(args): | |||
|  |     """生成搜索空间""" | |||
|  |     gru_weights = [float(x.strip()) for x in args.gru_weights.split(',')] | |||
|  |     original_weights = [float(x.strip()) for x in args.original_weights.split(',')] | |||
|  |     noise_weights = [float(x.strip()) for x in args.noise_weights.split(',')] | |||
|  |     scale_weights = [float(x.strip()) for x in args.scale_weights.split(',')] | |||
|  |     shift_weights = [float(x.strip()) for x in args.shift_weights.split(',')] | |||
|  |     smooth_weights = [float(x.strip()) for x in args.smooth_weights.split(',')] | |||
|  |      | |||
|  |     search_space = list(product( | |||
|  |         gru_weights, original_weights, noise_weights,  | |||
|  |         scale_weights, shift_weights, smooth_weights | |||
|  |     )) | |||
|  |      | |||
|  |     return search_space | |||
|  | 
 | |||
|  | def run_single_evaluation(config, args): | |||
|  |     """运行单个配置的评估""" | |||
|  |     gru_w, orig_w, noise_w, scale_w, shift_w, smooth_w = config | |||
|  |      | |||
|  |     # 构建TTA权重字符串 | |||
|  |     tta_weights_str = f"{orig_w},{noise_w},{scale_w},{shift_w},{smooth_w}" | |||
|  |      | |||
|  |     # 构建命令 | |||
|  |     cmd = [ | |||
|  |         'python', args.base_script, | |||
|  |         '--gru_weight', str(gru_w), | |||
|  |         '--tta_weights', tta_weights_str, | |||
|  |         '--data_dir', args.data_dir, | |||
|  |         '--eval_type', args.eval_type, | |||
|  |         '--gpu_number', str(args.gpu_number) | |||
|  |     ] | |||
|  |      | |||
|  |     if args.dry_run: | |||
|  |         print(f"Would run: {' '.join(cmd)}") | |||
|  |         return { | |||
|  |             'config': config, | |||
|  |             'gru_weight': gru_w, | |||
|  |             'tta_weights': { | |||
|  |                 'original': orig_w, | |||
|  |                 'noise': noise_w, | |||
|  |                 'scale': scale_w, | |||
|  |                 'shift': shift_w, | |||
|  |                 'smooth': smooth_w | |||
|  |             }, | |||
|  |             'per': np.random.uniform(20, 40),  # 模拟PER结果 | |||
|  |             'command': ' '.join(cmd) | |||
|  |         } | |||
|  |      | |||
|  |     # 实际运行命令 | |||
|  |     import subprocess | |||
|  |     import tempfile | |||
|  |      | |||
|  |     try: | |||
|  |         # 捕获输出 | |||
|  |         result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)  # 增加超时到30分钟 | |||
|  |          | |||
|  |         # 解析PER结果 | |||
|  |         per = None | |||
|  |         for line in result.stdout.split('\n'): | |||
|  |             if 'Aggregate Phoneme Error Rate (PER):' in line: | |||
|  |                 # 提取百分号前的数字 | |||
|  |                 try: | |||
|  |                     per_str = line.split('Aggregate Phoneme Error Rate (PER):')[-1].strip() | |||
|  |                     per_str = per_str.replace('%', '').strip() | |||
|  |                     per = float(per_str) | |||
|  |                     break | |||
|  |                 except (ValueError, IndexError) as e: | |||
|  |                     print(f"Error parsing PER from line: {line}, error: {e}") | |||
|  |                     continue | |||
|  |          | |||
|  |         if per is None: | |||
|  |             print(f"Warning: Could not parse PER from output for config {config}") | |||
|  |             per = float('inf') | |||
|  |          | |||
|  |         return { | |||
|  |             'config': config, | |||
|  |             'gru_weight': gru_w, | |||
|  |             'tta_weights': { | |||
|  |                 'original': orig_w, | |||
|  |                 'noise': noise_w, | |||
|  |                 'scale': scale_w, | |||
|  |                 'shift': shift_w, | |||
|  |                 'smooth': smooth_w | |||
|  |             }, | |||
|  |             'per': per, | |||
|  |             'command': ' '.join(cmd), | |||
|  |             'success': result.returncode == 0 | |||
|  |         } | |||
|  |          | |||
|  |     except subprocess.TimeoutExpired: | |||
|  |         return { | |||
|  |             'config': config, | |||
|  |             'per': float('inf'), | |||
|  |             'error': 'Timeout', | |||
|  |             'command': ' '.join(cmd) | |||
|  |         } | |||
|  |     except Exception as e: | |||
|  |         return { | |||
|  |             'config': config, | |||
|  |             'per': float('inf'), | |||
|  |             'error': str(e), | |||
|  |             'command': ' '.join(cmd) | |||
|  |         } | |||
|  | 
 | |||
|  | def main(): | |||
|  |     args = parse_arguments() | |||
|  |      | |||
|  |     print("🔍 TTA-E参数搜索") | |||
|  |     print("=" * 50) | |||
|  |      | |||
|  |     # 生成搜索空间 | |||
|  |     search_space = generate_search_space(args) | |||
|  |     total_configs = len(search_space) | |||
|  |      | |||
|  |     print(f"搜索空间大小: {total_configs} 个配置") | |||
|  |     print(f"参数范围:") | |||
|  |     print(f"  GRU权重: {args.gru_weights}") | |||
|  |     print(f"  原始权重: {args.original_weights}") | |||
|  |     print(f"  噪声权重: {args.noise_weights}") | |||
|  |     print(f"  缩放权重: {args.scale_weights}") | |||
|  |     print(f"  偏移权重: {args.shift_weights}") | |||
|  |     print(f"  平滑权重: {args.smooth_weights}") | |||
|  |     print() | |||
|  |      | |||
|  |     if args.dry_run: | |||
|  |         print("🧪 Dry run模式 - 显示前5个配置示例:") | |||
|  |         for i, config in enumerate(search_space[:5]): | |||
|  |             result = run_single_evaluation(config, args) | |||
|  |             print(f"{i+1}. {result['command']}") | |||
|  |         print(f"\n总共会运行 {total_configs} 个配置") | |||
|  |         return | |||
|  |      | |||
|  |     # 运行搜索 | |||
|  |     print("🚀 开始参数搜索...") | |||
|  |     print(f"使用 {args.max_workers} 个线程并行处理...") | |||
|  |      | |||
|  |     results = [] | |||
|  |     best_per = float('inf') | |||
|  |     best_config = None | |||
|  |     completed_count = 0 | |||
|  |      | |||
|  |     # 使用线程池并行处理 | |||
|  |     with ThreadPoolExecutor(max_workers=args.max_workers) as executor: | |||
|  |         # 提交所有任务 | |||
|  |         future_to_config = { | |||
|  |             executor.submit(run_single_evaluation, config, args): config  | |||
|  |             for config in search_space | |||
|  |         } | |||
|  |          | |||
|  |         # 处理完成的任务 | |||
|  |         for future in as_completed(future_to_config): | |||
|  |             try: | |||
|  |                 result = future.result() | |||
|  |                 results.append(result) | |||
|  |                 completed_count += 1 | |||
|  |                  | |||
|  |                 # 更新最佳结果 | |||
|  |                 if result['per'] < best_per: | |||
|  |                     best_per = result['per'] | |||
|  |                     best_config = result | |||
|  |                     config = future_to_config[future] | |||
|  |                     print(f"\n🎯 新最优配置[{completed_count}/{total_configs}]: PER={best_per:.3f}%") | |||
|  |                     print(f"   GRU={config[0]:.1f}, TTA=({config[1]},{config[2]},{config[3]},{config[4]},{config[5]})") | |||
|  |                  | |||
|  |                 # 定期进度报告 | |||
|  |                 if completed_count % args.batch_size == 0: | |||
|  |                     progress = 100 * completed_count / total_configs | |||
|  |                     print(f"\n📊 进度: {completed_count}/{total_configs} ({progress:.1f}%)") | |||
|  |                     print(f"   当前最优PER: {best_per:.3f}%") | |||
|  |                 elif completed_count % 50 == 0:  # 更频繁的简单进度 | |||
|  |                     print(f"... {completed_count}/{total_configs} ...", end='', flush=True) | |||
|  |                  | |||
|  |             except Exception as e: | |||
|  |                 completed_count += 1 | |||
|  |                 config = future_to_config[future] | |||
|  |                 print(f"\n❌ 配置失败: {config}, 错误: {e}") | |||
|  |                 results.append({ | |||
|  |                     'config': config, | |||
|  |                     'per': float('inf'), | |||
|  |                     'error': str(e) | |||
|  |                 }) | |||
|  |      | |||
|  |     print(f"\n✅ 所有任务完成!") | |||
|  |      | |||
|  |     # 找到真正的最佳结果(防止异常情况) | |||
|  |     valid_results = [r for r in results if 'error' not in r and r['per'] != float('inf')] | |||
|  |     if valid_results: | |||
|  |         best_config = min(valid_results, key=lambda x: x['per']) | |||
|  |      | |||
|  |     # 保存结果 | |||
|  |     search_results = { | |||
|  |         'best_config': best_config, | |||
|  |         'all_results': results, | |||
|  |         'search_space_size': total_configs, | |||
|  |         'args': vars(args), | |||
|  |         'timestamp': __import__('time').strftime("%Y-%m-%d %H:%M:%S") | |||
|  |     } | |||
|  |      | |||
|  |     with open(args.output_file, 'w') as f: | |||
|  |         json.dump(search_results, f, indent=2) | |||
|  |      | |||
|  |     # 总结结果 | |||
|  |     print("\n" + "=" * 50) | |||
|  |     print("🏆 搜索完成!") | |||
|  |      | |||
|  |     if best_config is not None: | |||
|  |         print(f"最佳配置:") | |||
|  |         print(f"  PER: {best_config['per']:.3f}%") | |||
|  |         print(f"  GRU权重: {best_config['gru_weight']:.1f}") | |||
|  |         print(f"  TTA权重: {best_config['tta_weights']}") | |||
|  |         print(f"  命令: {best_config['command']}") | |||
|  |          | |||
|  |         # 显示前10个最佳结果 | |||
|  |         sorted_results = sorted([r for r in results if r['per'] != float('inf')],  | |||
|  |                                key=lambda x: x['per']) | |||
|  |          | |||
|  |         print(f"\n📊 前10个最佳配置:") | |||
|  |         print("排名 | PER(%) | GRU | Original | Noise | Scale | Shift | Smooth") | |||
|  |         print("-" * 70) | |||
|  |         for i, result in enumerate(sorted_results[:10]): | |||
|  |             tw = result['tta_weights'] | |||
|  |             print(f"{i+1:3d}  | {result['per']:6.3f} | {result['gru_weight']:3.1f} | " | |||
|  |                   f"{tw['original']:8.1f} | {tw['noise']:5.1f} | {tw['scale']:5.1f} | " | |||
|  |                   f"{tw['shift']:5.1f} | {tw['smooth']:6.1f}") | |||
|  |     else: | |||
|  |         print("❌ 未找到有效的配置结果!所有配置都失败了。") | |||
|  |         print("请检查评估脚本是否正常工作。") | |||
|  |      | |||
|  |     print(f"\n📈 搜索统计:") | |||
|  |     print(f"  总配置数: {total_configs}") | |||
|  |     print(f"  成功配置数: {len(valid_results)}") | |||
|  |     print(f"  失败配置数: {total_configs - len(valid_results)}") | |||
|  |      | |||
|  |     if valid_results: | |||
|  |         valid_pers = [r['per'] for r in valid_results] | |||
|  |         print(f"  PER范围: {min(valid_pers):.3f}% - {max(valid_pers):.3f}%") | |||
|  |         print(f"  平均PER: {sum(valid_pers)/len(valid_pers):.3f}%") | |||
|  |      | |||
|  |     print(f"\n✅ 结果已保存到: {args.output_file}") | |||
|  | 
 | |||
|  | if __name__ == "__main__": | |||
|  |     main() |