#!/usr/bin/env python3 """ Quick test to verify TensorFlow implementation fixes This tests the core fixes without requiring external dependencies """ try: import tensorflow as tf print("✅ TensorFlow imported successfully") except ImportError as e: print(f"❌ TensorFlow import failed: {e}") exit(1) def test_gradient_reversal(): """Test gradient reversal layer fix""" print("\n=== Testing Gradient Reversal Fix ===") try: # Import our fixed gradient reversal function import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__))) from rnn_model_tf import gradient_reverse x = tf.constant([[1.0, 2.0], [3.0, 4.0]]) # Test forward pass (should be identity) y = gradient_reverse(x, lambd=0.5) # Check forward pass if tf.reduce_all(tf.equal(x, y)): print("✅ Gradient reversal forward pass works") # Test gradient computation with tf.GradientTape() as tape: tape.watch(x) y = gradient_reverse(x, lambd=0.5) loss = tf.reduce_sum(y) grad = tape.gradient(loss, x) expected_grad = -0.5 * tf.ones_like(x) if tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6): print("✅ Gradient reversal gradients work correctly") return True else: print(f"❌ Gradient reversal gradients incorrect: got {grad}, expected {expected_grad}") return False else: print("❌ Gradient reversal forward pass failed") return False except Exception as e: print(f"❌ Gradient reversal test failed: {e}") return False def test_ctc_loss(): """Test CTC loss fix""" print("\n=== Testing CTC Loss Fix ===") try: from rnn_model_tf import CTCLoss ctc_loss = CTCLoss(blank_index=0, reduction='none') # Create simple test data batch_size = 2 time_steps = 5 n_classes = 4 logits = tf.random.normal((batch_size, time_steps, n_classes)) labels = tf.constant([[1, 2, 0, 0], [3, 1, 2, 0]], dtype=tf.int32) input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32) label_lengths = tf.constant([2, 3], dtype=tf.int32) loss_input = { 'labels': labels, 'input_lengths': input_lengths, 'label_lengths': label_lengths } loss = ctc_loss(loss_input, logits) if tf.reduce_all(tf.math.is_finite(loss)) and loss.shape == (batch_size,): print("✅ CTC loss computation works") return True else: print(f"❌ CTC loss failed: shape {loss.shape}, finite: {tf.reduce_all(tf.math.is_finite(loss))}") return False except Exception as e: print(f"❌ CTC loss test failed: {e}") return False def test_basic_model(): """Test basic model creation""" print("\n=== Testing Basic Model Creation ===") try: from rnn_model_tf import TripleGRUDecoder model = TripleGRUDecoder( neural_dim=64, # Smaller for testing n_units=32, n_days=2, n_classes=10, rnn_dropout=0.1, input_dropout=0.1, patch_size=2, patch_stride=1 ) # Test forward pass batch_size = 2 time_steps = 10 x = tf.random.normal((batch_size, time_steps, 64)) day_idx = tf.constant([0, 1], dtype=tf.int32) # Test inference mode logits = model(x, day_idx, mode='inference', training=False) expected_time_steps = (time_steps - 2) // 1 + 1 if logits.shape == (batch_size, expected_time_steps, 10): print("✅ Basic model inference works") return True else: print(f"❌ Model output shape incorrect: {logits.shape}") return False except Exception as e: print(f"❌ Basic model test failed: {e}") return False def main(): """Run all tests""" print("🧪 Testing TensorFlow Implementation Fixes") print("=" * 50) tests = [ test_gradient_reversal, test_ctc_loss, test_basic_model ] passed = 0 total = len(tests) for test in tests: if test(): passed += 1 print("\n" + "=" * 50) print(f"📊 Test Results: {passed}/{total} tests passed") if passed == total: print("🎉 All fixes working correctly!") return 0 else: print("❌ Some fixes still need work") return 1 if __name__ == "__main__": exit(main())