8.7 KiB
		
	
	
	
	
	
	
	
			
		
		
	
	TensorFlow Brain-to-Text Model for TPU v5e-8
This directory contains a complete TensorFlow implementation of the brain-to-text neural speech decoding system, optimized for TPU v5e-8 hardware. It provides equivalent functionality to the PyTorch version but with TensorFlow operations designed for maximum TPU performance.
Architecture Overview
The TensorFlow implementation maintains the same sophisticated three-model adversarial architecture:
Core Models
- NoiseModel: 2-layer GRU that estimates noise in neural data
- CleanSpeechModel: 3-layer GRU that processes denoised signal for speech recognition
- NoisySpeechModel: 2-layer GRU that processes noise signal for adversarial training
Key Features
- Day-specific transformations: Learnable input layers for each recording session
- Patch processing: Temporal patching for improved sequence modeling
- Gradient Reversal Layer: For adversarial training between noise and speech models
- Mixed precision: bfloat16 optimization for TPU v5e-8 memory efficiency
- CTC Loss: Connectionist Temporal Classification for sequence alignment
Files Overview
Core Implementation
- rnn_model_tf.py: TensorFlow model architecture with TPU optimizations
- trainer_tf.py: Training pipeline with distributed TPU strategy
- dataset_tf.py: Data loading and augmentation optimized for TPU
- train_model_tf.py: Main training script
- evaluate_model_tf.py: Evaluation and inference script
Configuration and Setup
- rnn_args.yaml: Training configuration (shared with PyTorch version)
- setup_tensorflow_tpu.sh: Environment setup script
- requirements_tf.txt: Python dependencies
- README_TensorFlow.md: This documentation
Quick Start
1. Environment Setup
# Run the setup script to configure TPU environment
./setup_tensorflow_tpu.sh
# Activate the conda environment
conda activate b2txt_tf
2. Verify TPU Access
import tensorflow as tf
# Check TPU availability
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
print(f"TPU cores available: {strategy.num_replicas_in_sync}")
3. Start Training
# Basic training with default config
python train_model_tf.py --config_path rnn_args.yaml
# Training with custom settings
python train_model_tf.py \
    --config_path rnn_args.yaml \
    --batch_size 64 \
    --num_batches 50000 \
    --output_dir ./trained_models/custom_run
4. Run Evaluation
# Evaluate trained model
python evaluate_model_tf.py \
    --model_path ./trained_models/baseline_rnn/checkpoint/best_checkpoint \
    --config_path rnn_args.yaml \
    --eval_type test
TPU v5e-8 Optimizations
Hardware-Specific Features
- Mixed Precision: Automatic bfloat16 conversion for 2x memory efficiency
- XLA Compilation: Just-in-time compilation for optimal TPU performance
- Distributed Strategy: Automatic sharding across 8 TPU cores
- Memory Management: Efficient tensor operations to avoid OOM errors
Performance Optimizations
- Batch Matrix Operations: tf.linalg.matmulinstead of element-wise operations
- Static Shapes: Avoiding dynamic tensor shapes for better compilation
- Efficient Gathering: tf.gatherfor day-specific parameter selection
- Gradient Reversal: Custom gradient function for adversarial training
Configuration
The model uses the same rnn_args.yaml configuration as the PyTorch version. Key TPU-specific settings:
# TPU-specific settings
use_amp: true                    # Enable mixed precision (bfloat16)
dataset:
  batch_size: 32                # Optimized for TPU memory
  num_dataloader_workers: 0     # Disable multiprocessing on TPU
# Model architecture (same as PyTorch)
model:
  n_input_features: 512         # Neural features per timestep
  n_units: 768                  # Hidden units per GRU layer
  patch_size: 14                # Temporal patch size
  patch_stride: 4               # Patch stride
Performance Comparison
TPU v5e-8 vs Other Hardware
- Memory: 2x improvement with bfloat16 mixed precision
- Throughput: ~3-4x faster training than V100 GPU
- Scalability: Automatic distribution across 8 cores
- Cost Efficiency: Better performance-per-dollar for large models
Expected Training Times (120k batches)
- TPU v5e-8: ~4-6 hours
- Single V100: ~15-20 hours
- RTX 4090: ~12-18 hours
Model Architecture Details
TripleGRUDecoder Forward Pass
# Training mode (adversarial)
clean_logits, noisy_logits, noise_output = model(
    features, day_indices, mode='full',
    grl_lambda=0.5, training=True
)
# Inference mode (production)
clean_logits = model(
    features, day_indices, mode='inference',
    training=False
)
Loss Functions
# Clean speech CTC loss
clean_loss = ctc_loss(clean_logits, labels, input_lengths, label_lengths)
# Adversarial noisy speech loss (with gradient reversal)
noisy_loss = ctc_loss(noisy_logits, labels, input_lengths, label_lengths)
# Combined loss
total_loss = clean_loss + 0.2 * noisy_loss + 0.001 * noise_l2_loss
Data Pipeline
HDF5 Data Loading
The TensorFlow implementation efficiently loads data from HDF5 files:
- Batch creation: Pre-batched data with padding
- Feature subsets: Configurable neural feature selection
- Day balancing: Ensures even representation across recording sessions
- Memory efficiency: Lazy loading with tf.data.Dataset
Data Augmentations
- Gaussian smoothing: Temporal smoothing of neural signals
- White noise: Additive Gaussian noise for robustness
- Static gain: Channel-wise multiplicative noise
- Random walk: Temporal drift simulation
- Random cutoff: Variable sequence lengths
Troubleshooting
Common TPU Issues
"Resource exhausted" errors
# Reduce batch size
python train_model_tf.py --batch_size 16
# Enable gradient accumulation
# Modify config: gradient_accumulation_steps: 4
TPU not detected
# Check environment variables
echo $TPU_NAME
echo $COLAB_TPU_ADDR
# Verify TPU access
gcloud compute tpus list
Mixed precision issues
# Disable mixed precision if needed
python train_model_tf.py --disable_mixed_precision
Performance Debugging
# Enable XLA logging
import os
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
# Profile TPU usage
tf.profiler.experimental.start('logdir')
# ... training code ...
tf.profiler.experimental.stop()
Advanced Usage
Custom Training Loop
from trainer_tf import BrainToTextDecoderTrainerTF
# Initialize trainer
trainer = BrainToTextDecoderTrainerTF(config)
# Custom training with checkpointing
for epoch in range(num_epochs):
    stats = trainer.train()
    if epoch % 5 == 0:
        trainer._save_checkpoint(f'epoch_{epoch}', epoch)
Model Inference
# Load trained model
model = trainer.model
model.load_weights('path/to/checkpoint.weights.h5')
# Run inference
logits = trainer.inference(features, day_indices, n_time_steps)
# Decode predictions
predictions = tf.argmax(logits, axis=-1)
Hyperparameter Tuning
# Grid search over learning rates
learning_rates = [0.001, 0.005, 0.01]
for lr in learning_rates:
    config.lr_max = lr
    trainer = BrainToTextDecoderTrainerTF(config)
    stats = trainer.train()
Research and Development
This TensorFlow implementation maintains full compatibility with the published research while providing:
- Reproducible Results: Same model architecture and training procedures
- Hardware Optimization: TPU-specific performance improvements
- Scalability: Easy scaling to larger models and datasets
- Extensibility: Clean APIs for research modifications
Key Research Features
- Adversarial Training: Domain adaptation through gradient reversal
- Multi-day Learning: Session-specific input transformations
- Temporal Modeling: Patch-based sequence processing
- Robust Training: Comprehensive data augmentation pipeline
Citation
If you use this TensorFlow implementation in your research, please cite the original paper:
@article{card2024accurate,
  title={An Accurate and Rapidly Calibrating Speech Neuroprosthesis},
  author={Card, Nicholas S and others},
  journal={New England Journal of Medicine},
  year={2024}
}
Support
For questions specific to the TensorFlow implementation:
- Check this README and the PyTorch documentation in ../CLAUDE.md
- Review configuration options in rnn_args.yaml
- Examine example scripts in this directory
- Open issues on the project repository
For TPU-specific questions, consult Google Cloud TPU documentation and TensorFlow TPU guides.
