161 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			161 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | Quick test to verify TensorFlow implementation fixes | ||
|  | This tests the core fixes without requiring external dependencies | ||
|  | """
 | ||
|  | 
 | ||
|  | try: | ||
|  |     import tensorflow as tf | ||
|  |     print("✅ TensorFlow imported successfully") | ||
|  | except ImportError as e: | ||
|  |     print(f"❌ TensorFlow import failed: {e}") | ||
|  |     exit(1) | ||
|  | 
 | ||
|  | def test_gradient_reversal(): | ||
|  |     """Test gradient reversal layer fix""" | ||
|  |     print("\n=== Testing Gradient Reversal Fix ===") | ||
|  |     try: | ||
|  |         # Import our fixed gradient reversal function | ||
|  |         import sys | ||
|  |         import os | ||
|  |         sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|  | 
 | ||
|  |         from rnn_model_tf import gradient_reverse | ||
|  | 
 | ||
|  |         x = tf.constant([[1.0, 2.0], [3.0, 4.0]]) | ||
|  | 
 | ||
|  |         # Test forward pass (should be identity) | ||
|  |         y = gradient_reverse(x, lambd=0.5) | ||
|  | 
 | ||
|  |         # Check forward pass | ||
|  |         if tf.reduce_all(tf.equal(x, y)): | ||
|  |             print("✅ Gradient reversal forward pass works") | ||
|  | 
 | ||
|  |             # Test gradient computation | ||
|  |             with tf.GradientTape() as tape: | ||
|  |                 tape.watch(x) | ||
|  |                 y = gradient_reverse(x, lambd=0.5) | ||
|  |                 loss = tf.reduce_sum(y) | ||
|  | 
 | ||
|  |             grad = tape.gradient(loss, x) | ||
|  |             expected_grad = -0.5 * tf.ones_like(x) | ||
|  | 
 | ||
|  |             if tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6): | ||
|  |                 print("✅ Gradient reversal gradients work correctly") | ||
|  |                 return True | ||
|  |             else: | ||
|  |                 print(f"❌ Gradient reversal gradients incorrect: got {grad}, expected {expected_grad}") | ||
|  |                 return False | ||
|  |         else: | ||
|  |             print("❌ Gradient reversal forward pass failed") | ||
|  |             return False | ||
|  | 
 | ||
|  |     except Exception as e: | ||
|  |         print(f"❌ Gradient reversal test failed: {e}") | ||
|  |         return False | ||
|  | 
 | ||
|  | def test_ctc_loss(): | ||
|  |     """Test CTC loss fix""" | ||
|  |     print("\n=== Testing CTC Loss Fix ===") | ||
|  |     try: | ||
|  |         from rnn_model_tf import CTCLoss | ||
|  | 
 | ||
|  |         ctc_loss = CTCLoss(blank_index=0, reduction='none') | ||
|  | 
 | ||
|  |         # Create simple test data | ||
|  |         batch_size = 2 | ||
|  |         time_steps = 5 | ||
|  |         n_classes = 4 | ||
|  | 
 | ||
|  |         logits = tf.random.normal((batch_size, time_steps, n_classes)) | ||
|  |         labels = tf.constant([[1, 2, 0, 0], [3, 1, 2, 0]], dtype=tf.int32) | ||
|  |         input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32) | ||
|  |         label_lengths = tf.constant([2, 3], dtype=tf.int32) | ||
|  | 
 | ||
|  |         loss_input = { | ||
|  |             'labels': labels, | ||
|  |             'input_lengths': input_lengths, | ||
|  |             'label_lengths': label_lengths | ||
|  |         } | ||
|  | 
 | ||
|  |         loss = ctc_loss(loss_input, logits) | ||
|  | 
 | ||
|  |         if tf.reduce_all(tf.math.is_finite(loss)) and loss.shape == (batch_size,): | ||
|  |             print("✅ CTC loss computation works") | ||
|  |             return True | ||
|  |         else: | ||
|  |             print(f"❌ CTC loss failed: shape {loss.shape}, finite: {tf.reduce_all(tf.math.is_finite(loss))}") | ||
|  |             return False | ||
|  | 
 | ||
|  |     except Exception as e: | ||
|  |         print(f"❌ CTC loss test failed: {e}") | ||
|  |         return False | ||
|  | 
 | ||
|  | def test_basic_model(): | ||
|  |     """Test basic model creation""" | ||
|  |     print("\n=== Testing Basic Model Creation ===") | ||
|  |     try: | ||
|  |         from rnn_model_tf import TripleGRUDecoder | ||
|  | 
 | ||
|  |         model = TripleGRUDecoder( | ||
|  |             neural_dim=64,  # Smaller for testing | ||
|  |             n_units=32, | ||
|  |             n_days=2, | ||
|  |             n_classes=10, | ||
|  |             rnn_dropout=0.1, | ||
|  |             input_dropout=0.1, | ||
|  |             patch_size=2, | ||
|  |             patch_stride=1 | ||
|  |         ) | ||
|  | 
 | ||
|  |         # Test forward pass | ||
|  |         batch_size = 2 | ||
|  |         time_steps = 10 | ||
|  |         x = tf.random.normal((batch_size, time_steps, 64)) | ||
|  |         day_idx = tf.constant([0, 1], dtype=tf.int32) | ||
|  | 
 | ||
|  |         # Test inference mode | ||
|  |         logits = model(x, day_idx, mode='inference', training=False) | ||
|  |         expected_time_steps = (time_steps - 2) // 1 + 1 | ||
|  | 
 | ||
|  |         if logits.shape == (batch_size, expected_time_steps, 10): | ||
|  |             print("✅ Basic model inference works") | ||
|  |             return True | ||
|  |         else: | ||
|  |             print(f"❌ Model output shape incorrect: {logits.shape}") | ||
|  |             return False | ||
|  | 
 | ||
|  |     except Exception as e: | ||
|  |         print(f"❌ Basic model test failed: {e}") | ||
|  |         return False | ||
|  | 
 | ||
|  | def main(): | ||
|  |     """Run all tests""" | ||
|  |     print("🧪 Testing TensorFlow Implementation Fixes") | ||
|  |     print("=" * 50) | ||
|  | 
 | ||
|  |     tests = [ | ||
|  |         test_gradient_reversal, | ||
|  |         test_ctc_loss, | ||
|  |         test_basic_model | ||
|  |     ] | ||
|  | 
 | ||
|  |     passed = 0 | ||
|  |     total = len(tests) | ||
|  | 
 | ||
|  |     for test in tests: | ||
|  |         if test(): | ||
|  |             passed += 1 | ||
|  | 
 | ||
|  |     print("\n" + "=" * 50) | ||
|  |     print(f"📊 Test Results: {passed}/{total} tests passed") | ||
|  | 
 | ||
|  |     if passed == total: | ||
|  |         print("🎉 All fixes working correctly!") | ||
|  |         return 0 | ||
|  |     else: | ||
|  |         print("❌ Some fixes still need work") | ||
|  |         return 1 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     exit(main()) |