299 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			299 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | ||
| """
 | ||
| 简化版TTA-E参数搜索
 | ||
| 专门搜索6个关键参数:gru_weight 和 5个TTA权重
 | ||
| """
 | ||
| 
 | ||
| import os
 | ||
| import sys
 | ||
| import argparse
 | ||
| import json
 | ||
| import numpy as np
 | ||
| from itertools import product
 | ||
| from concurrent.futures import ThreadPoolExecutor, as_completed
 | ||
| import subprocess
 | ||
| import time
 | ||
| from tqdm import tqdm
 | ||
| 
 | ||
| def parse_arguments():
 | ||
|     parser = argparse.ArgumentParser(description='简化版TTA-E参数搜索')
 | ||
|     
 | ||
|         # 搜索空间定义 - 精度0.1
 | ||
|     parser.add_argument('--gru_weights', type=str, default='0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
 | ||
|                         help='GRU权重搜索范围')
 | ||
|     parser.add_argument('--original_weights', type=str, default='1.0',
 | ||
|                         help='原始数据权重(通常固定为1.0)')
 | ||
|     parser.add_argument('--noise_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
 | ||
|                         help='噪声增强权重搜索范围')
 | ||
|     parser.add_argument('--scale_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
 | ||
|                         help='缩放增强权重搜索范围')
 | ||
|     parser.add_argument('--shift_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
 | ||
|                         help='偏移增强权重搜索范围')
 | ||
|     parser.add_argument('--smooth_weights', type=str, default='0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0',
 | ||
|                         help='平滑增强权重搜索范围')
 | ||
|     
 | ||
|     # 基础评估参数
 | ||
|     parser.add_argument('--base_script', type=str, default='evaluate_model.py',
 | ||
|                         help='基础评估脚本路径')
 | ||
|     parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
 | ||
|                         help='数据目录')
 | ||
|     parser.add_argument('--eval_type', type=str, default='val',
 | ||
|                         help='评估类型')
 | ||
|     parser.add_argument('--gpu_number', type=int, default=0,
 | ||
|                         help='GPU编号')
 | ||
|     
 | ||
|     # 输出控制
 | ||
|     parser.add_argument('--output_file', type=str, default='parameter_search_results.json',
 | ||
|                         help='搜索结果输出文件')
 | ||
|     parser.add_argument('--dry_run', action='store_true',
 | ||
|                         help='只显示搜索空间,不实际运行')
 | ||
|     parser.add_argument('--max_workers', type=int, default=25,
 | ||
|                         help='最大并行工作线程数')
 | ||
|     parser.add_argument('--batch_size', type=int, default=100,
 | ||
|                         help='每批处理的配置数量')
 | ||
|     
 | ||
|     return parser.parse_args()
 | ||
| 
 | ||
| def generate_search_space(args):
 | ||
|     """生成搜索空间"""
 | ||
|     gru_weights = [float(x.strip()) for x in args.gru_weights.split(',')]
 | ||
|     original_weights = [float(x.strip()) for x in args.original_weights.split(',')]
 | ||
|     noise_weights = [float(x.strip()) for x in args.noise_weights.split(',')]
 | ||
|     scale_weights = [float(x.strip()) for x in args.scale_weights.split(',')]
 | ||
|     shift_weights = [float(x.strip()) for x in args.shift_weights.split(',')]
 | ||
|     smooth_weights = [float(x.strip()) for x in args.smooth_weights.split(',')]
 | ||
|     
 | ||
|     search_space = list(product(
 | ||
|         gru_weights, original_weights, noise_weights, 
 | ||
|         scale_weights, shift_weights, smooth_weights
 | ||
|     ))
 | ||
|     
 | ||
|     return search_space
 | ||
| 
 | ||
| def run_single_evaluation(config, args):
 | ||
|     """运行单个配置的评估"""
 | ||
|     gru_w, orig_w, noise_w, scale_w, shift_w, smooth_w = config
 | ||
|     
 | ||
|     # 构建TTA权重字符串
 | ||
|     tta_weights_str = f"{orig_w},{noise_w},{scale_w},{shift_w},{smooth_w}"
 | ||
|     
 | ||
|     # 构建命令
 | ||
|     cmd = [
 | ||
|         'python', args.base_script,
 | ||
|         '--gru_weight', str(gru_w),
 | ||
|         '--tta_weights', tta_weights_str,
 | ||
|         '--data_dir', args.data_dir,
 | ||
|         '--eval_type', args.eval_type,
 | ||
|         '--gpu_number', str(args.gpu_number)
 | ||
|     ]
 | ||
|     
 | ||
|     if args.dry_run:
 | ||
|         print(f"Would run: {' '.join(cmd)}")
 | ||
|         return {
 | ||
|             'config': config,
 | ||
|             'gru_weight': gru_w,
 | ||
|             'tta_weights': {
 | ||
|                 'original': orig_w,
 | ||
|                 'noise': noise_w,
 | ||
|                 'scale': scale_w,
 | ||
|                 'shift': shift_w,
 | ||
|                 'smooth': smooth_w
 | ||
|             },
 | ||
|             'per': np.random.uniform(20, 40),  # 模拟PER结果
 | ||
|             'command': ' '.join(cmd)
 | ||
|         }
 | ||
|     
 | ||
|     # 实际运行命令
 | ||
|     import subprocess
 | ||
|     import tempfile
 | ||
|     
 | ||
|     try:
 | ||
|         # 捕获输出
 | ||
|         result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)  # 增加超时到30分钟
 | ||
|         
 | ||
|         # 解析PER结果
 | ||
|         per = None
 | ||
|         for line in result.stdout.split('\n'):
 | ||
|             if 'Aggregate Phoneme Error Rate (PER):' in line:
 | ||
|                 # 提取百分号前的数字
 | ||
|                 try:
 | ||
|                     per_str = line.split('Aggregate Phoneme Error Rate (PER):')[-1].strip()
 | ||
|                     per_str = per_str.replace('%', '').strip()
 | ||
|                     per = float(per_str)
 | ||
|                     break
 | ||
|                 except (ValueError, IndexError) as e:
 | ||
|                     print(f"Error parsing PER from line: {line}, error: {e}")
 | ||
|                     continue
 | ||
|         
 | ||
|         if per is None:
 | ||
|             print(f"Warning: Could not parse PER from output for config {config}")
 | ||
|             per = float('inf')
 | ||
|         
 | ||
|         return {
 | ||
|             'config': config,
 | ||
|             'gru_weight': gru_w,
 | ||
|             'tta_weights': {
 | ||
|                 'original': orig_w,
 | ||
|                 'noise': noise_w,
 | ||
|                 'scale': scale_w,
 | ||
|                 'shift': shift_w,
 | ||
|                 'smooth': smooth_w
 | ||
|             },
 | ||
|             'per': per,
 | ||
|             'command': ' '.join(cmd),
 | ||
|             'success': result.returncode == 0
 | ||
|         }
 | ||
|         
 | ||
|     except subprocess.TimeoutExpired:
 | ||
|         return {
 | ||
|             'config': config,
 | ||
|             'per': float('inf'),
 | ||
|             'error': 'Timeout',
 | ||
|             'command': ' '.join(cmd)
 | ||
|         }
 | ||
|     except Exception as e:
 | ||
|         return {
 | ||
|             'config': config,
 | ||
|             'per': float('inf'),
 | ||
|             'error': str(e),
 | ||
|             'command': ' '.join(cmd)
 | ||
|         }
 | ||
| 
 | ||
| def main():
 | ||
|     args = parse_arguments()
 | ||
|     
 | ||
|     print("🔍 TTA-E参数搜索")
 | ||
|     print("=" * 50)
 | ||
|     
 | ||
|     # 生成搜索空间
 | ||
|     search_space = generate_search_space(args)
 | ||
|     total_configs = len(search_space)
 | ||
|     
 | ||
|     print(f"搜索空间大小: {total_configs} 个配置")
 | ||
|     print(f"参数范围:")
 | ||
|     print(f"  GRU权重: {args.gru_weights}")
 | ||
|     print(f"  原始权重: {args.original_weights}")
 | ||
|     print(f"  噪声权重: {args.noise_weights}")
 | ||
|     print(f"  缩放权重: {args.scale_weights}")
 | ||
|     print(f"  偏移权重: {args.shift_weights}")
 | ||
|     print(f"  平滑权重: {args.smooth_weights}")
 | ||
|     print()
 | ||
|     
 | ||
|     if args.dry_run:
 | ||
|         print("🧪 Dry run模式 - 显示前5个配置示例:")
 | ||
|         for i, config in enumerate(search_space[:5]):
 | ||
|             result = run_single_evaluation(config, args)
 | ||
|             print(f"{i+1}. {result['command']}")
 | ||
|         print(f"\n总共会运行 {total_configs} 个配置")
 | ||
|         return
 | ||
|     
 | ||
|     # 运行搜索
 | ||
|     print("🚀 开始参数搜索...")
 | ||
|     print(f"使用 {args.max_workers} 个线程并行处理...")
 | ||
|     
 | ||
|     results = []
 | ||
|     best_per = float('inf')
 | ||
|     best_config = None
 | ||
|     completed_count = 0
 | ||
|     
 | ||
|     # 使用线程池并行处理
 | ||
|     with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
 | ||
|         # 提交所有任务
 | ||
|         future_to_config = {
 | ||
|             executor.submit(run_single_evaluation, config, args): config 
 | ||
|             for config in search_space
 | ||
|         }
 | ||
|         
 | ||
|         # 处理完成的任务
 | ||
|         for future in as_completed(future_to_config):
 | ||
|             try:
 | ||
|                 result = future.result()
 | ||
|                 results.append(result)
 | ||
|                 completed_count += 1
 | ||
|                 
 | ||
|                 # 更新最佳结果
 | ||
|                 if result['per'] < best_per:
 | ||
|                     best_per = result['per']
 | ||
|                     best_config = result
 | ||
|                     config = future_to_config[future]
 | ||
|                     print(f"\n🎯 新最优配置[{completed_count}/{total_configs}]: PER={best_per:.3f}%")
 | ||
|                     print(f"   GRU={config[0]:.1f}, TTA=({config[1]},{config[2]},{config[3]},{config[4]},{config[5]})")
 | ||
|                 
 | ||
|                 # 定期进度报告
 | ||
|                 if completed_count % args.batch_size == 0:
 | ||
|                     progress = 100 * completed_count / total_configs
 | ||
|                     print(f"\n📊 进度: {completed_count}/{total_configs} ({progress:.1f}%)")
 | ||
|                     print(f"   当前最优PER: {best_per:.3f}%")
 | ||
|                 elif completed_count % 50 == 0:  # 更频繁的简单进度
 | ||
|                     print(f"... {completed_count}/{total_configs} ...", end='', flush=True)
 | ||
|                 
 | ||
|             except Exception as e:
 | ||
|                 completed_count += 1
 | ||
|                 config = future_to_config[future]
 | ||
|                 print(f"\n❌ 配置失败: {config}, 错误: {e}")
 | ||
|                 results.append({
 | ||
|                     'config': config,
 | ||
|                     'per': float('inf'),
 | ||
|                     'error': str(e)
 | ||
|                 })
 | ||
|     
 | ||
|     print(f"\n✅ 所有任务完成!")
 | ||
|     
 | ||
|     # 找到真正的最佳结果(防止异常情况)
 | ||
|     valid_results = [r for r in results if 'error' not in r and r['per'] != float('inf')]
 | ||
|     if valid_results:
 | ||
|         best_config = min(valid_results, key=lambda x: x['per'])
 | ||
|     
 | ||
|     # 保存结果
 | ||
|     search_results = {
 | ||
|         'best_config': best_config,
 | ||
|         'all_results': results,
 | ||
|         'search_space_size': total_configs,
 | ||
|         'args': vars(args),
 | ||
|         'timestamp': __import__('time').strftime("%Y-%m-%d %H:%M:%S")
 | ||
|     }
 | ||
|     
 | ||
|     with open(args.output_file, 'w') as f:
 | ||
|         json.dump(search_results, f, indent=2)
 | ||
|     
 | ||
|     # 总结结果
 | ||
|     print("\n" + "=" * 50)
 | ||
|     print("🏆 搜索完成!")
 | ||
|     
 | ||
|     if best_config is not None:
 | ||
|         print(f"最佳配置:")
 | ||
|         print(f"  PER: {best_config['per']:.3f}%")
 | ||
|         print(f"  GRU权重: {best_config['gru_weight']:.1f}")
 | ||
|         print(f"  TTA权重: {best_config['tta_weights']}")
 | ||
|         print(f"  命令: {best_config['command']}")
 | ||
|         
 | ||
|         # 显示前10个最佳结果
 | ||
|         sorted_results = sorted([r for r in results if r['per'] != float('inf')], 
 | ||
|                                key=lambda x: x['per'])
 | ||
|         
 | ||
|         print(f"\n📊 前10个最佳配置:")
 | ||
|         print("排名 | PER(%) | GRU | Original | Noise | Scale | Shift | Smooth")
 | ||
|         print("-" * 70)
 | ||
|         for i, result in enumerate(sorted_results[:10]):
 | ||
|             tw = result['tta_weights']
 | ||
|             print(f"{i+1:3d}  | {result['per']:6.3f} | {result['gru_weight']:3.1f} | "
 | ||
|                   f"{tw['original']:8.1f} | {tw['noise']:5.1f} | {tw['scale']:5.1f} | "
 | ||
|                   f"{tw['shift']:5.1f} | {tw['smooth']:6.1f}")
 | ||
|     else:
 | ||
|         print("❌ 未找到有效的配置结果!所有配置都失败了。")
 | ||
|         print("请检查评估脚本是否正常工作。")
 | ||
|     
 | ||
|     print(f"\n📈 搜索统计:")
 | ||
|     print(f"  总配置数: {total_configs}")
 | ||
|     print(f"  成功配置数: {len(valid_results)}")
 | ||
|     print(f"  失败配置数: {total_configs - len(valid_results)}")
 | ||
|     
 | ||
|     if valid_results:
 | ||
|         valid_pers = [r['per'] for r in valid_results]
 | ||
|         print(f"  PER范围: {min(valid_pers):.3f}% - {max(valid_pers):.3f}%")
 | ||
|         print(f"  平均PER: {sum(valid_pers)/len(valid_pers):.3f}%")
 | ||
|     
 | ||
|     print(f"\n✅ 结果已保存到: {args.output_file}")
 | ||
| 
 | ||
| if __name__ == "__main__":
 | ||
|     main() | 
