593 lines
23 KiB
Markdown
593 lines
23 KiB
Markdown
# CLAUDE.md
|
||
|
||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||
|
||
## Project Overview
|
||
|
||
This repository contains the code and data for "An Accurate and Rapidly Calibrating Speech Neuroprosthesis" published in the New England Journal of Medicine (2024). It implements a brain-to-text system that converts neural signals from speech motor cortex into text using RNN models and n-gram language models.
|
||
|
||
## Development Environment Setup
|
||
|
||
### Main Environment (b2txt25)
|
||
```bash
|
||
./setup.sh
|
||
conda activate b2txt25
|
||
```
|
||
|
||
### Language Model Environment (b2txt25_lm)
|
||
```bash
|
||
./setup_lm.sh
|
||
conda activate b2txt25_lm
|
||
```
|
||
|
||
**Important**: The project requires two separate conda environments due to conflicting PyTorch versions:
|
||
- `b2txt25`: PyTorch with CUDA 12.6 for model training/evaluation
|
||
- `b2txt25_lm`: PyTorch 1.13.1 for Kaldi-based n-gram language models
|
||
|
||
### Redis Setup
|
||
Redis is required for inter-process communication. Install on Ubuntu:
|
||
```bash
|
||
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
|
||
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
|
||
sudo apt-get update && sudo apt-get install redis
|
||
sudo systemctl disable redis-server
|
||
```
|
||
|
||
## Architecture Overview
|
||
|
||
### High-Level System Flow
|
||
1. **Neural Data Input**: 512 features (2 per electrode × 256 electrodes) binned at 20ms resolution
|
||
2. **RNN Model**: Converts neural features to phoneme logits via CTC loss
|
||
3. **Language Model**: Decodes phoneme logits to words using n-gram models + OPT rescoring
|
||
4. **Redis Communication**: Coordinates between RNN inference and language model processes
|
||
|
||
### Key Components
|
||
|
||
#### Model Training (`model_training/`)
|
||
- **Core Script**: `train_model.py` (loads config from `rnn_args.yaml`)
|
||
- **Model Architecture**: `rnn_model.py` - 5-layer GRU with 768 hidden units
|
||
- **Trainer**: `rnn_trainer.py` - Custom PyTorch trainer with CTC loss
|
||
- **Evaluation**: `evaluate_model.py` - Inference pipeline with Redis communication
|
||
|
||
#### Language Model (`language_model/`)
|
||
- **Standalone Server**: `language-model-standalone.py` - Redis-based LM server
|
||
- **Kaldi Integration**: Uses custom C++ bindings for efficient n-gram decoding
|
||
- **OPT Rescoring**: Facebook OPT 6.7B for language model rescoring
|
||
- **Build System**: Complex CMake-based build for Kaldi/SRILM integration
|
||
|
||
#### Utilities (`nejm_b2txt_utils/`)
|
||
- **General Utils**: `general_utils.py` - Shared utility functions
|
||
- **Package**: Installed via `setup.py` as `nejm_b2txt_utils`
|
||
|
||
#### Analysis (`analyses/`)
|
||
- **Jupyter Notebooks**: `figure_2.ipynb`, `figure_4.ipynb` for paper figures
|
||
|
||
## Common Development Tasks
|
||
|
||
### Training a Model
|
||
```bash
|
||
conda activate b2txt25
|
||
cd model_training
|
||
python train_model.py
|
||
```
|
||
|
||
### Running Evaluation Pipeline
|
||
1. Start Redis server:
|
||
```bash
|
||
redis-server
|
||
```
|
||
|
||
2. Start language model (separate terminal):
|
||
```bash
|
||
conda activate b2txt25_lm
|
||
python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
|
||
```
|
||
|
||
3. Run evaluation (separate terminal):
|
||
```bash
|
||
conda activate b2txt25
|
||
cd model_training
|
||
python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/hdf5_data_final --eval_type test --gpu_number 1
|
||
```
|
||
|
||
4. Shutdown Redis:
|
||
```bash
|
||
redis-cli shutdown
|
||
```
|
||
|
||
### Building Language Model from Scratch
|
||
```bash
|
||
# Build SRILM (in language_model/srilm-1.7.3/)
|
||
export SRILM=$PWD
|
||
make MAKE_PIC=yes World
|
||
|
||
# Build Kaldi components (in language_model/runtime/server/x86/)
|
||
mkdir build && cd build
|
||
cmake .. && make -j8
|
||
```
|
||
|
||
## Data Structure
|
||
|
||
### Neural Data Format
|
||
- **File Type**: HDF5 files in `data/hdf5_data_final/`
|
||
- **Features**: 512 neural features per 20ms bin:
|
||
- 0-64: ventral 6v threshold crossings
|
||
- 65-128: area 4 threshold crossings
|
||
- 129-192: 55b threshold crossings
|
||
- 193-256: dorsal 6v threshold crossings
|
||
- 257-320: ventral 6v spike band power
|
||
- 321-384: area 4 spike band power
|
||
- 385-448: 55b spike band power
|
||
- 449-512: dorsal 6v spike band power
|
||
|
||
### Data Loading
|
||
Use `load_h5py_file()` in `model_training/evaluate_model_helpers.py` as reference for HDF5 data loading.
|
||
|
||
## Important Notes
|
||
|
||
- **GPU Requirements**: OPT 6.7B requires ~12.4GB VRAM; RTX 4090s recommended
|
||
- **Memory Requirements**: 3-gram LM needs ~60GB RAM, 5-gram needs ~300GB RAM
|
||
- **Environment Isolation**: Always use correct conda environment for each component
|
||
- **Redis Dependency**: Many scripts require Redis server to be running
|
||
- **Build Dependencies**: CMake ≥3.14 and GCC ≥10.1 required for language model builds
|
||
|
||
## XLA Optimizations (TPU-Friendly Model)
|
||
|
||
The RNN model has been optimized for XLA compilation and TPU training while preserving the original model architecture. These optimizations improve compilation speed and reduce memory usage on TPUs.
|
||
|
||
### Applied XLA Optimizations
|
||
|
||
#### 1. Dynamic Shape Operations → Static Operations
|
||
**Problem**: XLA compiler struggles with dynamic tensor shapes and indexing
|
||
**Solution**: Replace dynamic operations with XLA-friendly alternatives
|
||
|
||
```python
|
||
# Before (XLA-unfriendly):
|
||
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
|
||
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
|
||
|
||
# After (XLA-friendly):
|
||
all_day_weights = torch.stack(list(self.day_weights), dim=0) # Static stack
|
||
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0)
|
||
day_weights = torch.index_select(all_day_weights, 0, day_idx) # Static gather
|
||
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
|
||
```
|
||
|
||
#### 2. Matrix Operations → XLA Primitives
|
||
**Problem**: Complex einsum operations are less optimized than native XLA ops
|
||
**Solution**: Use batch matrix multiplication (bmm) for better XLA performance
|
||
|
||
```python
|
||
# Before:
|
||
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
||
|
||
# After (XLA-optimized):
|
||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtype consistency
|
||
```
|
||
|
||
#### 5. Mixed Precision Dtype Consistency (Comprehensive Fix)
|
||
**Problem**: Mixed precision training causes dtype mismatches throughout the adversarial training pipeline
|
||
**Error**: `Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168]`
|
||
|
||
**Root Cause Analysis**: The error occurred at dimension 7168 = 512 * 14, indicating patch processing with patch_size=14. The dtype mismatch cascaded through multiple layers:
|
||
1. Initial bmm operations in day-specific transformations
|
||
2. Adversarial training residual connections between models
|
||
3. Patch processing operations (unfold, permute, reshape)
|
||
4. Gradient Reversal Layer (GRL) operations
|
||
5. Hidden state initialization in adversarial training helper methods
|
||
|
||
**Comprehensive Solution**: Implement dtype consistency across the entire adversarial training data flow:
|
||
|
||
```python
|
||
# Fix 1: Basic bmm operations with dtype consistency
|
||
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||
|
||
# Fix 2: Patch processing with explicit dtype preservation
|
||
if self.patch_size > 0:
|
||
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||
x = x.unsqueeze(1)
|
||
x = x.permute(0, 3, 1, 2)
|
||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||
x_unfold = x_unfold.squeeze(2)
|
||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||
# Ensure dtype consistency after patch processing operations
|
||
x = x.to(original_dtype)
|
||
|
||
# Fix 3: Adversarial training residual connections
|
||
noise_output = noise_output.to(x_processed.dtype)
|
||
denoised_input = x_processed - noise_output
|
||
|
||
# Fix 4: Gradient Reversal Layer dtype handling
|
||
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda else noise_output
|
||
# Ensure dtype consistency after GRL (preserves input dtype but explicit check)
|
||
noisy_input = noisy_input.to(x_processed.dtype)
|
||
|
||
# Fix 5: Hidden state dtype consistency in helper methods
|
||
# In _clean_forward_with_processed_input:
|
||
if states is None:
|
||
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
|
||
# Ensure hidden states match input dtype for mixed precision training
|
||
states = states.to(x_processed.dtype)
|
||
|
||
# In _noisy_forward_with_processed_input:
|
||
if states is None:
|
||
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
|
||
# Ensure hidden states match input dtype for mixed precision training
|
||
states = states.to(x_processed.dtype)
|
||
```
|
||
|
||
**Key Implementation Details**:
|
||
- **GradientReversalFn**: Preserves input dtype automatically (identity forward, gradient reversal backward)
|
||
- **Patch Processing**: Explicit dtype preservation prevents unfold operations from changing precision
|
||
- **Residual Connections**: All tensor arithmetic operations ensure matching dtypes
|
||
- **Helper Methods**: Hidden state initialization matches processed input dtype
|
||
- **Data Flow**: NoiseModel → GRL → NoisySpeechModel maintains dtype consistency throughout
|
||
|
||
#### 3. Hidden State Initialization
|
||
**Problem**: Dynamic batch size allocation causes XLA recompilation
|
||
**Solution**: Use static shapes and avoid x.shape[0] in tensor creation
|
||
|
||
```python
|
||
# Before:
|
||
if states is None:
|
||
states = self.h0.expand(2, x.shape[0], self.input_size).contiguous()
|
||
|
||
# After (XLA-friendly):
|
||
batch_size = x.size(0) # Extract once
|
||
if states is None:
|
||
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
|
||
```
|
||
|
||
#### 4. Return Value Optimization
|
||
**Problem**: Complex dictionary returns cause XLA compilation issues
|
||
**Solution**: Use tuples instead of dictionaries for cleaner XLA graphs
|
||
|
||
```python
|
||
# Before (XLA-unfriendly):
|
||
return {
|
||
'clean_logits': clean_logits,
|
||
'noisy_logits': noisy_logits,
|
||
'noise_output': noise_output
|
||
}
|
||
|
||
# After (XLA-friendly):
|
||
return clean_logits, noisy_logits, noise_output # Simple tuple return
|
||
```
|
||
|
||
### Files Modified for XLA Optimization
|
||
|
||
- **`model_training_nnn/rnn_model.py`**: Comprehensive XLA optimization with dtype consistency
|
||
- **`GradientReversalFn`**: Added adversarial training gradient reversal layer
|
||
- **`NoiseModel.forward()`**: Dynamic indexing → static gather operations + comprehensive dtype consistency + patch processing dtype preservation
|
||
- **`CleanSpeechModel.forward()`**: Same optimizations + bmm for matrix ops + comprehensive dtype consistency + patch processing dtype preservation
|
||
- **`NoisySpeechModel.forward()`**: Hidden state optimization (no day layers, simplified)
|
||
- **`TripleGRUDecoder.forward()`**: Complex return values → tuple returns + comprehensive adversarial training dtype fixes + residual connection dtype consistency + GRL dtype handling
|
||
- **`TripleGRUDecoder._apply_preprocessing()`**: Static preprocessing operations + dtype consistency + patch processing dtype preservation
|
||
- **`TripleGRUDecoder._clean_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
|
||
- **`TripleGRUDecoder._noisy_forward_with_processed_input()`**: Helper method with hidden state dtype consistency for mixed precision
|
||
|
||
**Specific Dtype Consistency Fixes Applied**:
|
||
1. **Basic Operations**: All `torch.bmm()` operations with `.to(x.dtype)` conversions
|
||
2. **Patch Processing**: Explicit dtype preservation through unfold/permute/reshape operations
|
||
3. **Adversarial Training**: Residual connections with `.to(x_processed.dtype)` conversions
|
||
4. **Gradient Reversal**: Dtype consistency after GRL operations
|
||
5. **Hidden States**: All hidden state initialization with `.to(x_processed.dtype)` conversions
|
||
6. **Data Flow**: End-to-end dtype consistency in NoiseModel → GRL → NoisySpeechModel pipeline
|
||
|
||
**Error Resolved**: `f32[32,7168] vs bf16[32,7168]` dtype mismatch in mixed precision TPU training
|
||
|
||
### Benefits of XLA Optimizations
|
||
|
||
1. **Faster Compilation**: Static shapes allow XLA to pre-compile optimized kernels
|
||
2. **Better Memory Usage**: Reduced dynamic allocation during training
|
||
3. **Improved TPU Utilization**: XLA primitives map directly to TPU matrix units
|
||
4. **Consistent Performance**: Eliminates recompilation caused by dynamic shapes
|
||
|
||
### Testing and Validation
|
||
|
||
Created test scripts to verify model consistency:
|
||
- **`test_xla_model.py`**: Comprehensive model validation testing
|
||
- **`quick_test_xla.py`**: Fast verification of basic functionality
|
||
|
||
**Important**: These optimizations preserve the exact model architecture and mathematical operations. Only the implementation has been made XLA-friendly.
|
||
|
||
### Usage Notes
|
||
|
||
- All original model interfaces remain unchanged
|
||
- Both 'inference' and 'full' modes are supported
|
||
- Backward compatibility with existing training scripts is maintained
|
||
- TPU training should now show improved compilation times and memory efficiency
|
||
|
||
### Troubleshooting Dtype Issues in Mixed Precision Training
|
||
|
||
**Common Error Pattern**:
|
||
```
|
||
Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[X,Y], argument shape: bf16[X,Y]
|
||
```
|
||
|
||
**Diagnosis Steps**:
|
||
1. **Identify Operation**: Look at the tensor dimensions to identify which operation is failing
|
||
- `7168 = 512 * 14`: Patch processing operation with patch_size=14
|
||
- `512`: Basic neural features
|
||
- Other patterns may indicate different operations
|
||
|
||
2. **Check Data Flow**: Trace the tensor through the adversarial training pipeline
|
||
- Input → NoiseModel → residual connection → CleanSpeechModel
|
||
- Input → NoiseModel → GRL → NoisySpeechModel
|
||
|
||
3. **Verify Dtype Consistency**: Ensure all operations maintain input dtype
|
||
- Use `.to(x.dtype)` for all operand tensors
|
||
- Preserve dtype through complex operations (unfold, permute, reshape)
|
||
- Match hidden state dtype to input tensor dtype
|
||
|
||
**Quick Fix Template**:
|
||
```python
|
||
# For any tensor operation between tensors a and b:
|
||
result = operation(a, b.to(a.dtype))
|
||
|
||
# For complex operations that might change dtype:
|
||
original_dtype = tensor.dtype
|
||
tensor = complex_operation(tensor)
|
||
tensor = tensor.to(original_dtype)
|
||
|
||
# For hidden state initialization:
|
||
states = states.to(input_tensor.dtype)
|
||
```
|
||
|
||
## PyTorch XLA API Updates and Warnings
|
||
|
||
### Deprecated APIs (as of 2024)
|
||
|
||
**Important**: Several torch_xla APIs have been deprecated and should be updated in new code:
|
||
|
||
#### 1. Device API Changes
|
||
```python
|
||
# ❌ Deprecated (shows DeprecationWarning):
|
||
device = xm.xla_device()
|
||
|
||
# ✅ Modern API:
|
||
import torch_xla
|
||
device = torch_xla.device()
|
||
```
|
||
|
||
#### 2. Synchronization API Changes
|
||
```python
|
||
# ❌ Deprecated (shows DeprecationWarning):
|
||
xm.mark_step()
|
||
|
||
# ✅ Modern API:
|
||
import torch_xla
|
||
torch_xla.sync()
|
||
```
|
||
|
||
#### 3. Mixed Precision Environment Variables
|
||
```python
|
||
# ⚠️ Will be deprecated after PyTorch XLA 2.6:
|
||
os.environ['XLA_USE_BF16'] = '1'
|
||
|
||
# 💡 Recommended: Convert model to bf16 directly in code
|
||
model = model.to(torch.bfloat16)
|
||
```
|
||
|
||
### TPU Performance Warnings
|
||
|
||
#### Transparent Hugepages Warning
|
||
```
|
||
UserWarning: Transparent hugepages are not enabled. TPU runtime startup and
|
||
shutdown time should be significantly improved on TPU v5e and newer.
|
||
```
|
||
|
||
**Solution** (for TPU v5e and newer):
|
||
```bash
|
||
sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled"
|
||
```
|
||
|
||
**Note**: This warning appears on TPU environments and can be safely ignored if you don't have root access (e.g., Kaggle, Colab).
|
||
|
||
### Updated Code Patterns
|
||
|
||
#### Modern XLA Synchronization Pattern
|
||
```python
|
||
import torch_xla.core.xla_model as xm # Still needed for other functions
|
||
import torch_xla
|
||
|
||
# Modern pattern:
|
||
def train_step():
|
||
# ... training code ...
|
||
|
||
# Synchronize every N steps
|
||
if step % sync_frequency == 0:
|
||
torch_xla.sync() # Instead of xm.mark_step()
|
||
|
||
# Legacy pattern (still works but deprecated):
|
||
def train_step_legacy():
|
||
# ... training code ...
|
||
|
||
# Old way (shows deprecation warning)
|
||
if step % sync_frequency == 0:
|
||
xm.mark_step()
|
||
xm.wait_device_ops() # This is still current
|
||
```
|
||
|
||
#### Device Detection Pattern
|
||
```python
|
||
# Modern approach:
|
||
import torch_xla
|
||
|
||
try:
|
||
device = torch_xla.device()
|
||
print(f"Using XLA device: {device}")
|
||
except:
|
||
device = torch.device('cpu')
|
||
print("Falling back to CPU")
|
||
|
||
# Legacy approach (shows warnings):
|
||
import torch_xla.core.xla_model as xm
|
||
|
||
try:
|
||
device = xm.xla_device() # DeprecationWarning
|
||
print(f"Using XLA device: {device}")
|
||
except:
|
||
device = torch.device('cpu')
|
||
```
|
||
|
||
### Migration Guidelines
|
||
|
||
When updating existing code:
|
||
|
||
1. **Replace `xm.xla_device()`** with `torch_xla.device()`
|
||
2. **Replace `xm.mark_step()`** with `torch_xla.sync()`
|
||
3. **Keep `xm.wait_device_ops()`** (still current API)
|
||
4. **Update imports** to include `torch_xla` directly
|
||
5. **Consider explicit bf16 conversion** instead of environment variables
|
||
|
||
### Backward Compatibility
|
||
|
||
The deprecated APIs still work but generate warnings. For production code:
|
||
- Update to modern APIs to avoid warnings
|
||
- Test thoroughly as synchronization behavior may differ slightly
|
||
- Legacy code will continue to function until removed in future versions
|
||
|
||
## TensorFlow TPU Implementation
|
||
|
||
The original PyTorch implementation has been converted to TensorFlow for optimal performance on TPU v5e-8 environments, particularly for the Brain-to-Text '25 Competition on Kaggle.
|
||
|
||
### Key TensorFlow Components (`model_training_nnn_tpu/`)
|
||
|
||
#### Core Files
|
||
- **`rnn_model_tf.py`**: TensorFlow implementation of TripleGRUDecoder architecture
|
||
- `NoiseModel`: 2-layer GRU for noise estimation with day-specific layers
|
||
- `CleanSpeechModel`: 3-layer GRU for clean speech recognition with day-specific layers
|
||
- `NoisySpeechModel`: 2-layer GRU for noisy speech recognition (no day layers)
|
||
- `TripleGRUDecoder`: Main adversarial architecture combining all three models
|
||
- `CTCLoss`: Custom CTC loss implementation for TPU compatibility
|
||
- `create_tpu_strategy()`: Enhanced TPU connection function with robust environment detection
|
||
|
||
- **`trainer_tf.py`**: TensorFlow training pipeline with distributed TPU support
|
||
- **`dataset_tf.py`**: TensorFlow data loading with augmentation pipeline optimized for TPU
|
||
- **`train_model_tf.py`**: Main training script entry point
|
||
- **`evaluate_model_tf.py`**: Evaluation pipeline for model performance analysis
|
||
|
||
### TPU v5e-8 Specific Optimizations
|
||
|
||
#### 1. Enhanced TPU Connection
|
||
The `create_tpu_strategy()` function provides robust TPU detection across different environments:
|
||
|
||
```python
|
||
def create_tpu_strategy():
|
||
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
||
# Multi-environment TPU detection
|
||
if 'COLAB_TPU_ADDR' in os.environ:
|
||
tpu_address = os.environ['COLAB_TPU_ADDR']
|
||
elif 'TPU_NAME' in os.environ:
|
||
tpu_name = os.environ['TPU_NAME']
|
||
elif 'TPU_WORKER_ID' in os.environ:
|
||
# Kaggle TPU environment
|
||
tpu_address = f'grpc://10.0.0.2:8470' # Default Kaggle TPU address
|
||
|
||
# Enhanced error handling and debugging output
|
||
# Fallback to default strategy if TPU connection fails
|
||
```
|
||
|
||
**Environment Variables Detected**:
|
||
- `COLAB_TPU_ADDR`: Google Colab TPU environment
|
||
- `TPU_NAME`: Generic TPU name specification
|
||
- `TPU_WORKER_ID`: Kaggle TPU environment indicator
|
||
|
||
**Troubleshooting TPU Connection Issues**:
|
||
- Error: "Failed to initialize TPU: Please provide a TPU Name to connect to."
|
||
- Solution: The function automatically detects and uses appropriate TPU addresses based on environment
|
||
- Debugging: All TPU-related environment variables are printed during initialization
|
||
|
||
#### 2. Mixed Precision Training
|
||
Configured for optimal TPU v5e-8 performance:
|
||
```python
|
||
def configure_mixed_precision():
|
||
"""Configure mixed precision for optimal TPU v5e-8 performance"""
|
||
policy = keras.mixed_precision.Policy('mixed_bfloat16')
|
||
keras.mixed_precision.set_global_policy(policy)
|
||
```
|
||
|
||
#### 3. XLA-Optimized Operations
|
||
- **Static Tensor Operations**: Using `tf.stack()` and `tf.gather()` instead of dynamic indexing
|
||
- **Efficient Matrix Operations**: `tf.linalg.matmul()` for batch matrix multiplication
|
||
- **TPU-Friendly GRU Layers**: Disabled recurrent dropout for better TPU performance
|
||
- **Patch Processing**: TensorFlow equivalent of PyTorch's unfold using `tf.image.extract_patches()`
|
||
|
||
### Key Architecture Differences from PyTorch
|
||
|
||
#### 1. Gradient Reversal Layer (GRL)
|
||
```python
|
||
@tf.custom_gradient
|
||
def gradient_reverse(x, lambd=1.0):
|
||
"""Gradient Reversal Layer for TensorFlow"""
|
||
def grad(dy):
|
||
return -lambd * dy # Only return gradient w.r.t. x
|
||
return tf.identity(x), grad
|
||
```
|
||
|
||
#### 2. CTC Loss Implementation
|
||
Custom sparse tensor conversion for TPU compatibility:
|
||
```python
|
||
def dense_to_sparse(dense_tensor, sequence_lengths):
|
||
"""Convert dense tensor to sparse tensor for CTC"""
|
||
mask = tf.not_equal(dense_tensor, 0)
|
||
indices = tf.where(mask)
|
||
values = tf.gather_nd(dense_tensor, indices)
|
||
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
|
||
```
|
||
|
||
#### 3. Day-Specific Layers
|
||
Using `add_weight()` for TPU-compatible variable management:
|
||
```python
|
||
for i in range(n_days):
|
||
weight = self.add_weight(
|
||
name=f'day_weight_{i}',
|
||
shape=(neural_dim, neural_dim),
|
||
initializer=tf.keras.initializers.Identity(),
|
||
trainable=True
|
||
)
|
||
```
|
||
|
||
### Training on TPU v5e-8
|
||
|
||
#### Basic Training Command
|
||
```python
|
||
# In Kaggle TPU v5e-8 environment
|
||
python train_model_tf.py
|
||
```
|
||
|
||
#### Expected Output
|
||
```
|
||
🔍 Detecting TPU environment...
|
||
📍 Kaggle TPU detected, worker ID: 0, address: grpc://10.0.0.2:8470
|
||
✅ TPU initialized successfully!
|
||
🎉 Number of TPU cores: 8
|
||
Training on 8 TPU cores # Should show 8 cores, not 1
|
||
```
|
||
|
||
### Performance Benefits
|
||
|
||
1. **Multi-Core Utilization**: Properly configured TPU strategy utilizes all 8 TPU v5e-8 cores
|
||
2. **Mixed Precision**: bfloat16 precision optimized for TPU matrix units
|
||
3. **XLA Compilation**: Static operations enable efficient XLA graph compilation
|
||
4. **Memory Efficiency**: Optimized for TPU memory constraints and batch processing
|
||
|
||
### Common Issues and Solutions
|
||
|
||
#### Issue: "Training on 1 TPU cores" instead of 8
|
||
**Cause**: TPU connection fallback to default strategy
|
||
**Solution**: Enhanced `create_tpu_strategy()` function with environment detection
|
||
**Check**: Verify TPU environment variables are properly set
|
||
|
||
#### Issue: CTC Loss dtype errors
|
||
**Cause**: Mixed precision dtype mismatches
|
||
**Solution**: Explicit dtype casting in `CTCLoss.call()`
|
||
|
||
#### Issue: Gradient Reversal Layer errors
|
||
**Cause**: Incorrect gradient return format
|
||
**Solution**: Return only gradient w.r.t. input tensor, not lambda parameter
|
||
|
||
## Competition Context
|
||
This codebase serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing both PyTorch and TensorFlow reference implementations for neural signal decoding with optimizations for TPU v5e-8 training environments. |