288 lines
8.7 KiB
Markdown
288 lines
8.7 KiB
Markdown
![]() |
# TensorFlow Brain-to-Text Model for TPU v5e-8
|
||
|
|
||
|
This directory contains a complete TensorFlow implementation of the brain-to-text neural speech decoding system, optimized for TPU v5e-8 hardware. It provides equivalent functionality to the PyTorch version but with TensorFlow operations designed for maximum TPU performance.
|
||
|
|
||
|
## Architecture Overview
|
||
|
|
||
|
The TensorFlow implementation maintains the same sophisticated three-model adversarial architecture:
|
||
|
|
||
|
### Core Models
|
||
|
- **NoiseModel**: 2-layer GRU that estimates noise in neural data
|
||
|
- **CleanSpeechModel**: 3-layer GRU that processes denoised signal for speech recognition
|
||
|
- **NoisySpeechModel**: 2-layer GRU that processes noise signal for adversarial training
|
||
|
|
||
|
### Key Features
|
||
|
- **Day-specific transformations**: Learnable input layers for each recording session
|
||
|
- **Patch processing**: Temporal patching for improved sequence modeling
|
||
|
- **Gradient Reversal Layer**: For adversarial training between noise and speech models
|
||
|
- **Mixed precision**: bfloat16 optimization for TPU v5e-8 memory efficiency
|
||
|
- **CTC Loss**: Connectionist Temporal Classification for sequence alignment
|
||
|
|
||
|
## Files Overview
|
||
|
|
||
|
### Core Implementation
|
||
|
- `rnn_model_tf.py`: TensorFlow model architecture with TPU optimizations
|
||
|
- `trainer_tf.py`: Training pipeline with distributed TPU strategy
|
||
|
- `dataset_tf.py`: Data loading and augmentation optimized for TPU
|
||
|
- `train_model_tf.py`: Main training script
|
||
|
- `evaluate_model_tf.py`: Evaluation and inference script
|
||
|
|
||
|
### Configuration and Setup
|
||
|
- `rnn_args.yaml`: Training configuration (shared with PyTorch version)
|
||
|
- `setup_tensorflow_tpu.sh`: Environment setup script
|
||
|
- `requirements_tf.txt`: Python dependencies
|
||
|
- `README_TensorFlow.md`: This documentation
|
||
|
|
||
|
## Quick Start
|
||
|
|
||
|
### 1. Environment Setup
|
||
|
```bash
|
||
|
# Run the setup script to configure TPU environment
|
||
|
./setup_tensorflow_tpu.sh
|
||
|
|
||
|
# Activate the conda environment
|
||
|
conda activate b2txt_tf
|
||
|
```
|
||
|
|
||
|
### 2. Verify TPU Access
|
||
|
```python
|
||
|
import tensorflow as tf
|
||
|
|
||
|
# Check TPU availability
|
||
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||
|
tf.config.experimental_connect_to_cluster(resolver)
|
||
|
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||
|
strategy = tf.distribute.TPUStrategy(resolver)
|
||
|
print(f"TPU cores available: {strategy.num_replicas_in_sync}")
|
||
|
```
|
||
|
|
||
|
### 3. Start Training
|
||
|
```bash
|
||
|
# Basic training with default config
|
||
|
python train_model_tf.py --config_path rnn_args.yaml
|
||
|
|
||
|
# Training with custom settings
|
||
|
python train_model_tf.py \
|
||
|
--config_path rnn_args.yaml \
|
||
|
--batch_size 64 \
|
||
|
--num_batches 50000 \
|
||
|
--output_dir ./trained_models/custom_run
|
||
|
```
|
||
|
|
||
|
### 4. Run Evaluation
|
||
|
```bash
|
||
|
# Evaluate trained model
|
||
|
python evaluate_model_tf.py \
|
||
|
--model_path ./trained_models/baseline_rnn/checkpoint/best_checkpoint \
|
||
|
--config_path rnn_args.yaml \
|
||
|
--eval_type test
|
||
|
```
|
||
|
|
||
|
## TPU v5e-8 Optimizations
|
||
|
|
||
|
### Hardware-Specific Features
|
||
|
- **Mixed Precision**: Automatic bfloat16 conversion for 2x memory efficiency
|
||
|
- **XLA Compilation**: Just-in-time compilation for optimal TPU performance
|
||
|
- **Distributed Strategy**: Automatic sharding across 8 TPU cores
|
||
|
- **Memory Management**: Efficient tensor operations to avoid OOM errors
|
||
|
|
||
|
### Performance Optimizations
|
||
|
- **Batch Matrix Operations**: `tf.linalg.matmul` instead of element-wise operations
|
||
|
- **Static Shapes**: Avoiding dynamic tensor shapes for better compilation
|
||
|
- **Efficient Gathering**: `tf.gather` for day-specific parameter selection
|
||
|
- **Gradient Reversal**: Custom gradient function for adversarial training
|
||
|
|
||
|
## Configuration
|
||
|
|
||
|
The model uses the same `rnn_args.yaml` configuration as the PyTorch version. Key TPU-specific settings:
|
||
|
|
||
|
```yaml
|
||
|
# TPU-specific settings
|
||
|
use_amp: true # Enable mixed precision (bfloat16)
|
||
|
dataset:
|
||
|
batch_size: 32 # Optimized for TPU memory
|
||
|
num_dataloader_workers: 0 # Disable multiprocessing on TPU
|
||
|
|
||
|
# Model architecture (same as PyTorch)
|
||
|
model:
|
||
|
n_input_features: 512 # Neural features per timestep
|
||
|
n_units: 768 # Hidden units per GRU layer
|
||
|
patch_size: 14 # Temporal patch size
|
||
|
patch_stride: 4 # Patch stride
|
||
|
```
|
||
|
|
||
|
## Performance Comparison
|
||
|
|
||
|
### TPU v5e-8 vs Other Hardware
|
||
|
- **Memory**: 2x improvement with bfloat16 mixed precision
|
||
|
- **Throughput**: ~3-4x faster training than V100 GPU
|
||
|
- **Scalability**: Automatic distribution across 8 cores
|
||
|
- **Cost Efficiency**: Better performance-per-dollar for large models
|
||
|
|
||
|
### Expected Training Times (120k batches)
|
||
|
- **TPU v5e-8**: ~4-6 hours
|
||
|
- **Single V100**: ~15-20 hours
|
||
|
- **RTX 4090**: ~12-18 hours
|
||
|
|
||
|
## Model Architecture Details
|
||
|
|
||
|
### TripleGRUDecoder Forward Pass
|
||
|
```python
|
||
|
# Training mode (adversarial)
|
||
|
clean_logits, noisy_logits, noise_output = model(
|
||
|
features, day_indices, mode='full',
|
||
|
grl_lambda=0.5, training=True
|
||
|
)
|
||
|
|
||
|
# Inference mode (production)
|
||
|
clean_logits = model(
|
||
|
features, day_indices, mode='inference',
|
||
|
training=False
|
||
|
)
|
||
|
```
|
||
|
|
||
|
### Loss Functions
|
||
|
```python
|
||
|
# Clean speech CTC loss
|
||
|
clean_loss = ctc_loss(clean_logits, labels, input_lengths, label_lengths)
|
||
|
|
||
|
# Adversarial noisy speech loss (with gradient reversal)
|
||
|
noisy_loss = ctc_loss(noisy_logits, labels, input_lengths, label_lengths)
|
||
|
|
||
|
# Combined loss
|
||
|
total_loss = clean_loss + 0.2 * noisy_loss + 0.001 * noise_l2_loss
|
||
|
```
|
||
|
|
||
|
## Data Pipeline
|
||
|
|
||
|
### HDF5 Data Loading
|
||
|
The TensorFlow implementation efficiently loads data from HDF5 files:
|
||
|
- **Batch creation**: Pre-batched data with padding
|
||
|
- **Feature subsets**: Configurable neural feature selection
|
||
|
- **Day balancing**: Ensures even representation across recording sessions
|
||
|
- **Memory efficiency**: Lazy loading with tf.data.Dataset
|
||
|
|
||
|
### Data Augmentations
|
||
|
- **Gaussian smoothing**: Temporal smoothing of neural signals
|
||
|
- **White noise**: Additive Gaussian noise for robustness
|
||
|
- **Static gain**: Channel-wise multiplicative noise
|
||
|
- **Random walk**: Temporal drift simulation
|
||
|
- **Random cutoff**: Variable sequence lengths
|
||
|
|
||
|
## Troubleshooting
|
||
|
|
||
|
### Common TPU Issues
|
||
|
|
||
|
#### "Resource exhausted" errors
|
||
|
```bash
|
||
|
# Reduce batch size
|
||
|
python train_model_tf.py --batch_size 16
|
||
|
|
||
|
# Enable gradient accumulation
|
||
|
# Modify config: gradient_accumulation_steps: 4
|
||
|
```
|
||
|
|
||
|
#### TPU not detected
|
||
|
```bash
|
||
|
# Check environment variables
|
||
|
echo $TPU_NAME
|
||
|
echo $COLAB_TPU_ADDR
|
||
|
|
||
|
# Verify TPU access
|
||
|
gcloud compute tpus list
|
||
|
```
|
||
|
|
||
|
#### Mixed precision issues
|
||
|
```bash
|
||
|
# Disable mixed precision if needed
|
||
|
python train_model_tf.py --disable_mixed_precision
|
||
|
```
|
||
|
|
||
|
### Performance Debugging
|
||
|
```python
|
||
|
# Enable XLA logging
|
||
|
import os
|
||
|
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
|
||
|
|
||
|
# Profile TPU usage
|
||
|
tf.profiler.experimental.start('logdir')
|
||
|
# ... training code ...
|
||
|
tf.profiler.experimental.stop()
|
||
|
```
|
||
|
|
||
|
## Advanced Usage
|
||
|
|
||
|
### Custom Training Loop
|
||
|
```python
|
||
|
from trainer_tf import BrainToTextDecoderTrainerTF
|
||
|
|
||
|
# Initialize trainer
|
||
|
trainer = BrainToTextDecoderTrainerTF(config)
|
||
|
|
||
|
# Custom training with checkpointing
|
||
|
for epoch in range(num_epochs):
|
||
|
stats = trainer.train()
|
||
|
if epoch % 5 == 0:
|
||
|
trainer._save_checkpoint(f'epoch_{epoch}', epoch)
|
||
|
```
|
||
|
|
||
|
### Model Inference
|
||
|
```python
|
||
|
# Load trained model
|
||
|
model = trainer.model
|
||
|
model.load_weights('path/to/checkpoint.weights.h5')
|
||
|
|
||
|
# Run inference
|
||
|
logits = trainer.inference(features, day_indices, n_time_steps)
|
||
|
|
||
|
# Decode predictions
|
||
|
predictions = tf.argmax(logits, axis=-1)
|
||
|
```
|
||
|
|
||
|
### Hyperparameter Tuning
|
||
|
```python
|
||
|
# Grid search over learning rates
|
||
|
learning_rates = [0.001, 0.005, 0.01]
|
||
|
for lr in learning_rates:
|
||
|
config.lr_max = lr
|
||
|
trainer = BrainToTextDecoderTrainerTF(config)
|
||
|
stats = trainer.train()
|
||
|
```
|
||
|
|
||
|
## Research and Development
|
||
|
|
||
|
This TensorFlow implementation maintains full compatibility with the published research while providing:
|
||
|
|
||
|
1. **Reproducible Results**: Same model architecture and training procedures
|
||
|
2. **Hardware Optimization**: TPU-specific performance improvements
|
||
|
3. **Scalability**: Easy scaling to larger models and datasets
|
||
|
4. **Extensibility**: Clean APIs for research modifications
|
||
|
|
||
|
### Key Research Features
|
||
|
- **Adversarial Training**: Domain adaptation through gradient reversal
|
||
|
- **Multi-day Learning**: Session-specific input transformations
|
||
|
- **Temporal Modeling**: Patch-based sequence processing
|
||
|
- **Robust Training**: Comprehensive data augmentation pipeline
|
||
|
|
||
|
## Citation
|
||
|
|
||
|
If you use this TensorFlow implementation in your research, please cite the original paper:
|
||
|
|
||
|
```bibtex
|
||
|
@article{card2024accurate,
|
||
|
title={An Accurate and Rapidly Calibrating Speech Neuroprosthesis},
|
||
|
author={Card, Nicholas S and others},
|
||
|
journal={New England Journal of Medicine},
|
||
|
year={2024}
|
||
|
}
|
||
|
```
|
||
|
|
||
|
## Support
|
||
|
|
||
|
For questions specific to the TensorFlow implementation:
|
||
|
1. Check this README and the PyTorch documentation in `../CLAUDE.md`
|
||
|
2. Review configuration options in `rnn_args.yaml`
|
||
|
3. Examine example scripts in this directory
|
||
|
4. Open issues on the project repository
|
||
|
|
||
|
For TPU-specific questions, consult Google Cloud TPU documentation and TensorFlow TPU guides.
|