Files
b2txt25/TTA-E/README_GA.md
2025-10-06 15:17:44 +08:00

246 lines
6.0 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# TTA-E遗传算法参数优化系统
## 概述
这个系统使用遗传算法(Genetic Algorithm)来优化TTA-ETest Time Augmentation + Ensemble模型的集成参数目标是最小化音素错误率(PER)。
## 核心特性
### 🚀 高效缓存机制
- **预计算策略**: 先生成所有GRU和LSTM在5种增强方式下的预测结果
- **智能缓存**: 避免重复计算,大幅提升遗传算法搜索速度
- **磁盘持久化**: 缓存保存到磁盘,支持断点续传
### 🧬 遗传算法优化
- **PyGAD实现**: 使用成熟的遗传算法库
- **并行处理**: 利用25个CPU核心进行并行评估
- **自适应搜索**: 自动探索最优参数组合
### 🎯 优化参数
1. **GRU权重** (`gru_weight`): [0, 1] - GRU模型在集成中的权重
2. **TTA权重** (`tta_weights`): [0, 5] × 5 - 五种增强方式的权重
- `original`: 原始数据
- `noise`: 噪声增强
- `scale`: 幅度缩放
- `shift`: 时间位移
- `smooth`: 平滑变化
### ⚡ 性能优化
- **GPU加速**: 使用4090 GPU进行模型推理
- **混合精度**: 使用bfloat16减少显存占用
- **多线程**: 并行评估多个参数组合
- **内存管理**: 智能缓存管理,避免内存溢出
## 文件结构
```
TTA-E/
├── GA_optimize.py # 主要优化模块
├── test_ga.py # 快速测试脚本
├── tta_cache/ # 缓存目录
│ ├── gru_predictions.pkl
│ └── lstm_predictions.pkl
├── ga_optimization_result_*.pkl # 优化结果
└── README_GA.md # 本文档
```
## 使用方法
### 1. 快速测试
```bash
conda activate b2txt25
cd /root/autodl-tmp/nejm-brain-to-text/TTA-E
python test_ga.py
```
### 2. 运行完整优化
```bash
conda activate b2txt25
cd /root/autodl-tmp/nejm-brain-to-text/TTA-E
python GA_optimize.py
```
### 3. 自定义参数优化
```python
import GA_optimize
# 创建优化器
optimizer = GA_optimize.TTAEGeneticOptimizer()
# 自定义遗传算法参数
optimizer.population_size = 50 # 种群大小
optimizer.num_generations = 100 # 迭代代数
optimizer.num_parents_mating = 20 # 父代数量
optimizer.mutation_percent_genes = 20 # 变异率
# 运行优化
result = optimizer.optimize()
```
## 核心算法
### 缓存策略
```
对于每个(session, trial, augmentation_type):
1. 应用数据增强
2. 通过GRU和LSTM模型获取预测
3. 缓存概率分布结果
4. 保存到磁盘
```
### 参数评估
```
对于给定的参数组合(gru_weight, tta_weights):
1. 从缓存加载预测结果
2. 按权重融合TTA样本
3. 使用几何平均集成GRU和LSTM
4. 计算音素错误率(PER)
5. 返回-PER作为适应度
```
### 遗传算法流程
```
1. 初始化随机种群
2. 并行评估所有个体的适应度
3. 选择优秀个体作为父代
4. 交叉产生子代
5. 变异引入多样性
6. 重复2-5直到收敛
```
## 优化结果
优化完成后,系统会输出:
```
Best GRU weight: 0.6234
Best LSTM weight: 0.3766
Best TTA weights:
- original: 2.1245
- noise: 1.3456
- scale: 0.7834
- shift: 0.0000
- smooth: 0.5621
Best PER: 18.45%
```
结果保存在 `ga_optimization_result_YYYYMMDD_HHMMSS.pkl` 文件中。
## 性能基准
### 测试环境
- GPU: NVIDIA RTX 4090
- CPU: 25核心
- 内存: 足够容纳所有缓存数据
### 预期性能
- **缓存生成**: ~20-30分钟1426个验证样本 × 5种增强 × 2个模型
- **单次评估**: ~0.1-0.5秒
- **完整优化**: ~2-4小时50个体 × 100代
### 内存使用
- **GPU显存**: ~8-12GB
- **系统内存**: ~16-32GB用于缓存
- **磁盘空间**: ~5-10GB持久化缓存
## 高级功能
### 1. 断点续传
如果优化过程中断,缓存文件会保留,重新启动时会自动加载:
```python
# 系统会自动检测并加载现有缓存
optimizer.generate_all_predictions() # 只生成缺失的预测
```
### 2. 自定义评估函数
可以修改评估函数来优化其他指标:
```python
def custom_fitness_function(self, ga_instance, solution, solution_idx):
# 自定义评估逻辑
# 例如同时考虑PER和推理速度
per = self.evaluate_parameters(solution[0], solution[1:6])
speed_penalty = calculate_speed_penalty(solution)
return -(per + speed_penalty)
```
### 3. 多目标优化
```python
# 可以扩展为多目标优化
# 例如最小化PER的同时最大化推理速度
def multi_objective_fitness(self, solution):
per = self.evaluate_parameters(solution[0], solution[1:6])
speed = self.evaluate_inference_speed(solution)
return [-per, speed] # 返回多个目标值
```
## 故障排除
### 常见问题
1. **显存不足**
```bash
# 解决方案:降低批处理大小或使用梯度累积
export CUDA_VISIBLE_DEVICES=0
```
2. **缓存损坏**
```bash
# 删除缓存重新生成
rm -rf tta_cache/
```
3. **进程被杀死**
```bash
# 检查系统资源
nvidia-smi
htop
```
### 性能调优
1. **减少评估时间**:使用更少的验证样本进行快速测试
2. **增加并行度**:调整 `parallel_processing` 参数
3. **优化内存使用**:使用更小的数据类型或分批处理
## 扩展指南
### 添加新的增强方式
```python
def new_augmentation(self, x):
# 实现新的数据增强
return augmented_x
# 在_apply_augmentation中添加新分支
elif aug_type == 'new_aug':
x_augmented = self.new_augmentation(x_augmented)
```
### 集成其他模型
```python
# 扩展为三模型集成
def load_third_model(self):
# 加载第三个模型
pass
def ensemble_three_models(self, gru_probs, lstm_probs, third_probs, weights):
# 三模型集成逻辑
pass
```
## 引用和致谢
本系统基于以下技术:
- PyGAD: 遗传算法库
- PyTorch: 深度学习框架
- CUDA: GPU加速计算
- NumPy: 数值计算
## 联系信息
如有问题或建议,请查看代码注释或联系开发团队。
---
**版本**: 1.0
**更新日期**: 2025年9月17日
**兼容性**: Python 3.10+, PyTorch 2.0+, CUDA 11.8+