231 lines
8.6 KiB
Markdown
231 lines
8.6 KiB
Markdown
# CLAUDE.md
|
||
|
||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||
|
||
## Project Overview
|
||
|
||
This repository contains the code and data for "An Accurate and Rapidly Calibrating Speech Neuroprosthesis" published in the New England Journal of Medicine (2024). It implements a brain-to-text system that converts neural signals from speech motor cortex into text using RNN models and n-gram language models.
|
||
|
||
## Development Environment Setup
|
||
|
||
### Main Environment (b2txt25)
|
||
```bash
|
||
./setup.sh
|
||
conda activate b2txt25
|
||
```
|
||
|
||
### Language Model Environment (b2txt25_lm)
|
||
```bash
|
||
./setup_lm.sh
|
||
conda activate b2txt25_lm
|
||
```
|
||
|
||
**Important**: The project requires two separate conda environments due to conflicting PyTorch versions:
|
||
- `b2txt25`: PyTorch with CUDA 12.6 for model training/evaluation
|
||
- `b2txt25_lm`: PyTorch 1.13.1 for Kaldi-based n-gram language models
|
||
|
||
### Redis Setup
|
||
Redis is required for inter-process communication. Install on Ubuntu:
|
||
```bash
|
||
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
|
||
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
|
||
sudo apt-get update && sudo apt-get install redis
|
||
sudo systemctl disable redis-server
|
||
```
|
||
|
||
## Architecture Overview
|
||
|
||
### High-Level System Flow
|
||
1. **Neural Data Input**: 512 features (2 per electrode × 256 electrodes) binned at 20ms resolution
|
||
2. **RNN Model**: Converts neural features to phoneme logits via CTC loss
|
||
3. **Language Model**: Decodes phoneme logits to words using n-gram models + OPT rescoring
|
||
4. **Redis Communication**: Coordinates between RNN inference and language model processes
|
||
|
||
### Key Components
|
||
|
||
#### Model Training (`model_training/`)
|
||
- **Core Script**: `train_model.py` (loads config from `rnn_args.yaml`)
|
||
- **Model Architecture**: `rnn_model.py` - 5-layer GRU with 768 hidden units
|
||
- **Trainer**: `rnn_trainer.py` - Custom PyTorch trainer with CTC loss
|
||
- **Evaluation**: `evaluate_model.py` - Inference pipeline with Redis communication
|
||
|
||
#### Language Model (`language_model/`)
|
||
- **Standalone Server**: `language-model-standalone.py` - Redis-based LM server
|
||
- **Kaldi Integration**: Uses custom C++ bindings for efficient n-gram decoding
|
||
- **OPT Rescoring**: Facebook OPT 6.7B for language model rescoring
|
||
- **Build System**: Complex CMake-based build for Kaldi/SRILM integration
|
||
|
||
#### Utilities (`nejm_b2txt_utils/`)
|
||
- **General Utils**: `general_utils.py` - Shared utility functions
|
||
- **Package**: Installed via `setup.py` as `nejm_b2txt_utils`
|
||
|
||
#### Analysis (`analyses/`)
|
||
- **Jupyter Notebooks**: `figure_2.ipynb`, `figure_4.ipynb` for paper figures
|
||
|
||
## Common Development Tasks
|
||
|
||
### Training a Model
|
||
```bash
|
||
conda activate b2txt25
|
||
cd model_training
|
||
python train_model.py
|
||
```
|
||
|
||
### Running Evaluation Pipeline
|
||
1. Start Redis server:
|
||
```bash
|
||
redis-server
|
||
```
|
||
|
||
2. Start language model (separate terminal):
|
||
```bash
|
||
conda activate b2txt25_lm
|
||
python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
|
||
```
|
||
|
||
3. Run evaluation (separate terminal):
|
||
```bash
|
||
conda activate b2txt25
|
||
cd model_training
|
||
python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/hdf5_data_final --eval_type test --gpu_number 1
|
||
```
|
||
|
||
4. Shutdown Redis:
|
||
```bash
|
||
redis-cli shutdown
|
||
```
|
||
|
||
### Building Language Model from Scratch
|
||
```bash
|
||
# Build SRILM (in language_model/srilm-1.7.3/)
|
||
export SRILM=$PWD
|
||
make MAKE_PIC=yes World
|
||
|
||
# Build Kaldi components (in language_model/runtime/server/x86/)
|
||
mkdir build && cd build
|
||
cmake .. && make -j8
|
||
```
|
||
|
||
## Data Structure
|
||
|
||
### Neural Data Format
|
||
- **File Type**: HDF5 files in `data/hdf5_data_final/`
|
||
- **Features**: 512 neural features per 20ms bin:
|
||
- 0-64: ventral 6v threshold crossings
|
||
- 65-128: area 4 threshold crossings
|
||
- 129-192: 55b threshold crossings
|
||
- 193-256: dorsal 6v threshold crossings
|
||
- 257-320: ventral 6v spike band power
|
||
- 321-384: area 4 spike band power
|
||
- 385-448: 55b spike band power
|
||
- 449-512: dorsal 6v spike band power
|
||
|
||
### Data Loading
|
||
Use `load_h5py_file()` in `model_training/evaluate_model_helpers.py` as reference for HDF5 data loading.
|
||
|
||
## Important Notes
|
||
|
||
- **GPU Requirements**: OPT 6.7B requires ~12.4GB VRAM; RTX 4090s recommended
|
||
- **Memory Requirements**: 3-gram LM needs ~60GB RAM, 5-gram needs ~300GB RAM
|
||
- **Environment Isolation**: Always use correct conda environment for each component
|
||
- **Redis Dependency**: Many scripts require Redis server to be running
|
||
- **Build Dependencies**: CMake ≥3.14 and GCC ≥10.1 required for language model builds
|
||
|
||
## XLA Optimizations (TPU-Friendly Model)
|
||
|
||
The RNN model has been optimized for XLA compilation and TPU training while preserving the original model architecture. These optimizations improve compilation speed and reduce memory usage on TPUs.
|
||
|
||
### Applied XLA Optimizations
|
||
|
||
#### 1. Dynamic Shape Operations → Static Operations
|
||
**Problem**: XLA compiler struggles with dynamic tensor shapes and indexing
|
||
**Solution**: Replace dynamic operations with XLA-friendly alternatives
|
||
|
||
```python
|
||
# Before (XLA-unfriendly):
|
||
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
|
||
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
|
||
|
||
# After (XLA-friendly):
|
||
all_day_weights = torch.stack(list(self.day_weights), dim=0) # Static stack
|
||
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0)
|
||
day_weights = torch.index_select(all_day_weights, 0, day_idx) # Static gather
|
||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
||
```
|
||
|
||
#### 2. Matrix Operations → XLA Primitives
|
||
**Problem**: Complex einsum operations are less optimized than native XLA ops
|
||
**Solution**: Use batch matrix multiplication (bmm) for better XLA performance
|
||
|
||
```python
|
||
# Before:
|
||
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
||
|
||
# After (XLA-optimized):
|
||
x = torch.bmm(x, day_weights) + day_biases # bmm is highly optimized in XLA
|
||
```
|
||
|
||
#### 3. Hidden State Initialization
|
||
**Problem**: Dynamic batch size allocation causes XLA recompilation
|
||
**Solution**: Use static shapes and avoid x.shape[0] in tensor creation
|
||
|
||
```python
|
||
# Before:
|
||
if states is None:
|
||
states = self.h0.expand(2, x.shape[0], self.input_size).contiguous()
|
||
|
||
# After (XLA-friendly):
|
||
batch_size = x.size(0) # Extract once
|
||
if states is None:
|
||
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
|
||
```
|
||
|
||
#### 4. Return Value Optimization
|
||
**Problem**: Complex dictionary returns cause XLA compilation issues
|
||
**Solution**: Use tuples instead of dictionaries for cleaner XLA graphs
|
||
|
||
```python
|
||
# Before (XLA-unfriendly):
|
||
return {
|
||
'clean_logits': clean_logits,
|
||
'noisy_logits': noisy_logits,
|
||
'noise_output': noise_output
|
||
}
|
||
|
||
# After (XLA-friendly):
|
||
return clean_logits, noisy_logits, noise_output # Simple tuple return
|
||
```
|
||
|
||
### Files Modified for XLA Optimization
|
||
|
||
- **`model_training_nnn/rnn_model.py`**: All three models optimized
|
||
- `NoiseModel.forward()`: Dynamic indexing → static gather operations
|
||
- `CleanSpeechModel.forward()`: Same optimizations + bmm for matrix ops
|
||
- `NoisySpeechModel.forward()`: Hidden state optimization
|
||
- `TripleGRUDecoder.forward()`: Complex return values → tuple returns
|
||
- `TripleGRUDecoder._apply_preprocessing()`: Static preprocessing operations
|
||
|
||
### Benefits of XLA Optimizations
|
||
|
||
1. **Faster Compilation**: Static shapes allow XLA to pre-compile optimized kernels
|
||
2. **Better Memory Usage**: Reduced dynamic allocation during training
|
||
3. **Improved TPU Utilization**: XLA primitives map directly to TPU matrix units
|
||
4. **Consistent Performance**: Eliminates recompilation caused by dynamic shapes
|
||
|
||
### Testing and Validation
|
||
|
||
Created test scripts to verify model consistency:
|
||
- **`test_xla_model.py`**: Comprehensive model validation testing
|
||
- **`quick_test_xla.py`**: Fast verification of basic functionality
|
||
|
||
**Important**: These optimizations preserve the exact model architecture and mathematical operations. Only the implementation has been made XLA-friendly.
|
||
|
||
### Usage Notes
|
||
|
||
- All original model interfaces remain unchanged
|
||
- Both 'inference' and 'full' modes are supported
|
||
- Backward compatibility with existing training scripts is maintained
|
||
- TPU training should now show improved compilation times and memory efficiency
|
||
|
||
## Competition Context
|
||
This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding. |