TPU-Optimized Brain-to-Text Model Training
This directory contains TPU-optimized code for training the brain-to-text RNN model with advanced adversarial training architecture. The model is based on "An Accurate and Rapidly Calibrating Speech Neuroprosthesis" by Card et al. (2024), enhanced with three-model adversarial training and comprehensive XLA optimizations for efficient TPU training.
Key Features
- Triple-Model Adversarial Architecture: NoiseModel + CleanSpeechModel + NoisySpeechModel for robust neural decoding
- XLA/TPU Optimizations: Comprehensive optimizations for fast compilation and efficient TPU utilization
- Mixed Precision Training: bfloat16 support with full dtype consistency
- Distributed Training: 8-core TPU support with Accelerate library integration
- 687M Parameters: Large-scale model with patch processing and day-specific adaptations
For detailed technical documentation, see TPU_MODEL_SUMMARY.md.
Setup
-
Install the required
b2txt25
conda environment by following the instructions in the rootREADME.md
file. This will set up the necessary dependencies for running the model training and evaluation code. -
Download the dataset from Dryad: Dryad Dataset. Place the downloaded data in the
data
directory. See the main README.md file for more details on the included datasets and the properdata
directory structure.
TPU Training
Triple-Model Adversarial Architecture
This implementation features an advanced three-model adversarial training system:
- NoiseModel: 2-layer GRU that estimates noise in neural data
- CleanSpeechModel: 3-layer GRU that processes denoised signals for speech recognition
- NoisySpeechModel: 2-layer GRU that processes noise signals for adversarial training
The architecture uses residual connections and gradient reversal layers (GRL) to improve robustness. All models include day-specific input layers (512x512 linear with softsign activation), patch processing (14 timesteps), and are optimized for XLA compilation on TPU.
Training Methods
Option 1: Direct Training
conda activate b2txt25
python train_model.py --config_path rnn_args.yaml
Option 2: Launcher Script (Recommended)
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
Option 3: Accelerate
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py
The model trains for 120,000 mini-batches with mixed precision (bfloat16) and distributed training across 8 TPU cores. Expected training time varies based on TPU type and configuration. All hyperparameters are specified in rnn_args.yaml
.
Model Configuration
Key Configuration Files
rnn_args.yaml
: Main training configuration with adversarial training settingsaccelerate_config_tpu.yaml
: Accelerate library configuration for TPUlaunch_tpu_training.py
: Convenient TPU training launcher
Adversarial Training Settings
adversarial:
enabled: true
grl_lambda: 0.5 # Gradient Reversal Layer strength
noisy_loss_weight: 0.2 # Weight for noisy branch CTC loss
noise_l2_weight: 0.0 # L2 regularization on noise output
warmup_steps: 0 # Steps before enabling adversarial training
TPU-Specific Settings
use_tpu: true
num_tpu_cores: 8
gradient_accumulation_steps: 2
use_amp: true # bfloat16 mixed precision
batch_size: 32 # Per-core batch size
num_dataloader_workers: 0 # Required for TPU
Evaluation
Model evaluation using the trained TripleGRUDecoder requires the language model pipeline. Please refer to the main project README for complete evaluation setup instructions. The evaluation scripts in this directory are currently being adapted for TPU compatibility.