Files
b2txt25/model_training_nnn_tpu
2025-10-17 00:51:53 +08:00
..
TPU
2025-10-15 16:55:52 +08:00
2025-10-16 13:39:05 +08:00
tpu
2025-10-15 14:26:11 +08:00
2025-10-16 20:26:32 +08:00
tpu
2025-10-15 14:26:11 +08:00
tpu
2025-10-15 14:26:11 +08:00
TPU
2025-10-15 16:55:52 +08:00
tpu
2025-10-15 14:26:11 +08:00
2025-10-16 10:53:42 +08:00
2025-10-17 00:51:53 +08:00
TPU
2025-10-15 16:55:52 +08:00
tpu
2025-10-15 14:26:11 +08:00
TPU
2025-10-15 16:55:52 +08:00
2025-10-16 17:37:59 +08:00
tpu
2025-10-15 14:26:11 +08:00
2025-10-16 21:13:42 +08:00
tpu
2025-10-15 14:26:11 +08:00
tpu
2025-10-15 14:38:56 +08:00
TPU
2025-10-15 16:55:52 +08:00
tpu
2025-10-15 20:45:25 +08:00
2025-10-16 13:39:05 +08:00
TPU
2025-10-15 16:55:52 +08:00
tpu
2025-10-15 14:26:11 +08:00
2025-10-17 00:51:53 +08:00
2025-10-16 09:22:25 +08:00

TPU-Optimized Brain-to-Text Model Training

This directory contains TPU-optimized code for training the brain-to-text RNN model with advanced adversarial training architecture. The model is based on "An Accurate and Rapidly Calibrating Speech Neuroprosthesis" by Card et al. (2024), enhanced with three-model adversarial training and comprehensive XLA optimizations for efficient TPU training.

Key Features

  • Triple-Model Adversarial Architecture: NoiseModel + CleanSpeechModel + NoisySpeechModel for robust neural decoding
  • XLA/TPU Optimizations: Comprehensive optimizations for fast compilation and efficient TPU utilization
  • Mixed Precision Training: bfloat16 support with full dtype consistency
  • Distributed Training: 8-core TPU support with Accelerate library integration
  • 687M Parameters: Large-scale model with patch processing and day-specific adaptations

For detailed technical documentation, see TPU_MODEL_SUMMARY.md.

Setup

  1. Install the required b2txt25 conda environment by following the instructions in the root README.md file. This will set up the necessary dependencies for running the model training and evaluation code.

  2. Download the dataset from Dryad: Dryad Dataset. Place the downloaded data in the data directory. See the main README.md file for more details on the included datasets and the proper data directory structure.

TPU Training

Triple-Model Adversarial Architecture

This implementation features an advanced three-model adversarial training system:

  • NoiseModel: 2-layer GRU that estimates noise in neural data
  • CleanSpeechModel: 3-layer GRU that processes denoised signals for speech recognition
  • NoisySpeechModel: 2-layer GRU that processes noise signals for adversarial training

The architecture uses residual connections and gradient reversal layers (GRL) to improve robustness. All models include day-specific input layers (512x512 linear with softsign activation), patch processing (14 timesteps), and are optimized for XLA compilation on TPU.

Training Methods

Option 1: Direct Training

conda activate b2txt25
python train_model.py --config_path rnn_args.yaml
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8

Option 3: Accelerate

accelerate launch --config_file accelerate_config_tpu.yaml train_model.py

The model trains for 120,000 mini-batches with mixed precision (bfloat16) and distributed training across 8 TPU cores. Expected training time varies based on TPU type and configuration. All hyperparameters are specified in rnn_args.yaml.

Model Configuration

Key Configuration Files

  • rnn_args.yaml: Main training configuration with adversarial training settings
  • accelerate_config_tpu.yaml: Accelerate library configuration for TPU
  • launch_tpu_training.py: Convenient TPU training launcher

Adversarial Training Settings

adversarial:
  enabled: true
  grl_lambda: 0.5        # Gradient Reversal Layer strength
  noisy_loss_weight: 0.2 # Weight for noisy branch CTC loss
  noise_l2_weight: 0.0   # L2 regularization on noise output
  warmup_steps: 0        # Steps before enabling adversarial training

TPU-Specific Settings

use_tpu: true
num_tpu_cores: 8
gradient_accumulation_steps: 2
use_amp: true  # bfloat16 mixed precision
batch_size: 32  # Per-core batch size
num_dataloader_workers: 0  # Required for TPU

Evaluation

Model evaluation using the trained TripleGRUDecoder requires the language model pipeline. Please refer to the main project README for complete evaluation setup instructions. The evaluation scripts in this directory are currently being adapted for TPU compatibility.