339 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
			
		
		
	
	
			339 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
| # CLAUDE.md
 | ||
| 
 | ||
| This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
 | ||
| 
 | ||
| ## Project Overview
 | ||
| 
 | ||
| This repository contains the code and data for "An Accurate and Rapidly Calibrating Speech Neuroprosthesis" published in the New England Journal of Medicine (2024). It implements a brain-to-text system that converts neural signals from speech motor cortex into text using RNN models and n-gram language models.
 | ||
| 
 | ||
| ## Development Environment Setup
 | ||
| 
 | ||
| ### Main Environment (b2txt25)
 | ||
| ```bash
 | ||
| ./setup.sh
 | ||
| conda activate b2txt25
 | ||
| ```
 | ||
| 
 | ||
| ### Language Model Environment (b2txt25_lm)
 | ||
| ```bash
 | ||
| ./setup_lm.sh
 | ||
| conda activate b2txt25_lm
 | ||
| ```
 | ||
| 
 | ||
| **Important**: The project requires two separate conda environments due to conflicting PyTorch versions:
 | ||
| - `b2txt25`: PyTorch with CUDA 12.6 for model training/evaluation
 | ||
| - `b2txt25_lm`: PyTorch 1.13.1 for Kaldi-based n-gram language models
 | ||
| 
 | ||
| ### Redis Setup
 | ||
| Redis is required for inter-process communication. Install on Ubuntu:
 | ||
| ```bash
 | ||
| curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
 | ||
| echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
 | ||
| sudo apt-get update && sudo apt-get install redis
 | ||
| sudo systemctl disable redis-server
 | ||
| ```
 | ||
| 
 | ||
| ## Architecture Overview
 | ||
| 
 | ||
| ### High-Level System Flow
 | ||
| 1. **Neural Data Input**: 512 features (2 per electrode × 256 electrodes) binned at 20ms resolution
 | ||
| 2. **RNN Model**: Converts neural features to phoneme logits via CTC loss
 | ||
| 3. **Language Model**: Decodes phoneme logits to words using n-gram models + OPT rescoring
 | ||
| 4. **Redis Communication**: Coordinates between RNN inference and language model processes
 | ||
| 
 | ||
| ### Key Components
 | ||
| 
 | ||
| #### Model Training (`model_training/`)
 | ||
| - **Core Script**: `train_model.py` (loads config from `rnn_args.yaml`)
 | ||
| - **Model Architecture**: `rnn_model.py` - 5-layer GRU with 768 hidden units
 | ||
| - **Trainer**: `rnn_trainer.py` - Custom PyTorch trainer with CTC loss
 | ||
| - **Evaluation**: `evaluate_model.py` - Inference pipeline with Redis communication
 | ||
| 
 | ||
| #### Language Model (`language_model/`)
 | ||
| - **Standalone Server**: `language-model-standalone.py` - Redis-based LM server
 | ||
| - **Kaldi Integration**: Uses custom C++ bindings for efficient n-gram decoding
 | ||
| - **OPT Rescoring**: Facebook OPT 6.7B for language model rescoring
 | ||
| - **Build System**: Complex CMake-based build for Kaldi/SRILM integration
 | ||
| 
 | ||
| #### Utilities (`nejm_b2txt_utils/`)
 | ||
| - **General Utils**: `general_utils.py` - Shared utility functions
 | ||
| - **Package**: Installed via `setup.py` as `nejm_b2txt_utils`
 | ||
| 
 | ||
| #### Analysis (`analyses/`)
 | ||
| - **Jupyter Notebooks**: `figure_2.ipynb`, `figure_4.ipynb` for paper figures
 | ||
| 
 | ||
| ## Common Development Tasks
 | ||
| 
 | ||
| ### Training a Model
 | ||
| ```bash
 | ||
| conda activate b2txt25
 | ||
| cd model_training
 | ||
| python train_model.py
 | ||
| ```
 | ||
| 
 | ||
| ### Running Evaluation Pipeline
 | ||
| 1. Start Redis server:
 | ||
|    ```bash
 | ||
|    redis-server
 | ||
|    ```
 | ||
| 
 | ||
| 2. Start language model (separate terminal):
 | ||
|    ```bash
 | ||
|    conda activate b2txt25_lm
 | ||
|    python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
 | ||
|    ```
 | ||
| 
 | ||
| 3. Run evaluation (separate terminal):
 | ||
|    ```bash
 | ||
|    conda activate b2txt25
 | ||
|    cd model_training
 | ||
|    python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/hdf5_data_final --eval_type test --gpu_number 1
 | ||
|    ```
 | ||
| 
 | ||
| 4. Shutdown Redis:
 | ||
|    ```bash
 | ||
|    redis-cli shutdown
 | ||
|    ```
 | ||
| 
 | ||
| ### Building Language Model from Scratch
 | ||
| ```bash
 | ||
| # Build SRILM (in language_model/srilm-1.7.3/)
 | ||
| export SRILM=$PWD
 | ||
| make MAKE_PIC=yes World
 | ||
| 
 | ||
| # Build Kaldi components (in language_model/runtime/server/x86/)
 | ||
| mkdir build && cd build
 | ||
| cmake .. && make -j8
 | ||
| ```
 | ||
| 
 | ||
| ## Data Structure
 | ||
| 
 | ||
| ### Neural Data Format
 | ||
| - **File Type**: HDF5 files in `data/hdf5_data_final/`
 | ||
| - **Features**: 512 neural features per 20ms bin:
 | ||
|   - 0-64: ventral 6v threshold crossings
 | ||
|   - 65-128: area 4 threshold crossings
 | ||
|   - 129-192: 55b threshold crossings
 | ||
|   - 193-256: dorsal 6v threshold crossings
 | ||
|   - 257-320: ventral 6v spike band power
 | ||
|   - 321-384: area 4 spike band power
 | ||
|   - 385-448: 55b spike band power
 | ||
|   - 449-512: dorsal 6v spike band power
 | ||
| 
 | ||
| ### Data Loading
 | ||
| Use `load_h5py_file()` in `model_training/evaluate_model_helpers.py` as reference for HDF5 data loading.
 | ||
| 
 | ||
| ## Important Notes
 | ||
| 
 | ||
| - **GPU Requirements**: OPT 6.7B requires ~12.4GB VRAM; RTX 4090s recommended
 | ||
| - **Memory Requirements**: 3-gram LM needs ~60GB RAM, 5-gram needs ~300GB RAM
 | ||
| - **Environment Isolation**: Always use correct conda environment for each component
 | ||
| - **Redis Dependency**: Many scripts require Redis server to be running
 | ||
| - **Build Dependencies**: CMake ≥3.14 and GCC ≥10.1 required for language model builds
 | ||
| 
 | ||
| ## XLA Optimizations (TPU-Friendly Model)
 | ||
| 
 | ||
| The RNN model has been optimized for XLA compilation and TPU training while preserving the original model architecture. These optimizations improve compilation speed and reduce memory usage on TPUs.
 | ||
| 
 | ||
| ### Applied XLA Optimizations
 | ||
| 
 | ||
| #### 1. Dynamic Shape Operations → Static Operations
 | ||
| **Problem**: XLA compiler struggles with dynamic tensor shapes and indexing
 | ||
| **Solution**: Replace dynamic operations with XLA-friendly alternatives
 | ||
| 
 | ||
| ```python
 | ||
| # Before (XLA-unfriendly):
 | ||
| day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
 | ||
| day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
 | ||
| 
 | ||
| # After (XLA-friendly):
 | ||
| all_day_weights = torch.stack(list(self.day_weights), dim=0)  # Static stack
 | ||
| all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0)
 | ||
| day_weights = torch.index_select(all_day_weights, 0, day_idx)  # Static gather
 | ||
| day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
 | ||
| ```
 | ||
| 
 | ||
| #### 2. Matrix Operations → XLA Primitives
 | ||
| **Problem**: Complex einsum operations are less optimized than native XLA ops
 | ||
| **Solution**: Use batch matrix multiplication (bmm) for better XLA performance
 | ||
| 
 | ||
| ```python
 | ||
| # Before:
 | ||
| x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
 | ||
| 
 | ||
| # After (XLA-optimized):
 | ||
| x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)  # bmm + dtype consistency
 | ||
| ```
 | ||
| 
 | ||
| #### 5. Mixed Precision Dtype Consistency (Comprehensive Fix)
 | ||
| **Problem**: Mixed precision training causes dtype mismatches throughout the adversarial training pipeline
 | ||
| **Error**: `Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168]`
 | ||
| 
 | ||
| **Root Cause Analysis**: The error occurred at dimension 7168 = 512 * 14, indicating patch processing with patch_size=14. The dtype mismatch cascaded through multiple layers:
 | ||
| 1. Initial bmm operations in day-specific transformations
 | ||
| 2. Adversarial training residual connections between models
 | ||
| 3. Patch processing operations (unfold, permute, reshape)
 | ||
| 4. Gradient Reversal Layer (GRL) operations
 | ||
| 5. Hidden state initialization in adversarial training helper methods
 | ||
| 
 | ||
| **Comprehensive Solution**: Implement dtype consistency across the entire adversarial training data flow:
 | ||
| 
 | ||
| ```python
 | ||
| # Fix 1: Basic bmm operations with dtype consistency
 | ||
| x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
 | ||
| 
 | ||
| # Fix 2: Patch processing with explicit dtype preservation
 | ||
| if self.patch_size > 0:
 | ||
|     original_dtype = x.dtype  # Preserve original dtype for XLA/TPU compatibility
 | ||
|     x = x.unsqueeze(1)
 | ||
|     x = x.permute(0, 3, 1, 2)
 | ||
|     x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
 | ||
|     x_unfold = x_unfold.squeeze(2)
 | ||
|     x_unfold = x_unfold.permute(0, 2, 3, 1)
 | ||
|     x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
 | ||
|     # Ensure dtype consistency after patch processing operations
 | ||
|     x = x.to(original_dtype)
 | ||
| 
 | ||
| # Fix 3: Adversarial training residual connections
 | ||
| noise_output = noise_output.to(x_processed.dtype)
 | ||
| denoised_input = x_processed - noise_output
 | ||
| 
 | ||
| # Fix 4: Gradient Reversal Layer dtype handling
 | ||
| noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda else noise_output
 | ||
| # Ensure dtype consistency after GRL (preserves input dtype but explicit check)
 | ||
| noisy_input = noisy_input.to(x_processed.dtype)
 | ||
| 
 | ||
| # Fix 5: Hidden state dtype consistency in helper methods
 | ||
| # In _clean_forward_with_processed_input:
 | ||
| if states is None:
 | ||
|     states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
 | ||
|     # Ensure hidden states match input dtype for mixed precision training
 | ||
|     states = states.to(x_processed.dtype)
 | ||
| 
 | ||
| # In _noisy_forward_with_processed_input:
 | ||
| if states is None:
 | ||
|     states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
 | ||
|     # Ensure hidden states match input dtype for mixed precision training
 | ||
|     states = states.to(x_processed.dtype)
 | ||
| ```
 | ||
| 
 | ||
| **Key Implementation Details**:
 | ||
| - **GradientReversalFn**: Preserves input dtype automatically (identity forward, gradient reversal backward)
 | ||
| - **Patch Processing**: Explicit dtype preservation prevents unfold operations from changing precision
 | ||
| - **Residual Connections**: All tensor arithmetic operations ensure matching dtypes
 | ||
| - **Helper Methods**: Hidden state initialization matches processed input dtype
 | ||
| - **Data Flow**: NoiseModel → GRL → NoisySpeechModel maintains dtype consistency throughout
 | ||
| 
 | ||
| #### 3. Hidden State Initialization
 | ||
| **Problem**: Dynamic batch size allocation causes XLA recompilation
 | ||
| **Solution**: Use static shapes and avoid x.shape[0] in tensor creation
 | ||
| 
 | ||
| ```python
 | ||
| # Before:
 | ||
| if states is None:
 | ||
|     states = self.h0.expand(2, x.shape[0], self.input_size).contiguous()
 | ||
| 
 | ||
| # After (XLA-friendly):
 | ||
| batch_size = x.size(0)  # Extract once
 | ||
| if states is None:
 | ||
|     states = self.h0.expand(2, batch_size, self.input_size).contiguous()
 | ||
| ```
 | ||
| 
 | ||
| #### 4. Return Value Optimization
 | ||
| **Problem**: Complex dictionary returns cause XLA compilation issues
 | ||
| **Solution**: Use tuples instead of dictionaries for cleaner XLA graphs
 | ||
| 
 | ||
| ```python
 | ||
| # Before (XLA-unfriendly):
 | ||
| return {
 | ||
|     'clean_logits': clean_logits,
 | ||
|     'noisy_logits': noisy_logits,
 | ||
|     'noise_output': noise_output
 | ||
| }
 | ||
| 
 | ||
| # After (XLA-friendly):
 | ||
| return clean_logits, noisy_logits, noise_output  # Simple tuple return
 | ||
| ```
 | ||
| 
 | ||
| ### Files Modified for XLA Optimization
 | ||
| 
 | ||
| - **`model_training_nnn/rnn_model.py`**: Comprehensive XLA optimization with dtype consistency
 | ||
|   - **`GradientReversalFn`**: Added adversarial training gradient reversal layer
 | ||
|   - **`NoiseModel.forward()`**: Dynamic indexing → static gather operations + comprehensive dtype consistency + patch processing dtype preservation
 | ||
|   - **`CleanSpeechModel.forward()`**: Same optimizations + bmm for matrix ops + comprehensive dtype consistency + patch processing dtype preservation
 | ||
|   - **`NoisySpeechModel.forward()`**: Hidden state optimization (no day layers, simplified)
 | ||
|   - **`TripleGRUDecoder.forward()`**: Complex return values → tuple returns + comprehensive adversarial training dtype fixes + residual connection dtype consistency + GRL dtype handling
 | ||
|   - **`TripleGRUDecoder._apply_preprocessing()`**: Static preprocessing operations + dtype consistency + patch processing dtype preservation
 | ||
|   - **`TripleGRUDecoder._clean_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
 | ||
|   - **`TripleGRUDecoder._noisy_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
 | ||
| 
 | ||
| **Specific Dtype Consistency Fixes Applied**:
 | ||
| 1. **Basic Operations**: All `torch.bmm()` operations with `.to(x.dtype)` conversions
 | ||
| 2. **Patch Processing**: Explicit dtype preservation through unfold/permute/reshape operations
 | ||
| 3. **Adversarial Training**: Residual connections with `.to(x_processed.dtype)` conversions
 | ||
| 4. **Gradient Reversal**: Dtype consistency after GRL operations
 | ||
| 5. **Hidden States**: All hidden state initialization with `.to(x_processed.dtype)` conversions
 | ||
| 6. **Data Flow**: End-to-end dtype consistency in NoiseModel → GRL → NoisySpeechModel pipeline
 | ||
| 
 | ||
| **Error Resolved**: `f32[32,7168] vs bf16[32,7168]` dtype mismatch in mixed precision TPU training
 | ||
| 
 | ||
| ### Benefits of XLA Optimizations
 | ||
| 
 | ||
| 1. **Faster Compilation**: Static shapes allow XLA to pre-compile optimized kernels
 | ||
| 2. **Better Memory Usage**: Reduced dynamic allocation during training
 | ||
| 3. **Improved TPU Utilization**: XLA primitives map directly to TPU matrix units
 | ||
| 4. **Consistent Performance**: Eliminates recompilation caused by dynamic shapes
 | ||
| 
 | ||
| ### Testing and Validation
 | ||
| 
 | ||
| Created test scripts to verify model consistency:
 | ||
| - **`test_xla_model.py`**: Comprehensive model validation testing
 | ||
| - **`quick_test_xla.py`**: Fast verification of basic functionality
 | ||
| 
 | ||
| **Important**: These optimizations preserve the exact model architecture and mathematical operations. Only the implementation has been made XLA-friendly.
 | ||
| 
 | ||
| ### Usage Notes
 | ||
| 
 | ||
| - All original model interfaces remain unchanged
 | ||
| - Both 'inference' and 'full' modes are supported
 | ||
| - Backward compatibility with existing training scripts is maintained
 | ||
| - TPU training should now show improved compilation times and memory efficiency
 | ||
| 
 | ||
| ### Troubleshooting Dtype Issues in Mixed Precision Training
 | ||
| 
 | ||
| **Common Error Pattern**:
 | ||
| ```
 | ||
| Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[X,Y], argument shape: bf16[X,Y]
 | ||
| ```
 | ||
| 
 | ||
| **Diagnosis Steps**:
 | ||
| 1. **Identify Operation**: Look at the tensor dimensions to identify which operation is failing
 | ||
|    - `7168 = 512 * 14`: Patch processing operation with patch_size=14
 | ||
|    - `512`: Basic neural features
 | ||
|    - Other patterns may indicate different operations
 | ||
| 
 | ||
| 2. **Check Data Flow**: Trace the tensor through the adversarial training pipeline
 | ||
|    - Input → NoiseModel → residual connection → CleanSpeechModel
 | ||
|    - Input → NoiseModel → GRL → NoisySpeechModel
 | ||
| 
 | ||
| 3. **Verify Dtype Consistency**: Ensure all operations maintain input dtype
 | ||
|    - Use `.to(x.dtype)` for all operand tensors
 | ||
|    - Preserve dtype through complex operations (unfold, permute, reshape)
 | ||
|    - Match hidden state dtype to input tensor dtype
 | ||
| 
 | ||
| **Quick Fix Template**:
 | ||
| ```python
 | ||
| # For any tensor operation between tensors a and b:
 | ||
| result = operation(a, b.to(a.dtype))
 | ||
| 
 | ||
| # For complex operations that might change dtype:
 | ||
| original_dtype = tensor.dtype
 | ||
| tensor = complex_operation(tensor)
 | ||
| tensor = tensor.to(original_dtype)
 | ||
| 
 | ||
| # For hidden state initialization:
 | ||
| states = states.to(input_tensor.dtype)
 | ||
| ```
 | ||
| 
 | ||
| ## Competition Context
 | ||
| This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding. | 
