Files
b2txt25/TTA-E/README_GA.md

246 lines
6.0 KiB
Markdown
Raw Normal View History

2025-10-06 15:17:44 +08:00
# TTA-E遗传算法参数优化系统
## 概述
这个系统使用遗传算法(Genetic Algorithm)来优化TTA-ETest Time Augmentation + Ensemble模型的集成参数目标是最小化音素错误率(PER)。
## 核心特性
### 🚀 高效缓存机制
- **预计算策略**: 先生成所有GRU和LSTM在5种增强方式下的预测结果
- **智能缓存**: 避免重复计算,大幅提升遗传算法搜索速度
- **磁盘持久化**: 缓存保存到磁盘,支持断点续传
### 🧬 遗传算法优化
- **PyGAD实现**: 使用成熟的遗传算法库
- **并行处理**: 利用25个CPU核心进行并行评估
- **自适应搜索**: 自动探索最优参数组合
### 🎯 优化参数
1. **GRU权重** (`gru_weight`): [0, 1] - GRU模型在集成中的权重
2. **TTA权重** (`tta_weights`): [0, 5] × 5 - 五种增强方式的权重
- `original`: 原始数据
- `noise`: 噪声增强
- `scale`: 幅度缩放
- `shift`: 时间位移
- `smooth`: 平滑变化
### ⚡ 性能优化
- **GPU加速**: 使用4090 GPU进行模型推理
- **混合精度**: 使用bfloat16减少显存占用
- **多线程**: 并行评估多个参数组合
- **内存管理**: 智能缓存管理,避免内存溢出
## 文件结构
```
TTA-E/
├── GA_optimize.py # 主要优化模块
├── test_ga.py # 快速测试脚本
├── tta_cache/ # 缓存目录
│ ├── gru_predictions.pkl
│ └── lstm_predictions.pkl
├── ga_optimization_result_*.pkl # 优化结果
└── README_GA.md # 本文档
```
## 使用方法
### 1. 快速测试
```bash
conda activate b2txt25
cd /root/autodl-tmp/nejm-brain-to-text/TTA-E
python test_ga.py
```
### 2. 运行完整优化
```bash
conda activate b2txt25
cd /root/autodl-tmp/nejm-brain-to-text/TTA-E
python GA_optimize.py
```
### 3. 自定义参数优化
```python
import GA_optimize
# 创建优化器
optimizer = GA_optimize.TTAEGeneticOptimizer()
# 自定义遗传算法参数
optimizer.population_size = 50 # 种群大小
optimizer.num_generations = 100 # 迭代代数
optimizer.num_parents_mating = 20 # 父代数量
optimizer.mutation_percent_genes = 20 # 变异率
# 运行优化
result = optimizer.optimize()
```
## 核心算法
### 缓存策略
```
对于每个(session, trial, augmentation_type):
1. 应用数据增强
2. 通过GRU和LSTM模型获取预测
3. 缓存概率分布结果
4. 保存到磁盘
```
### 参数评估
```
对于给定的参数组合(gru_weight, tta_weights):
1. 从缓存加载预测结果
2. 按权重融合TTA样本
3. 使用几何平均集成GRU和LSTM
4. 计算音素错误率(PER)
5. 返回-PER作为适应度
```
### 遗传算法流程
```
1. 初始化随机种群
2. 并行评估所有个体的适应度
3. 选择优秀个体作为父代
4. 交叉产生子代
5. 变异引入多样性
6. 重复2-5直到收敛
```
## 优化结果
优化完成后,系统会输出:
```
Best GRU weight: 0.6234
Best LSTM weight: 0.3766
Best TTA weights:
- original: 2.1245
- noise: 1.3456
- scale: 0.7834
- shift: 0.0000
- smooth: 0.5621
Best PER: 18.45%
```
结果保存在 `ga_optimization_result_YYYYMMDD_HHMMSS.pkl` 文件中。
## 性能基准
### 测试环境
- GPU: NVIDIA RTX 4090
- CPU: 25核心
- 内存: 足够容纳所有缓存数据
### 预期性能
- **缓存生成**: ~20-30分钟1426个验证样本 × 5种增强 × 2个模型
- **单次评估**: ~0.1-0.5秒
- **完整优化**: ~2-4小时50个体 × 100代
### 内存使用
- **GPU显存**: ~8-12GB
- **系统内存**: ~16-32GB用于缓存
- **磁盘空间**: ~5-10GB持久化缓存
## 高级功能
### 1. 断点续传
如果优化过程中断,缓存文件会保留,重新启动时会自动加载:
```python
# 系统会自动检测并加载现有缓存
optimizer.generate_all_predictions() # 只生成缺失的预测
```
### 2. 自定义评估函数
可以修改评估函数来优化其他指标:
```python
def custom_fitness_function(self, ga_instance, solution, solution_idx):
# 自定义评估逻辑
# 例如同时考虑PER和推理速度
per = self.evaluate_parameters(solution[0], solution[1:6])
speed_penalty = calculate_speed_penalty(solution)
return -(per + speed_penalty)
```
### 3. 多目标优化
```python
# 可以扩展为多目标优化
# 例如最小化PER的同时最大化推理速度
def multi_objective_fitness(self, solution):
per = self.evaluate_parameters(solution[0], solution[1:6])
speed = self.evaluate_inference_speed(solution)
return [-per, speed] # 返回多个目标值
```
## 故障排除
### 常见问题
1. **显存不足**
```bash
# 解决方案:降低批处理大小或使用梯度累积
export CUDA_VISIBLE_DEVICES=0
```
2. **缓存损坏**
```bash
# 删除缓存重新生成
rm -rf tta_cache/
```
3. **进程被杀死**
```bash
# 检查系统资源
nvidia-smi
htop
```
### 性能调优
1. **减少评估时间**:使用更少的验证样本进行快速测试
2. **增加并行度**:调整 `parallel_processing` 参数
3. **优化内存使用**:使用更小的数据类型或分批处理
## 扩展指南
### 添加新的增强方式
```python
def new_augmentation(self, x):
# 实现新的数据增强
return augmented_x
# 在_apply_augmentation中添加新分支
elif aug_type == 'new_aug':
x_augmented = self.new_augmentation(x_augmented)
```
### 集成其他模型
```python
# 扩展为三模型集成
def load_third_model(self):
# 加载第三个模型
pass
def ensemble_three_models(self, gru_probs, lstm_probs, third_probs, weights):
# 三模型集成逻辑
pass
```
## 引用和致谢
本系统基于以下技术:
- PyGAD: 遗传算法库
- PyTorch: 深度学习框架
- CUDA: GPU加速计算
- NumPy: 数值计算
## 联系信息
如有问题或建议,请查看代码注释或联系开发团队。
---
**版本**: 1.0
**更新日期**: 2025年9月17日
**兼容性**: Python 3.10+, PyTorch 2.0+, CUDA 11.8+