288 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
			
		
		
	
	
			288 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
| # TensorFlow Brain-to-Text Model for TPU v5e-8
 | |
| 
 | |
| This directory contains a complete TensorFlow implementation of the brain-to-text neural speech decoding system, optimized for TPU v5e-8 hardware. It provides equivalent functionality to the PyTorch version but with TensorFlow operations designed for maximum TPU performance.
 | |
| 
 | |
| ## Architecture Overview
 | |
| 
 | |
| The TensorFlow implementation maintains the same sophisticated three-model adversarial architecture:
 | |
| 
 | |
| ### Core Models
 | |
| - **NoiseModel**: 2-layer GRU that estimates noise in neural data
 | |
| - **CleanSpeechModel**: 3-layer GRU that processes denoised signal for speech recognition
 | |
| - **NoisySpeechModel**: 2-layer GRU that processes noise signal for adversarial training
 | |
| 
 | |
| ### Key Features
 | |
| - **Day-specific transformations**: Learnable input layers for each recording session
 | |
| - **Patch processing**: Temporal patching for improved sequence modeling
 | |
| - **Gradient Reversal Layer**: For adversarial training between noise and speech models
 | |
| - **Mixed precision**: bfloat16 optimization for TPU v5e-8 memory efficiency
 | |
| - **CTC Loss**: Connectionist Temporal Classification for sequence alignment
 | |
| 
 | |
| ## Files Overview
 | |
| 
 | |
| ### Core Implementation
 | |
| - `rnn_model_tf.py`: TensorFlow model architecture with TPU optimizations
 | |
| - `trainer_tf.py`: Training pipeline with distributed TPU strategy
 | |
| - `dataset_tf.py`: Data loading and augmentation optimized for TPU
 | |
| - `train_model_tf.py`: Main training script
 | |
| - `evaluate_model_tf.py`: Evaluation and inference script
 | |
| 
 | |
| ### Configuration and Setup
 | |
| - `rnn_args.yaml`: Training configuration (shared with PyTorch version)
 | |
| - `setup_tensorflow_tpu.sh`: Environment setup script
 | |
| - `requirements_tf.txt`: Python dependencies
 | |
| - `README_TensorFlow.md`: This documentation
 | |
| 
 | |
| ## Quick Start
 | |
| 
 | |
| ### 1. Environment Setup
 | |
| ```bash
 | |
| # Run the setup script to configure TPU environment
 | |
| ./setup_tensorflow_tpu.sh
 | |
| 
 | |
| # Activate the conda environment
 | |
| conda activate b2txt_tf
 | |
| ```
 | |
| 
 | |
| ### 2. Verify TPU Access
 | |
| ```python
 | |
| import tensorflow as tf
 | |
| 
 | |
| # Check TPU availability
 | |
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
 | |
| tf.config.experimental_connect_to_cluster(resolver)
 | |
| tf.tpu.experimental.initialize_tpu_system(resolver)
 | |
| strategy = tf.distribute.TPUStrategy(resolver)
 | |
| print(f"TPU cores available: {strategy.num_replicas_in_sync}")
 | |
| ```
 | |
| 
 | |
| ### 3. Start Training
 | |
| ```bash
 | |
| # Basic training with default config
 | |
| python train_model_tf.py --config_path rnn_args.yaml
 | |
| 
 | |
| # Training with custom settings
 | |
| python train_model_tf.py \
 | |
|     --config_path rnn_args.yaml \
 | |
|     --batch_size 64 \
 | |
|     --num_batches 50000 \
 | |
|     --output_dir ./trained_models/custom_run
 | |
| ```
 | |
| 
 | |
| ### 4. Run Evaluation
 | |
| ```bash
 | |
| # Evaluate trained model
 | |
| python evaluate_model_tf.py \
 | |
|     --model_path ./trained_models/baseline_rnn/checkpoint/best_checkpoint \
 | |
|     --config_path rnn_args.yaml \
 | |
|     --eval_type test
 | |
| ```
 | |
| 
 | |
| ## TPU v5e-8 Optimizations
 | |
| 
 | |
| ### Hardware-Specific Features
 | |
| - **Mixed Precision**: Automatic bfloat16 conversion for 2x memory efficiency
 | |
| - **XLA Compilation**: Just-in-time compilation for optimal TPU performance
 | |
| - **Distributed Strategy**: Automatic sharding across 8 TPU cores
 | |
| - **Memory Management**: Efficient tensor operations to avoid OOM errors
 | |
| 
 | |
| ### Performance Optimizations
 | |
| - **Batch Matrix Operations**: `tf.linalg.matmul` instead of element-wise operations
 | |
| - **Static Shapes**: Avoiding dynamic tensor shapes for better compilation
 | |
| - **Efficient Gathering**: `tf.gather` for day-specific parameter selection
 | |
| - **Gradient Reversal**: Custom gradient function for adversarial training
 | |
| 
 | |
| ## Configuration
 | |
| 
 | |
| The model uses the same `rnn_args.yaml` configuration as the PyTorch version. Key TPU-specific settings:
 | |
| 
 | |
| ```yaml
 | |
| # TPU-specific settings
 | |
| use_amp: true                    # Enable mixed precision (bfloat16)
 | |
| dataset:
 | |
|   batch_size: 32                # Optimized for TPU memory
 | |
|   num_dataloader_workers: 0     # Disable multiprocessing on TPU
 | |
| 
 | |
| # Model architecture (same as PyTorch)
 | |
| model:
 | |
|   n_input_features: 512         # Neural features per timestep
 | |
|   n_units: 768                  # Hidden units per GRU layer
 | |
|   patch_size: 14                # Temporal patch size
 | |
|   patch_stride: 4               # Patch stride
 | |
| ```
 | |
| 
 | |
| ## Performance Comparison
 | |
| 
 | |
| ### TPU v5e-8 vs Other Hardware
 | |
| - **Memory**: 2x improvement with bfloat16 mixed precision
 | |
| - **Throughput**: ~3-4x faster training than V100 GPU
 | |
| - **Scalability**: Automatic distribution across 8 cores
 | |
| - **Cost Efficiency**: Better performance-per-dollar for large models
 | |
| 
 | |
| ### Expected Training Times (120k batches)
 | |
| - **TPU v5e-8**: ~4-6 hours
 | |
| - **Single V100**: ~15-20 hours
 | |
| - **RTX 4090**: ~12-18 hours
 | |
| 
 | |
| ## Model Architecture Details
 | |
| 
 | |
| ### TripleGRUDecoder Forward Pass
 | |
| ```python
 | |
| # Training mode (adversarial)
 | |
| clean_logits, noisy_logits, noise_output = model(
 | |
|     features, day_indices, mode='full',
 | |
|     grl_lambda=0.5, training=True
 | |
| )
 | |
| 
 | |
| # Inference mode (production)
 | |
| clean_logits = model(
 | |
|     features, day_indices, mode='inference',
 | |
|     training=False
 | |
| )
 | |
| ```
 | |
| 
 | |
| ### Loss Functions
 | |
| ```python
 | |
| # Clean speech CTC loss
 | |
| clean_loss = ctc_loss(clean_logits, labels, input_lengths, label_lengths)
 | |
| 
 | |
| # Adversarial noisy speech loss (with gradient reversal)
 | |
| noisy_loss = ctc_loss(noisy_logits, labels, input_lengths, label_lengths)
 | |
| 
 | |
| # Combined loss
 | |
| total_loss = clean_loss + 0.2 * noisy_loss + 0.001 * noise_l2_loss
 | |
| ```
 | |
| 
 | |
| ## Data Pipeline
 | |
| 
 | |
| ### HDF5 Data Loading
 | |
| The TensorFlow implementation efficiently loads data from HDF5 files:
 | |
| - **Batch creation**: Pre-batched data with padding
 | |
| - **Feature subsets**: Configurable neural feature selection
 | |
| - **Day balancing**: Ensures even representation across recording sessions
 | |
| - **Memory efficiency**: Lazy loading with tf.data.Dataset
 | |
| 
 | |
| ### Data Augmentations
 | |
| - **Gaussian smoothing**: Temporal smoothing of neural signals
 | |
| - **White noise**: Additive Gaussian noise for robustness
 | |
| - **Static gain**: Channel-wise multiplicative noise
 | |
| - **Random walk**: Temporal drift simulation
 | |
| - **Random cutoff**: Variable sequence lengths
 | |
| 
 | |
| ## Troubleshooting
 | |
| 
 | |
| ### Common TPU Issues
 | |
| 
 | |
| #### "Resource exhausted" errors
 | |
| ```bash
 | |
| # Reduce batch size
 | |
| python train_model_tf.py --batch_size 16
 | |
| 
 | |
| # Enable gradient accumulation
 | |
| # Modify config: gradient_accumulation_steps: 4
 | |
| ```
 | |
| 
 | |
| #### TPU not detected
 | |
| ```bash
 | |
| # Check environment variables
 | |
| echo $TPU_NAME
 | |
| echo $COLAB_TPU_ADDR
 | |
| 
 | |
| # Verify TPU access
 | |
| gcloud compute tpus list
 | |
| ```
 | |
| 
 | |
| #### Mixed precision issues
 | |
| ```bash
 | |
| # Disable mixed precision if needed
 | |
| python train_model_tf.py --disable_mixed_precision
 | |
| ```
 | |
| 
 | |
| ### Performance Debugging
 | |
| ```python
 | |
| # Enable XLA logging
 | |
| import os
 | |
| os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'
 | |
| 
 | |
| # Profile TPU usage
 | |
| tf.profiler.experimental.start('logdir')
 | |
| # ... training code ...
 | |
| tf.profiler.experimental.stop()
 | |
| ```
 | |
| 
 | |
| ## Advanced Usage
 | |
| 
 | |
| ### Custom Training Loop
 | |
| ```python
 | |
| from trainer_tf import BrainToTextDecoderTrainerTF
 | |
| 
 | |
| # Initialize trainer
 | |
| trainer = BrainToTextDecoderTrainerTF(config)
 | |
| 
 | |
| # Custom training with checkpointing
 | |
| for epoch in range(num_epochs):
 | |
|     stats = trainer.train()
 | |
|     if epoch % 5 == 0:
 | |
|         trainer._save_checkpoint(f'epoch_{epoch}', epoch)
 | |
| ```
 | |
| 
 | |
| ### Model Inference
 | |
| ```python
 | |
| # Load trained model
 | |
| model = trainer.model
 | |
| model.load_weights('path/to/checkpoint.weights.h5')
 | |
| 
 | |
| # Run inference
 | |
| logits = trainer.inference(features, day_indices, n_time_steps)
 | |
| 
 | |
| # Decode predictions
 | |
| predictions = tf.argmax(logits, axis=-1)
 | |
| ```
 | |
| 
 | |
| ### Hyperparameter Tuning
 | |
| ```python
 | |
| # Grid search over learning rates
 | |
| learning_rates = [0.001, 0.005, 0.01]
 | |
| for lr in learning_rates:
 | |
|     config.lr_max = lr
 | |
|     trainer = BrainToTextDecoderTrainerTF(config)
 | |
|     stats = trainer.train()
 | |
| ```
 | |
| 
 | |
| ## Research and Development
 | |
| 
 | |
| This TensorFlow implementation maintains full compatibility with the published research while providing:
 | |
| 
 | |
| 1. **Reproducible Results**: Same model architecture and training procedures
 | |
| 2. **Hardware Optimization**: TPU-specific performance improvements
 | |
| 3. **Scalability**: Easy scaling to larger models and datasets
 | |
| 4. **Extensibility**: Clean APIs for research modifications
 | |
| 
 | |
| ### Key Research Features
 | |
| - **Adversarial Training**: Domain adaptation through gradient reversal
 | |
| - **Multi-day Learning**: Session-specific input transformations
 | |
| - **Temporal Modeling**: Patch-based sequence processing
 | |
| - **Robust Training**: Comprehensive data augmentation pipeline
 | |
| 
 | |
| ## Citation
 | |
| 
 | |
| If you use this TensorFlow implementation in your research, please cite the original paper:
 | |
| 
 | |
| ```bibtex
 | |
| @article{card2024accurate,
 | |
|   title={An Accurate and Rapidly Calibrating Speech Neuroprosthesis},
 | |
|   author={Card, Nicholas S and others},
 | |
|   journal={New England Journal of Medicine},
 | |
|   year={2024}
 | |
| }
 | |
| ```
 | |
| 
 | |
| ## Support
 | |
| 
 | |
| For questions specific to the TensorFlow implementation:
 | |
| 1. Check this README and the PyTorch documentation in `../CLAUDE.md`
 | |
| 2. Review configuration options in `rnn_args.yaml`
 | |
| 3. Examine example scripts in this directory
 | |
| 4. Open issues on the project repository
 | |
| 
 | |
| For TPU-specific questions, consult Google Cloud TPU documentation and TensorFlow TPU guides. | 
