# TensorFlow Brain-to-Text Model for TPU v5e-8 This directory contains a complete TensorFlow implementation of the brain-to-text neural speech decoding system, optimized for TPU v5e-8 hardware. It provides equivalent functionality to the PyTorch version but with TensorFlow operations designed for maximum TPU performance. ## Architecture Overview The TensorFlow implementation maintains the same sophisticated three-model adversarial architecture: ### Core Models - **NoiseModel**: 2-layer GRU that estimates noise in neural data - **CleanSpeechModel**: 3-layer GRU that processes denoised signal for speech recognition - **NoisySpeechModel**: 2-layer GRU that processes noise signal for adversarial training ### Key Features - **Day-specific transformations**: Learnable input layers for each recording session - **Patch processing**: Temporal patching for improved sequence modeling - **Gradient Reversal Layer**: For adversarial training between noise and speech models - **Mixed precision**: bfloat16 optimization for TPU v5e-8 memory efficiency - **CTC Loss**: Connectionist Temporal Classification for sequence alignment ## Files Overview ### Core Implementation - `rnn_model_tf.py`: TensorFlow model architecture with TPU optimizations - `trainer_tf.py`: Training pipeline with distributed TPU strategy - `dataset_tf.py`: Data loading and augmentation optimized for TPU - `train_model_tf.py`: Main training script - `evaluate_model_tf.py`: Evaluation and inference script ### Configuration and Setup - `rnn_args.yaml`: Training configuration (shared with PyTorch version) - `setup_tensorflow_tpu.sh`: Environment setup script - `requirements_tf.txt`: Python dependencies - `README_TensorFlow.md`: This documentation ## Quick Start ### 1. Environment Setup ```bash # Run the setup script to configure TPU environment ./setup_tensorflow_tpu.sh # Activate the conda environment conda activate b2txt_tf ``` ### 2. Verify TPU Access ```python import tensorflow as tf # Check TPU availability resolver = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) print(f"TPU cores available: {strategy.num_replicas_in_sync}") ``` ### 3. Start Training ```bash # Basic training with default config python train_model_tf.py --config_path rnn_args.yaml # Training with custom settings python train_model_tf.py \ --config_path rnn_args.yaml \ --batch_size 64 \ --num_batches 50000 \ --output_dir ./trained_models/custom_run ``` ### 4. Run Evaluation ```bash # Evaluate trained model python evaluate_model_tf.py \ --model_path ./trained_models/baseline_rnn/checkpoint/best_checkpoint \ --config_path rnn_args.yaml \ --eval_type test ``` ## TPU v5e-8 Optimizations ### Hardware-Specific Features - **Mixed Precision**: Automatic bfloat16 conversion for 2x memory efficiency - **XLA Compilation**: Just-in-time compilation for optimal TPU performance - **Distributed Strategy**: Automatic sharding across 8 TPU cores - **Memory Management**: Efficient tensor operations to avoid OOM errors ### Performance Optimizations - **Batch Matrix Operations**: `tf.linalg.matmul` instead of element-wise operations - **Static Shapes**: Avoiding dynamic tensor shapes for better compilation - **Efficient Gathering**: `tf.gather` for day-specific parameter selection - **Gradient Reversal**: Custom gradient function for adversarial training ## Configuration The model uses the same `rnn_args.yaml` configuration as the PyTorch version. Key TPU-specific settings: ```yaml # TPU-specific settings use_amp: true # Enable mixed precision (bfloat16) dataset: batch_size: 32 # Optimized for TPU memory num_dataloader_workers: 0 # Disable multiprocessing on TPU # Model architecture (same as PyTorch) model: n_input_features: 512 # Neural features per timestep n_units: 768 # Hidden units per GRU layer patch_size: 14 # Temporal patch size patch_stride: 4 # Patch stride ``` ## Performance Comparison ### TPU v5e-8 vs Other Hardware - **Memory**: 2x improvement with bfloat16 mixed precision - **Throughput**: ~3-4x faster training than V100 GPU - **Scalability**: Automatic distribution across 8 cores - **Cost Efficiency**: Better performance-per-dollar for large models ### Expected Training Times (120k batches) - **TPU v5e-8**: ~4-6 hours - **Single V100**: ~15-20 hours - **RTX 4090**: ~12-18 hours ## Model Architecture Details ### TripleGRUDecoder Forward Pass ```python # Training mode (adversarial) clean_logits, noisy_logits, noise_output = model( features, day_indices, mode='full', grl_lambda=0.5, training=True ) # Inference mode (production) clean_logits = model( features, day_indices, mode='inference', training=False ) ``` ### Loss Functions ```python # Clean speech CTC loss clean_loss = ctc_loss(clean_logits, labels, input_lengths, label_lengths) # Adversarial noisy speech loss (with gradient reversal) noisy_loss = ctc_loss(noisy_logits, labels, input_lengths, label_lengths) # Combined loss total_loss = clean_loss + 0.2 * noisy_loss + 0.001 * noise_l2_loss ``` ## Data Pipeline ### HDF5 Data Loading The TensorFlow implementation efficiently loads data from HDF5 files: - **Batch creation**: Pre-batched data with padding - **Feature subsets**: Configurable neural feature selection - **Day balancing**: Ensures even representation across recording sessions - **Memory efficiency**: Lazy loading with tf.data.Dataset ### Data Augmentations - **Gaussian smoothing**: Temporal smoothing of neural signals - **White noise**: Additive Gaussian noise for robustness - **Static gain**: Channel-wise multiplicative noise - **Random walk**: Temporal drift simulation - **Random cutoff**: Variable sequence lengths ## Troubleshooting ### Common TPU Issues #### "Resource exhausted" errors ```bash # Reduce batch size python train_model_tf.py --batch_size 16 # Enable gradient accumulation # Modify config: gradient_accumulation_steps: 4 ``` #### TPU not detected ```bash # Check environment variables echo $TPU_NAME echo $COLAB_TPU_ADDR # Verify TPU access gcloud compute tpus list ``` #### Mixed precision issues ```bash # Disable mixed precision if needed python train_model_tf.py --disable_mixed_precision ``` ### Performance Debugging ```python # Enable XLA logging import os os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit' # Profile TPU usage tf.profiler.experimental.start('logdir') # ... training code ... tf.profiler.experimental.stop() ``` ## Advanced Usage ### Custom Training Loop ```python from trainer_tf import BrainToTextDecoderTrainerTF # Initialize trainer trainer = BrainToTextDecoderTrainerTF(config) # Custom training with checkpointing for epoch in range(num_epochs): stats = trainer.train() if epoch % 5 == 0: trainer._save_checkpoint(f'epoch_{epoch}', epoch) ``` ### Model Inference ```python # Load trained model model = trainer.model model.load_weights('path/to/checkpoint.weights.h5') # Run inference logits = trainer.inference(features, day_indices, n_time_steps) # Decode predictions predictions = tf.argmax(logits, axis=-1) ``` ### Hyperparameter Tuning ```python # Grid search over learning rates learning_rates = [0.001, 0.005, 0.01] for lr in learning_rates: config.lr_max = lr trainer = BrainToTextDecoderTrainerTF(config) stats = trainer.train() ``` ## Research and Development This TensorFlow implementation maintains full compatibility with the published research while providing: 1. **Reproducible Results**: Same model architecture and training procedures 2. **Hardware Optimization**: TPU-specific performance improvements 3. **Scalability**: Easy scaling to larger models and datasets 4. **Extensibility**: Clean APIs for research modifications ### Key Research Features - **Adversarial Training**: Domain adaptation through gradient reversal - **Multi-day Learning**: Session-specific input transformations - **Temporal Modeling**: Patch-based sequence processing - **Robust Training**: Comprehensive data augmentation pipeline ## Citation If you use this TensorFlow implementation in your research, please cite the original paper: ```bibtex @article{card2024accurate, title={An Accurate and Rapidly Calibrating Speech Neuroprosthesis}, author={Card, Nicholas S and others}, journal={New England Journal of Medicine}, year={2024} } ``` ## Support For questions specific to the TensorFlow implementation: 1. Check this README and the PyTorch documentation in `../CLAUDE.md` 2. Review configuration options in `rnn_args.yaml` 3. Examine example scripts in this directory 4. Open issues on the project repository For TPU-specific questions, consult Google Cloud TPU documentation and TensorFlow TPU guides.