246 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
			
		
		
	
	
			246 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
| # TTA-E遗传算法参数优化系统
 | ||
| 
 | ||
| ## 概述
 | ||
| 
 | ||
| 这个系统使用遗传算法(Genetic Algorithm)来优化TTA-E(Test Time Augmentation + Ensemble)模型的集成参数,目标是最小化音素错误率(PER)。
 | ||
| 
 | ||
| ## 核心特性
 | ||
| 
 | ||
| ### 🚀 高效缓存机制
 | ||
| - **预计算策略**: 先生成所有GRU和LSTM在5种增强方式下的预测结果
 | ||
| - **智能缓存**: 避免重复计算,大幅提升遗传算法搜索速度
 | ||
| - **磁盘持久化**: 缓存保存到磁盘,支持断点续传
 | ||
| 
 | ||
| ### 🧬 遗传算法优化
 | ||
| - **PyGAD实现**: 使用成熟的遗传算法库
 | ||
| - **并行处理**: 利用25个CPU核心进行并行评估
 | ||
| - **自适应搜索**: 自动探索最优参数组合
 | ||
| 
 | ||
| ### 🎯 优化参数
 | ||
| 1. **GRU权重** (`gru_weight`): [0, 1] - GRU模型在集成中的权重
 | ||
| 2. **TTA权重** (`tta_weights`): [0, 5] × 5 - 五种增强方式的权重
 | ||
|    - `original`: 原始数据
 | ||
|    - `noise`: 噪声增强  
 | ||
|    - `scale`: 幅度缩放
 | ||
|    - `shift`: 时间位移
 | ||
|    - `smooth`: 平滑变化
 | ||
| 
 | ||
| ### ⚡ 性能优化
 | ||
| - **GPU加速**: 使用4090 GPU进行模型推理
 | ||
| - **混合精度**: 使用bfloat16减少显存占用
 | ||
| - **多线程**: 并行评估多个参数组合
 | ||
| - **内存管理**: 智能缓存管理,避免内存溢出
 | ||
| 
 | ||
| ## 文件结构
 | ||
| 
 | ||
| ```
 | ||
| TTA-E/
 | ||
| ├── GA_optimize.py          # 主要优化模块
 | ||
| ├── test_ga.py             # 快速测试脚本
 | ||
| ├── tta_cache/             # 缓存目录
 | ||
| │   ├── gru_predictions.pkl
 | ||
| │   └── lstm_predictions.pkl
 | ||
| ├── ga_optimization_result_*.pkl  # 优化结果
 | ||
| └── README_GA.md           # 本文档
 | ||
| ```
 | ||
| 
 | ||
| ## 使用方法
 | ||
| 
 | ||
| ### 1. 快速测试
 | ||
| ```bash
 | ||
| conda activate b2txt25
 | ||
| cd /root/autodl-tmp/nejm-brain-to-text/TTA-E
 | ||
| python test_ga.py
 | ||
| ```
 | ||
| 
 | ||
| ### 2. 运行完整优化
 | ||
| ```bash
 | ||
| conda activate b2txt25
 | ||
| cd /root/autodl-tmp/nejm-brain-to-text/TTA-E
 | ||
| python GA_optimize.py
 | ||
| ```
 | ||
| 
 | ||
| ### 3. 自定义参数优化
 | ||
| ```python
 | ||
| import GA_optimize
 | ||
| 
 | ||
| # 创建优化器
 | ||
| optimizer = GA_optimize.TTAEGeneticOptimizer()
 | ||
| 
 | ||
| # 自定义遗传算法参数
 | ||
| optimizer.population_size = 50      # 种群大小
 | ||
| optimizer.num_generations = 100     # 迭代代数
 | ||
| optimizer.num_parents_mating = 20   # 父代数量
 | ||
| optimizer.mutation_percent_genes = 20  # 变异率
 | ||
| 
 | ||
| # 运行优化
 | ||
| result = optimizer.optimize()
 | ||
| ```
 | ||
| 
 | ||
| ## 核心算法
 | ||
| 
 | ||
| ### 缓存策略
 | ||
| ```
 | ||
| 对于每个(session, trial, augmentation_type):
 | ||
|     1. 应用数据增强
 | ||
|     2. 通过GRU和LSTM模型获取预测
 | ||
|     3. 缓存概率分布结果
 | ||
|     4. 保存到磁盘
 | ||
| ```
 | ||
| 
 | ||
| ### 参数评估
 | ||
| ```
 | ||
| 对于给定的参数组合(gru_weight, tta_weights):
 | ||
|     1. 从缓存加载预测结果
 | ||
|     2. 按权重融合TTA样本
 | ||
|     3. 使用几何平均集成GRU和LSTM
 | ||
|     4. 计算音素错误率(PER)
 | ||
|     5. 返回-PER作为适应度
 | ||
| ```
 | ||
| 
 | ||
| ### 遗传算法流程
 | ||
| ```
 | ||
| 1. 初始化随机种群
 | ||
| 2. 并行评估所有个体的适应度
 | ||
| 3. 选择优秀个体作为父代
 | ||
| 4. 交叉产生子代
 | ||
| 5. 变异引入多样性
 | ||
| 6. 重复2-5直到收敛
 | ||
| ```
 | ||
| 
 | ||
| ## 优化结果
 | ||
| 
 | ||
| 优化完成后,系统会输出:
 | ||
| 
 | ||
| ```
 | ||
| Best GRU weight: 0.6234
 | ||
| Best LSTM weight: 0.3766
 | ||
| Best TTA weights:
 | ||
|   - original: 2.1245
 | ||
|   - noise: 1.3456
 | ||
|   - scale: 0.7834
 | ||
|   - shift: 0.0000
 | ||
|   - smooth: 0.5621
 | ||
| Best PER: 18.45%
 | ||
| ```
 | ||
| 
 | ||
| 结果保存在 `ga_optimization_result_YYYYMMDD_HHMMSS.pkl` 文件中。
 | ||
| 
 | ||
| ## 性能基准
 | ||
| 
 | ||
| ### 测试环境
 | ||
| - GPU: NVIDIA RTX 4090
 | ||
| - CPU: 25核心
 | ||
| - 内存: 足够容纳所有缓存数据
 | ||
| 
 | ||
| ### 预期性能
 | ||
| - **缓存生成**: ~20-30分钟(1426个验证样本 × 5种增强 × 2个模型)
 | ||
| - **单次评估**: ~0.1-0.5秒
 | ||
| - **完整优化**: ~2-4小时(50个体 × 100代)
 | ||
| 
 | ||
| ### 内存使用
 | ||
| - **GPU显存**: ~8-12GB
 | ||
| - **系统内存**: ~16-32GB(用于缓存)
 | ||
| - **磁盘空间**: ~5-10GB(持久化缓存)
 | ||
| 
 | ||
| ## 高级功能
 | ||
| 
 | ||
| ### 1. 断点续传
 | ||
| 如果优化过程中断,缓存文件会保留,重新启动时会自动加载:
 | ||
| ```python
 | ||
| # 系统会自动检测并加载现有缓存
 | ||
| optimizer.generate_all_predictions()  # 只生成缺失的预测
 | ||
| ```
 | ||
| 
 | ||
| ### 2. 自定义评估函数
 | ||
| 可以修改评估函数来优化其他指标:
 | ||
| ```python
 | ||
| def custom_fitness_function(self, ga_instance, solution, solution_idx):
 | ||
|     # 自定义评估逻辑
 | ||
|     # 例如:同时考虑PER和推理速度
 | ||
|     per = self.evaluate_parameters(solution[0], solution[1:6])
 | ||
|     speed_penalty = calculate_speed_penalty(solution)
 | ||
|     return -(per + speed_penalty)
 | ||
| ```
 | ||
| 
 | ||
| ### 3. 多目标优化
 | ||
| ```python
 | ||
| # 可以扩展为多目标优化
 | ||
| # 例如:最小化PER的同时最大化推理速度
 | ||
| def multi_objective_fitness(self, solution):
 | ||
|     per = self.evaluate_parameters(solution[0], solution[1:6])
 | ||
|     speed = self.evaluate_inference_speed(solution)
 | ||
|     return [-per, speed]  # 返回多个目标值
 | ||
| ```
 | ||
| 
 | ||
| ## 故障排除
 | ||
| 
 | ||
| ### 常见问题
 | ||
| 
 | ||
| 1. **显存不足**
 | ||
|    ```bash
 | ||
|    # 解决方案:降低批处理大小或使用梯度累积
 | ||
|    export CUDA_VISIBLE_DEVICES=0
 | ||
|    ```
 | ||
| 
 | ||
| 2. **缓存损坏**
 | ||
|    ```bash
 | ||
|    # 删除缓存重新生成
 | ||
|    rm -rf tta_cache/
 | ||
|    ```
 | ||
| 
 | ||
| 3. **进程被杀死**
 | ||
|    ```bash
 | ||
|    # 检查系统资源
 | ||
|    nvidia-smi
 | ||
|    htop
 | ||
|    ```
 | ||
| 
 | ||
| ### 性能调优
 | ||
| 
 | ||
| 1. **减少评估时间**:使用更少的验证样本进行快速测试
 | ||
| 2. **增加并行度**:调整 `parallel_processing` 参数
 | ||
| 3. **优化内存使用**:使用更小的数据类型或分批处理
 | ||
| 
 | ||
| ## 扩展指南
 | ||
| 
 | ||
| ### 添加新的增强方式
 | ||
| ```python
 | ||
| def new_augmentation(self, x):
 | ||
|     # 实现新的数据增强
 | ||
|     return augmented_x
 | ||
| 
 | ||
| # 在_apply_augmentation中添加新分支
 | ||
| elif aug_type == 'new_aug':
 | ||
|     x_augmented = self.new_augmentation(x_augmented)
 | ||
| ```
 | ||
| 
 | ||
| ### 集成其他模型
 | ||
| ```python
 | ||
| # 扩展为三模型集成
 | ||
| def load_third_model(self):
 | ||
|     # 加载第三个模型
 | ||
|     pass
 | ||
| 
 | ||
| def ensemble_three_models(self, gru_probs, lstm_probs, third_probs, weights):
 | ||
|     # 三模型集成逻辑
 | ||
|     pass
 | ||
| ```
 | ||
| 
 | ||
| ## 引用和致谢
 | ||
| 
 | ||
| 本系统基于以下技术:
 | ||
| - PyGAD: 遗传算法库
 | ||
| - PyTorch: 深度学习框架  
 | ||
| - CUDA: GPU加速计算
 | ||
| - NumPy: 数值计算
 | ||
| 
 | ||
| ## 联系信息
 | ||
| 
 | ||
| 如有问题或建议,请查看代码注释或联系开发团队。
 | ||
| 
 | ||
| ---
 | ||
| 
 | ||
| **版本**: 1.0  
 | ||
| **更新日期**: 2025年9月17日  
 | ||
| **兼容性**: Python 3.10+, PyTorch 2.0+, CUDA 11.8+ | 
