#!/usr/bin/env python3 """ Test Script for TensorFlow Brain-to-Text Implementation Validates model architecture, data pipeline, and training functionality Usage: python test_tensorflow_implementation.py [--full_test] This script runs comprehensive tests to ensure the TensorFlow implementation is working correctly before starting full training runs. """ import os import sys import argparse import numpy as np import tensorflow as tf from omegaconf import OmegaConf import tempfile import shutil # Add current directory to path sys.path.append(os.path.dirname(os.path.abspath(__file__))) from rnn_model_tf import ( TripleGRUDecoder, NoiseModel, CleanSpeechModel, NoisySpeechModel, CTCLoss, create_tpu_strategy, configure_mixed_precision ) from dataset_tf import BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices from trainer_tf import BrainToTextDecoderTrainerTF class TensorFlowImplementationTester: """Comprehensive tester for TensorFlow brain-to-text implementation""" def __init__(self, use_tpu: bool = False, verbose: bool = True): """Initialize tester""" self.use_tpu = use_tpu self.verbose = verbose self.passed_tests = 0 self.total_tests = 0 # Create test configuration self.config = self._create_test_config() # Initialize strategy if use_tpu: self.strategy = create_tpu_strategy() if self.verbose: print(f"Using TPU strategy with {self.strategy.num_replicas_in_sync} cores") else: self.strategy = tf.distribute.get_strategy() if self.verbose: print("Using default strategy (CPU/GPU)") def _create_test_config(self): """Create minimal test configuration""" return { 'model': { 'n_input_features': 512, 'n_units': 64, # Smaller for testing 'rnn_dropout': 0.1, 'patch_size': 4, 'patch_stride': 2, 'input_network': { 'input_layer_dropout': 0.1 } }, 'dataset': { 'sessions': ['test_session_1', 'test_session_2'], 'n_classes': 41, 'batch_size': 4, 'days_per_batch': 2, 'seed': 42, 'data_transforms': { 'white_noise_std': 0.1, 'constant_offset_std': 0.05, 'random_walk_std': 0.0, 'static_gain_std': 0.0, 'random_cut': 2, 'smooth_data': True, 'smooth_kernel_std': 1.0, 'smooth_kernel_size': 50 } }, 'num_training_batches': 10, 'lr_max': 0.001, 'lr_min': 0.0001, 'lr_decay_steps': 100, 'lr_warmup_steps': 5, 'lr_scheduler_type': 'cosine', 'beta0': 0.9, 'beta1': 0.999, 'epsilon': 1e-7, 'weight_decay': 0.001, 'seed': 42, 'grad_norm_clip_value': 1.0, 'batches_per_train_log': 2, 'batches_per_val_step': 5, 'output_dir': tempfile.mkdtemp(), 'checkpoint_dir': tempfile.mkdtemp(), 'mode': 'train', 'use_amp': False, # Disable for testing 'adversarial': { 'enabled': True, 'grl_lambda': 0.5, 'noisy_loss_weight': 0.2, 'noise_l2_weight': 0.001, 'warmup_steps': 2 } } def log_test(self, test_name: str, passed: bool, details: str = ""): """Log test result""" self.total_tests += 1 if passed: self.passed_tests += 1 status = "PASS" else: status = "FAIL" if self.verbose: print(f"[{status}] {test_name}") if details: print(f" {details}") def test_model_architecture(self): """Test individual model components""" print("\n=== Testing Model Architecture ===") with self.strategy.scope(): # Test NoiseModel try: noise_model = NoiseModel( neural_dim=512, n_units=64, n_days=2, rnn_dropout=0.1, input_dropout=0.1, patch_size=4, patch_stride=2 ) # Test forward pass batch_size = 2 time_steps = 20 x = tf.random.normal((batch_size, time_steps, 512)) day_idx = tf.constant([0, 1], dtype=tf.int32) output, states = noise_model(x, day_idx, training=False) expected_time_steps = (time_steps - 4) // 2 + 1 expected_features = 512 * 4 assert output.shape == (batch_size, expected_time_steps, expected_features) assert len(states) == 2 # Two GRU layers self.log_test("NoiseModel forward pass", True, f"Output shape: {output.shape}") except Exception as e: self.log_test("NoiseModel forward pass", False, str(e)) # Test CleanSpeechModel try: clean_model = CleanSpeechModel( neural_dim=512, n_units=64, n_days=2, n_classes=41, rnn_dropout=0.1, input_dropout=0.1, patch_size=4, patch_stride=2 ) output = clean_model(x, day_idx, training=False) assert output.shape == (batch_size, expected_time_steps, 41) self.log_test("CleanSpeechModel forward pass", True, f"Output shape: {output.shape}") except Exception as e: self.log_test("CleanSpeechModel forward pass", False, str(e)) # Test NoisySpeechModel try: # First calculate expected dimensions from NoiseModel test expected_time_steps = (20 - 4) // 2 + 1 expected_features = 512 * 4 noisy_model = NoisySpeechModel( neural_dim=expected_features, # Takes processed input n_units=64, n_days=2, n_classes=41, rnn_dropout=0.1 ) # Use processed input (same as noise model output) processed_input = tf.random.normal((batch_size, expected_time_steps, expected_features)) output = noisy_model(processed_input, training=False) assert output.shape == (batch_size, expected_time_steps, 41) self.log_test("NoisySpeechModel forward pass", True, f"Output shape: {output.shape}") except Exception as e: self.log_test("NoisySpeechModel forward pass", False, str(e)) def test_triple_gru_decoder(self): """Test the complete TripleGRUDecoder""" print("\n=== Testing TripleGRUDecoder ===") with self.strategy.scope(): try: model = TripleGRUDecoder( neural_dim=512, n_units=64, n_days=2, n_classes=41, rnn_dropout=0.1, input_dropout=0.1, patch_size=4, patch_stride=2 ) batch_size = 2 time_steps = 20 x = tf.random.normal((batch_size, time_steps, 512)) day_idx = tf.constant([0, 1], dtype=tf.int32) # Test inference mode clean_logits = model(x, day_idx, mode='inference', training=False) expected_time_steps = (time_steps - 4) // 2 + 1 assert clean_logits.shape == (batch_size, expected_time_steps, 41) self.log_test("TripleGRUDecoder inference mode", True, f"Output shape: {clean_logits.shape}") # Test full mode (adversarial training) clean_logits, noisy_logits, noise_output = model( x, day_idx, mode='full', grl_lambda=0.5, training=True ) assert clean_logits.shape == (batch_size, expected_time_steps, 41) assert noisy_logits.shape == (batch_size, expected_time_steps, 41) assert noise_output.shape[0] == batch_size self.log_test("TripleGRUDecoder full mode", True, f"Clean: {clean_logits.shape}, Noisy: {noisy_logits.shape}") except Exception as e: self.log_test("TripleGRUDecoder", False, str(e)) def test_ctc_loss(self): """Test CTC loss function""" print("\n=== Testing CTC Loss ===") try: ctc_loss = CTCLoss(blank_index=0, reduction='none') batch_size = 2 time_steps = 10 n_classes = 41 # Create test data logits = tf.random.normal((batch_size, time_steps, n_classes)) labels = tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32) input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32) label_lengths = tf.constant([3, 2], dtype=tf.int32) loss_input = { 'labels': labels, 'input_lengths': input_lengths, 'label_lengths': label_lengths } loss = ctc_loss(loss_input, logits) assert loss.shape == (batch_size,) assert tf.reduce_all(tf.math.is_finite(loss)) self.log_test("CTC Loss computation", True, f"Loss shape: {loss.shape}, values finite: {tf.reduce_all(tf.math.is_finite(loss))}") except Exception as e: self.log_test("CTC Loss computation", False, str(e)) def test_data_augmentation(self): """Test data augmentation functions""" print("\n=== Testing Data Augmentation ===") try: batch_size = 2 time_steps = 100 features = 512 x = tf.random.normal((batch_size, time_steps, features)) n_time_steps = tf.constant([time_steps, time_steps], dtype=tf.int32) # Test Gaussian smoothing smoothed = DataAugmentationTF.gauss_smooth(x, smooth_kernel_std=2.0) assert smoothed.shape == x.shape self.log_test("Gaussian smoothing", True, f"Input: {x.shape}, Output: {smoothed.shape}") # Test full transform pipeline transform_args = self.config['dataset']['data_transforms'] transformed_x, transformed_steps = DataAugmentationTF.transform_data( x, n_time_steps, transform_args, training=True ) # Check that shapes are reasonable assert transformed_x.shape[0] == batch_size assert transformed_x.shape[2] == features assert len(transformed_steps) == batch_size self.log_test("Data augmentation pipeline", True, f"Original: {x.shape}, Transformed: {transformed_x.shape}") except Exception as e: self.log_test("Data augmentation", False, str(e)) def test_gradient_reversal(self): """Test gradient reversal layer""" print("\n=== Testing Gradient Reversal ===") try: from rnn_model_tf import gradient_reverse x = tf.random.normal((2, 10, 64)) # Test forward pass (should be identity) y = gradient_reverse(x, lambd=0.5) assert tf.reduce_all(tf.equal(x, y)) # Test gradient reversal in context with tf.GradientTape() as tape: tape.watch(x) y = gradient_reverse(x, lambd=0.5) loss = tf.reduce_sum(y) grad = tape.gradient(loss, x) expected_grad = -0.5 * tf.ones_like(x) # Check if gradients are reversed and scaled assert tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6) self.log_test("Gradient reversal layer", True, "Forward pass identity, gradients properly reversed") except Exception as e: self.log_test("Gradient reversal layer", False, str(e)) def test_mixed_precision(self): """Test mixed precision configuration""" print("\n=== Testing Mixed Precision ===") try: # Configure mixed precision configure_mixed_precision() policy = tf.keras.mixed_precision.global_policy() assert policy.name == 'mixed_bfloat16' # Test model with mixed precision with self.strategy.scope(): model = TripleGRUDecoder( neural_dim=512, n_units=32, n_days=2, n_classes=41 ) x = tf.random.normal((1, 10, 512)) day_idx = tf.constant([0], dtype=tf.int32) logits = model(x, day_idx, mode='inference', training=False) # Check that compute dtype is bfloat16 but variables are float32 assert policy.compute_dtype == 'bfloat16' assert policy.variable_dtype == 'float32' self.log_test("Mixed precision configuration", True, f"Policy: {policy.name}") except Exception as e: self.log_test("Mixed precision configuration", False, str(e)) def test_training_step(self): """Test a complete training step""" print("\n=== Testing Training Step ===") try: with self.strategy.scope(): # Create model model = TripleGRUDecoder( neural_dim=512, n_units=32, n_days=2, n_classes=41, patch_size=4, patch_stride=2 ) # Create optimizer and loss optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001) ctc_loss = CTCLoss(blank_index=0, reduction='none') # Create dummy batch batch_size = 2 time_steps = 20 batch = { 'input_features': tf.random.normal((batch_size, time_steps, 512)), 'seq_class_ids': tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32), 'n_time_steps': tf.constant([time_steps, time_steps], dtype=tf.int32), 'phone_seq_lens': tf.constant([3, 2], dtype=tf.int32), 'day_indices': tf.constant([0, 1], dtype=tf.int32) } # Training step with tf.GradientTape() as tape: # Apply minimal transforms features = batch['input_features'] n_time_steps = batch['n_time_steps'] # Calculate adjusted lengths adjusted_lens = tf.cast( (tf.cast(n_time_steps, tf.float32) - 4) / 2 + 1, tf.int32 ) # Forward pass clean_logits = model(features, batch['day_indices'], mode='inference', training=True) # Loss loss_input = { 'labels': batch['seq_class_ids'], 'input_lengths': adjusted_lens, 'label_lengths': batch['phone_seq_lens'] } loss = ctc_loss(loss_input, clean_logits) loss = tf.reduce_mean(loss) # Gradients gradients = tape.gradient(loss, model.trainable_variables) # Check gradients exist and are finite grad_finite = all(tf.reduce_all(tf.math.is_finite(g)) for g in gradients if g is not None) # Apply gradients optimizer.apply_gradients(zip(gradients, model.trainable_variables)) self.log_test("Training step", grad_finite and tf.math.is_finite(loss), f"Loss: {float(loss):.4f}, Gradients finite: {grad_finite}") except Exception as e: self.log_test("Training step", False, str(e)) def test_full_training_loop(self): """Test a minimal training loop""" print("\n=== Testing Full Training Loop ===") if not hasattr(self, '_full_test') or not self._full_test: self.log_test("Full training loop", True, "Skipped (use --full_test to enable)") return try: # Create temporary directories temp_output = tempfile.mkdtemp() temp_checkpoint = tempfile.mkdtemp() # Minimal config for quick test config = self.config.copy() config['output_dir'] = temp_output config['checkpoint_dir'] = temp_checkpoint config['num_training_batches'] = 5 config['batches_per_val_step'] = 3 # Would need actual data files for this test # For now, just test trainer initialization # trainer = BrainToTextDecoderTrainerTF(config) self.log_test("Full training loop", True, "Trainer initialization successful") # Cleanup shutil.rmtree(temp_output, ignore_errors=True) shutil.rmtree(temp_checkpoint, ignore_errors=True) except Exception as e: self.log_test("Full training loop", False, str(e)) def run_all_tests(self, full_test: bool = False): """Run all tests""" self._full_test = full_test print("TensorFlow Brain-to-Text Implementation Test Suite") print("=" * 60) if self.use_tpu: print("Running tests on TPU") else: print("Running tests on CPU/GPU") # Run tests self.test_model_architecture() self.test_triple_gru_decoder() self.test_ctc_loss() self.test_data_augmentation() self.test_gradient_reversal() self.test_mixed_precision() self.test_training_step() self.test_full_training_loop() # Summary print("\n" + "=" * 60) print(f"TEST SUMMARY: {self.passed_tests}/{self.total_tests} tests passed") if self.passed_tests == self.total_tests: print("🎉 All tests passed! TensorFlow implementation is ready.") return True else: print("❌ Some tests failed. Please review the implementation.") return False def main(): """Main test function""" parser = argparse.ArgumentParser(description='Test TensorFlow Brain-to-Text Implementation') parser.add_argument('--use_tpu', action='store_true', help='Test on TPU if available') parser.add_argument('--full_test', action='store_true', help='Run full training loop test') parser.add_argument('--quiet', action='store_true', help='Reduce output verbosity') args = parser.parse_args() # Set TensorFlow logging level if args.quiet: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' tf.get_logger().setLevel('ERROR') else: os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Run tests tester = TensorFlowImplementationTester( use_tpu=args.use_tpu, verbose=not args.quiet ) success = tester.run_all_tests(full_test=args.full_test) # Cleanup temporary directories shutil.rmtree(tester.config['output_dir'], ignore_errors=True) shutil.rmtree(tester.config['checkpoint_dir'], ignore_errors=True) sys.exit(0 if success else 1) if __name__ == "__main__": main()