161 lines
4.6 KiB
Python
161 lines
4.6 KiB
Python
![]() |
#!/usr/bin/env python3
|
||
|
"""
|
||
|
Quick test to verify TensorFlow implementation fixes
|
||
|
This tests the core fixes without requiring external dependencies
|
||
|
"""
|
||
|
|
||
|
try:
|
||
|
import tensorflow as tf
|
||
|
print("✅ TensorFlow imported successfully")
|
||
|
except ImportError as e:
|
||
|
print(f"❌ TensorFlow import failed: {e}")
|
||
|
exit(1)
|
||
|
|
||
|
def test_gradient_reversal():
|
||
|
"""Test gradient reversal layer fix"""
|
||
|
print("\n=== Testing Gradient Reversal Fix ===")
|
||
|
try:
|
||
|
# Import our fixed gradient reversal function
|
||
|
import sys
|
||
|
import os
|
||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
|
||
|
from rnn_model_tf import gradient_reverse
|
||
|
|
||
|
x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
|
||
|
|
||
|
# Test forward pass (should be identity)
|
||
|
y = gradient_reverse(x, lambd=0.5)
|
||
|
|
||
|
# Check forward pass
|
||
|
if tf.reduce_all(tf.equal(x, y)):
|
||
|
print("✅ Gradient reversal forward pass works")
|
||
|
|
||
|
# Test gradient computation
|
||
|
with tf.GradientTape() as tape:
|
||
|
tape.watch(x)
|
||
|
y = gradient_reverse(x, lambd=0.5)
|
||
|
loss = tf.reduce_sum(y)
|
||
|
|
||
|
grad = tape.gradient(loss, x)
|
||
|
expected_grad = -0.5 * tf.ones_like(x)
|
||
|
|
||
|
if tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6):
|
||
|
print("✅ Gradient reversal gradients work correctly")
|
||
|
return True
|
||
|
else:
|
||
|
print(f"❌ Gradient reversal gradients incorrect: got {grad}, expected {expected_grad}")
|
||
|
return False
|
||
|
else:
|
||
|
print("❌ Gradient reversal forward pass failed")
|
||
|
return False
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"❌ Gradient reversal test failed: {e}")
|
||
|
return False
|
||
|
|
||
|
def test_ctc_loss():
|
||
|
"""Test CTC loss fix"""
|
||
|
print("\n=== Testing CTC Loss Fix ===")
|
||
|
try:
|
||
|
from rnn_model_tf import CTCLoss
|
||
|
|
||
|
ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||
|
|
||
|
# Create simple test data
|
||
|
batch_size = 2
|
||
|
time_steps = 5
|
||
|
n_classes = 4
|
||
|
|
||
|
logits = tf.random.normal((batch_size, time_steps, n_classes))
|
||
|
labels = tf.constant([[1, 2, 0, 0], [3, 1, 2, 0]], dtype=tf.int32)
|
||
|
input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32)
|
||
|
label_lengths = tf.constant([2, 3], dtype=tf.int32)
|
||
|
|
||
|
loss_input = {
|
||
|
'labels': labels,
|
||
|
'input_lengths': input_lengths,
|
||
|
'label_lengths': label_lengths
|
||
|
}
|
||
|
|
||
|
loss = ctc_loss(loss_input, logits)
|
||
|
|
||
|
if tf.reduce_all(tf.math.is_finite(loss)) and loss.shape == (batch_size,):
|
||
|
print("✅ CTC loss computation works")
|
||
|
return True
|
||
|
else:
|
||
|
print(f"❌ CTC loss failed: shape {loss.shape}, finite: {tf.reduce_all(tf.math.is_finite(loss))}")
|
||
|
return False
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"❌ CTC loss test failed: {e}")
|
||
|
return False
|
||
|
|
||
|
def test_basic_model():
|
||
|
"""Test basic model creation"""
|
||
|
print("\n=== Testing Basic Model Creation ===")
|
||
|
try:
|
||
|
from rnn_model_tf import TripleGRUDecoder
|
||
|
|
||
|
model = TripleGRUDecoder(
|
||
|
neural_dim=64, # Smaller for testing
|
||
|
n_units=32,
|
||
|
n_days=2,
|
||
|
n_classes=10,
|
||
|
rnn_dropout=0.1,
|
||
|
input_dropout=0.1,
|
||
|
patch_size=2,
|
||
|
patch_stride=1
|
||
|
)
|
||
|
|
||
|
# Test forward pass
|
||
|
batch_size = 2
|
||
|
time_steps = 10
|
||
|
x = tf.random.normal((batch_size, time_steps, 64))
|
||
|
day_idx = tf.constant([0, 1], dtype=tf.int32)
|
||
|
|
||
|
# Test inference mode
|
||
|
logits = model(x, day_idx, mode='inference', training=False)
|
||
|
expected_time_steps = (time_steps - 2) // 1 + 1
|
||
|
|
||
|
if logits.shape == (batch_size, expected_time_steps, 10):
|
||
|
print("✅ Basic model inference works")
|
||
|
return True
|
||
|
else:
|
||
|
print(f"❌ Model output shape incorrect: {logits.shape}")
|
||
|
return False
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"❌ Basic model test failed: {e}")
|
||
|
return False
|
||
|
|
||
|
def main():
|
||
|
"""Run all tests"""
|
||
|
print("🧪 Testing TensorFlow Implementation Fixes")
|
||
|
print("=" * 50)
|
||
|
|
||
|
tests = [
|
||
|
test_gradient_reversal,
|
||
|
test_ctc_loss,
|
||
|
test_basic_model
|
||
|
]
|
||
|
|
||
|
passed = 0
|
||
|
total = len(tests)
|
||
|
|
||
|
for test in tests:
|
||
|
if test():
|
||
|
passed += 1
|
||
|
|
||
|
print("\n" + "=" * 50)
|
||
|
print(f"📊 Test Results: {passed}/{total} tests passed")
|
||
|
|
||
|
if passed == total:
|
||
|
print("🎉 All fixes working correctly!")
|
||
|
return 0
|
||
|
else:
|
||
|
print("❌ Some fixes still need work")
|
||
|
return 1
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
exit(main())
|