294 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			294 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | |||
|  | """
 | |||
|  | 分阶段TTA-E参数搜索 | |||
|  | 先粗搜索找到有希望的区域,再精细搜索 | |||
|  | """
 | |||
|  | 
 | |||
|  | import os | |||
|  | import sys | |||
|  | import argparse | |||
|  | import json | |||
|  | import numpy as np | |||
|  | from itertools import product | |||
|  | import subprocess | |||
|  | import time | |||
|  | 
 | |||
|  | def parse_arguments(): | |||
|  |     parser = argparse.ArgumentParser(description='分阶段TTA-E参数搜索') | |||
|  |      | |||
|  |     # 基础参数 | |||
|  |     parser.add_argument('--base_script', type=str, default='evaluate_model.py') | |||
|  |     parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final') | |||
|  |     parser.add_argument('--eval_type', type=str, default='val') | |||
|  |     parser.add_argument('--gpu_number', type=int, default=0) | |||
|  |      | |||
|  |     # 搜索阶段控制 | |||
|  |     parser.add_argument('--stage', type=str, default='coarse', choices=['coarse', 'fine', 'both'], | |||
|  |                         help='搜索阶段:coarse=粗搜索,fine=精细搜索,both=两阶段') | |||
|  |     parser.add_argument('--coarse_results', type=str, default='coarse_results.json', | |||
|  |                         help='粗搜索结果文件(用于精细搜索阶段)') | |||
|  |     parser.add_argument('--final_results', type=str, default='final_results.json', | |||
|  |                         help='最终结果文件') | |||
|  |      | |||
|  |     # 粗搜索参数(步长0.2) | |||
|  |     parser.add_argument('--coarse_gru_weights', type=str, default='0.2,0.4,0.6,0.8,1.0') | |||
|  |     parser.add_argument('--coarse_tta_weights', type=str, default='0.0,0.5,1.0') | |||
|  |      | |||
|  |     # 精细搜索参数(步长0.1,在最佳配置周围) | |||
|  |     parser.add_argument('--fine_range', type=float, default=0.3, | |||
|  |                         help='精细搜索范围(围绕最佳配置的±范围)') | |||
|  |     parser.add_argument('--fine_step', type=float, default=0.1, | |||
|  |                         help='精细搜索步长') | |||
|  |      | |||
|  |     # 筛选控制 | |||
|  |     parser.add_argument('--top_k', type=int, default=5, | |||
|  |                         help='选择前K个最佳配置进行精细搜索') | |||
|  |      | |||
|  |     return parser.parse_args() | |||
|  | 
 | |||
|  | def generate_coarse_search_space(args): | |||
|  |     """生成粗搜索空间""" | |||
|  |     gru_weights = [float(x.strip()) for x in args.coarse_gru_weights.split(',')] | |||
|  |     tta_weights = [float(x.strip()) for x in args.coarse_tta_weights.split(',')] | |||
|  |      | |||
|  |     search_space = [] | |||
|  |     for gru_w in gru_weights: | |||
|  |         for noise_w in tta_weights: | |||
|  |             for scale_w in tta_weights: | |||
|  |                 for shift_w in tta_weights: | |||
|  |                     for smooth_w in tta_weights: | |||
|  |                         search_space.append((gru_w, 1.0, noise_w, scale_w, shift_w, smooth_w)) | |||
|  |      | |||
|  |     return search_space | |||
|  | 
 | |||
|  | def generate_fine_search_space(best_configs, args): | |||
|  |     """基于最佳配置生成精细搜索空间""" | |||
|  |     fine_search_space = [] | |||
|  |      | |||
|  |     for config in best_configs: | |||
|  |         gru_w = config['gru_weight'] | |||
|  |         tta_w = config['tta_weights'] | |||
|  |          | |||
|  |         # 在每个参数周围生成精细搜索点 | |||
|  |         gru_range = np.arange( | |||
|  |             max(0.1, gru_w - args.fine_range), | |||
|  |             min(1.0, gru_w + args.fine_range) + args.fine_step, | |||
|  |             args.fine_step | |||
|  |         ) | |||
|  |          | |||
|  |         for param_name in ['noise', 'scale', 'shift', 'smooth']: | |||
|  |             base_val = tta_w[param_name] | |||
|  |             param_range = np.arange( | |||
|  |                 max(0.0, base_val - args.fine_range), | |||
|  |                 min(1.0, base_val + args.fine_range) + args.fine_step, | |||
|  |                 args.fine_step | |||
|  |             ) | |||
|  |              | |||
|  |             # 围绕当前最佳配置生成邻域 | |||
|  |             for gru_fine in gru_range: | |||
|  |                 for noise_fine in param_range if param_name == 'noise' else [tta_w['noise']]: | |||
|  |                     for scale_fine in param_range if param_name == 'scale' else [tta_w['scale']]: | |||
|  |                         for shift_fine in param_range if param_name == 'shift' else [tta_w['shift']]: | |||
|  |                             for smooth_fine in param_range if param_name == 'smooth' else [tta_w['smooth']]: | |||
|  |                                 config_tuple = ( | |||
|  |                                     round(gru_fine, 1), 1.0, | |||
|  |                                     round(noise_fine, 1), round(scale_fine, 1), | |||
|  |                                     round(shift_fine, 1), round(smooth_fine, 1) | |||
|  |                                 ) | |||
|  |                                 if config_tuple not in fine_search_space: | |||
|  |                                     fine_search_space.append(config_tuple) | |||
|  |      | |||
|  |     return fine_search_space | |||
|  | 
 | |||
|  | def run_evaluation(config, args): | |||
|  |     """运行单个配置的评估""" | |||
|  |     gru_w, orig_w, noise_w, scale_w, shift_w, smooth_w = config | |||
|  |      | |||
|  |     tta_weights_str = f"{orig_w},{noise_w},{scale_w},{shift_w},{smooth_w}" | |||
|  |      | |||
|  |     cmd = [ | |||
|  |         'python', args.base_script, | |||
|  |         '--gru_weight', str(gru_w), | |||
|  |         '--tta_weights', tta_weights_str, | |||
|  |         '--data_dir', args.data_dir, | |||
|  |         '--eval_type', args.eval_type, | |||
|  |         '--gpu_number', str(args.gpu_number) | |||
|  |     ] | |||
|  |      | |||
|  |     try: | |||
|  |         result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)  # 30分钟超时 | |||
|  |          | |||
|  |         # 解析PER结果 | |||
|  |         per = None | |||
|  |         for line in result.stdout.split('\n'): | |||
|  |             if 'Aggregate Phoneme Error Rate (PER):' in line: | |||
|  |                 per_str = line.split(':')[-1].strip().replace('%', '') | |||
|  |                 per = float(per_str) | |||
|  |                 break | |||
|  |          | |||
|  |         if per is None: | |||
|  |             print(f"⚠️  无法解析PER结果: {config}") | |||
|  |             per = float('inf') | |||
|  |          | |||
|  |         return { | |||
|  |             'config': config, | |||
|  |             'gru_weight': gru_w, | |||
|  |             'tta_weights': { | |||
|  |                 'original': orig_w, | |||
|  |                 'noise': noise_w, | |||
|  |                 'scale': scale_w, | |||
|  |                 'shift': shift_w, | |||
|  |                 'smooth': smooth_w | |||
|  |             }, | |||
|  |             'per': per, | |||
|  |             'success': result.returncode == 0, | |||
|  |             'stdout': result.stdout[:1000],  # 只保存前1000字符 | |||
|  |         } | |||
|  |          | |||
|  |     except subprocess.TimeoutExpired: | |||
|  |         return {'config': config, 'per': float('inf'), 'error': 'Timeout'} | |||
|  |     except Exception as e: | |||
|  |         return {'config': config, 'per': float('inf'), 'error': str(e)} | |||
|  | 
 | |||
|  | def run_coarse_search(args): | |||
|  |     """运行粗搜索""" | |||
|  |     print("🔍 第一阶段:粗搜索") | |||
|  |     print("=" * 50) | |||
|  |      | |||
|  |     search_space = generate_coarse_search_space(args) | |||
|  |     total_configs = len(search_space) | |||
|  |     print(f"粗搜索空间: {total_configs} 个配置") | |||
|  |     print(f"GRU权重: {args.coarse_gru_weights}") | |||
|  |     print(f"TTA权重: {args.coarse_tta_weights}") | |||
|  |     print() | |||
|  |      | |||
|  |     results = [] | |||
|  |     best_per = float('inf') | |||
|  |      | |||
|  |     for i, config in enumerate(search_space): | |||
|  |         print(f"进度: {i+1}/{total_configs} ({100*(i+1)/total_configs:.1f}%)") | |||
|  |         print(f"配置: GRU={config[0]:.1f}, TTA=({config[2]},{config[3]},{config[4]},{config[5]})") | |||
|  |          | |||
|  |         result = run_evaluation(config, args) | |||
|  |         results.append(result) | |||
|  |          | |||
|  |         if result['per'] < best_per: | |||
|  |             best_per = result['per'] | |||
|  |             print(f"🎯 新最佳PER: {best_per:.3f}%") | |||
|  |         else: | |||
|  |             print(f"   PER: {result['per']:.3f}%") | |||
|  |      | |||
|  |     # 保存粗搜索结果 | |||
|  |     coarse_results = { | |||
|  |         'results': results, | |||
|  |         'stage': 'coarse', | |||
|  |         'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"), | |||
|  |         'args': vars(args) | |||
|  |     } | |||
|  |      | |||
|  |     with open(args.coarse_results, 'w') as f: | |||
|  |         json.dump(coarse_results, f, indent=2) | |||
|  |      | |||
|  |     # 选择最佳配置 | |||
|  |     valid_results = [r for r in results if r['per'] != float('inf')] | |||
|  |     best_configs = sorted(valid_results, key=lambda x: x['per'])[:args.top_k] | |||
|  |      | |||
|  |     print(f"\n粗搜索完成!选择前{args.top_k}个配置进行精细搜索:") | |||
|  |     for i, config in enumerate(best_configs): | |||
|  |         print(f"{i+1}. PER={config['per']:.3f}% | GRU={config['gru_weight']:.1f} | {config['tta_weights']}") | |||
|  |      | |||
|  |     return best_configs | |||
|  | 
 | |||
|  | def run_fine_search(best_configs, args): | |||
|  |     """运行精细搜索""" | |||
|  |     print(f"\n🔬 第二阶段:精细搜索") | |||
|  |     print("=" * 50) | |||
|  |      | |||
|  |     fine_search_space = generate_fine_search_space(best_configs, args) | |||
|  |     total_configs = len(fine_search_space) | |||
|  |     print(f"精细搜索空间: {total_configs} 个配置") | |||
|  |     print(f"搜索范围: ±{args.fine_range}") | |||
|  |     print(f"搜索步长: {args.fine_step}") | |||
|  |     print() | |||
|  |      | |||
|  |     results = [] | |||
|  |     best_per = float('inf') | |||
|  |      | |||
|  |     for i, config in enumerate(fine_search_space): | |||
|  |         print(f"进度: {i+1}/{total_configs} ({100*(i+1)/total_configs:.1f}%)") | |||
|  |          | |||
|  |         result = run_evaluation(config, args) | |||
|  |         results.append(result) | |||
|  |          | |||
|  |         if result['per'] < best_per: | |||
|  |             best_per = result['per'] | |||
|  |             print(f"🎯 新最佳PER: {best_per:.3f}%") | |||
|  |             print(f"   配置: GRU={result['gru_weight']:.1f} | {result['tta_weights']}") | |||
|  |          | |||
|  |         if i % 10 == 0:  # 每10个配置显示一次进度 | |||
|  |             print(f"   当前PER: {result['per']:.3f}%") | |||
|  |      | |||
|  |     return results | |||
|  | 
 | |||
|  | def main(): | |||
|  |     args = parse_arguments() | |||
|  |      | |||
|  |     print("🚀 分阶段TTA-E参数搜索") | |||
|  |     print("=" * 60) | |||
|  |      | |||
|  |     if args.stage in ['coarse', 'both']: | |||
|  |         # 运行粗搜索 | |||
|  |         best_configs = run_coarse_search(args) | |||
|  |          | |||
|  |         if args.stage == 'coarse': | |||
|  |             print(f"\n✅ 粗搜索完成,结果保存到: {args.coarse_results}") | |||
|  |             return | |||
|  |     else: | |||
|  |         # 从文件加载粗搜索结果 | |||
|  |         print(f"📁 加载粗搜索结果: {args.coarse_results}") | |||
|  |         with open(args.coarse_results, 'r') as f: | |||
|  |             coarse_data = json.load(f) | |||
|  |         valid_results = [r for r in coarse_data['results'] if r['per'] != float('inf')] | |||
|  |         best_configs = sorted(valid_results, key=lambda x: x['per'])[:args.top_k] | |||
|  |      | |||
|  |     if args.stage in ['fine', 'both']: | |||
|  |         # 运行精细搜索 | |||
|  |         fine_results = run_fine_search(best_configs, args) | |||
|  |          | |||
|  |         # 合并所有结果 | |||
|  |         all_results = fine_results | |||
|  |         if args.stage == 'both': | |||
|  |             all_results.extend([r for r in coarse_data['results'] if 'results' in locals()]) | |||
|  |          | |||
|  |         # 找到最终最佳配置 | |||
|  |         valid_results = [r for r in all_results if r['per'] != float('inf')] | |||
|  |         final_best = min(valid_results, key=lambda x: x['per']) | |||
|  |          | |||
|  |         # 保存最终结果 | |||
|  |         final_results = { | |||
|  |             'best_config': final_best, | |||
|  |             'all_fine_results': fine_results, | |||
|  |             'stage': args.stage, | |||
|  |             'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"), | |||
|  |             'args': vars(args) | |||
|  |         } | |||
|  |          | |||
|  |         with open(args.final_results, 'w') as f: | |||
|  |             json.dump(final_results, f, indent=2) | |||
|  |          | |||
|  |         print(f"\n🏆 最终最佳配置:") | |||
|  |         print(f"PER: {final_best['per']:.3f}%") | |||
|  |         print(f"GRU权重: {final_best['gru_weight']:.1f}") | |||
|  |         print(f"TTA权重: {final_best['tta_weights']}") | |||
|  |         print(f"结果保存到: {args.final_results}") | |||
|  |          | |||
|  |         # 显示top-10 | |||
|  |         sorted_results = sorted(valid_results, key=lambda x: x['per'])[:10] | |||
|  |         print(f"\n📊 Top-10配置:") | |||
|  |         for i, result in enumerate(sorted_results): | |||
|  |             tw = result['tta_weights'] | |||
|  |             print(f"{i+1:2d}. PER={result['per']:6.3f}% | GRU={result['gru_weight']:.1f} | " | |||
|  |                   f"TTA=({tw['noise']:.1f},{tw['scale']:.1f},{tw['shift']:.1f},{tw['smooth']:.1f})") | |||
|  | 
 | |||
|  | if __name__ == "__main__": | |||
|  |     main() |