remove quick test script for TensorFlow implementation fixes
This commit is contained in:
@@ -1,161 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Quick test to verify TensorFlow implementation fixes
|
|
||||||
This tests the core fixes without requiring external dependencies
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
import tensorflow as tf
|
|
||||||
print("✅ TensorFlow imported successfully")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"❌ TensorFlow import failed: {e}")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
def test_gradient_reversal():
|
|
||||||
"""Test gradient reversal layer fix"""
|
|
||||||
print("\n=== Testing Gradient Reversal Fix ===")
|
|
||||||
try:
|
|
||||||
# Import our fixed gradient reversal function
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
from rnn_model_tf import gradient_reverse
|
|
||||||
|
|
||||||
x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
|
|
||||||
|
|
||||||
# Test forward pass (should be identity)
|
|
||||||
y = gradient_reverse(x, lambd=0.5)
|
|
||||||
|
|
||||||
# Check forward pass
|
|
||||||
if tf.reduce_all(tf.equal(x, y)):
|
|
||||||
print("✅ Gradient reversal forward pass works")
|
|
||||||
|
|
||||||
# Test gradient computation
|
|
||||||
with tf.GradientTape() as tape:
|
|
||||||
tape.watch(x)
|
|
||||||
y = gradient_reverse(x, lambd=0.5)
|
|
||||||
loss = tf.reduce_sum(y)
|
|
||||||
|
|
||||||
grad = tape.gradient(loss, x)
|
|
||||||
expected_grad = -0.5 * tf.ones_like(x)
|
|
||||||
|
|
||||||
if tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6):
|
|
||||||
print("✅ Gradient reversal gradients work correctly")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f"❌ Gradient reversal gradients incorrect: got {grad}, expected {expected_grad}")
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
print("❌ Gradient reversal forward pass failed")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Gradient reversal test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_ctc_loss():
|
|
||||||
"""Test CTC loss fix"""
|
|
||||||
print("\n=== Testing CTC Loss Fix ===")
|
|
||||||
try:
|
|
||||||
from rnn_model_tf import CTCLoss
|
|
||||||
|
|
||||||
ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
|
||||||
|
|
||||||
# Create simple test data
|
|
||||||
batch_size = 2
|
|
||||||
time_steps = 5
|
|
||||||
n_classes = 4
|
|
||||||
|
|
||||||
logits = tf.random.normal((batch_size, time_steps, n_classes))
|
|
||||||
labels = tf.constant([[1, 2, 0, 0], [3, 1, 2, 0]], dtype=tf.int32)
|
|
||||||
input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32)
|
|
||||||
label_lengths = tf.constant([2, 3], dtype=tf.int32)
|
|
||||||
|
|
||||||
loss_input = {
|
|
||||||
'labels': labels,
|
|
||||||
'input_lengths': input_lengths,
|
|
||||||
'label_lengths': label_lengths
|
|
||||||
}
|
|
||||||
|
|
||||||
loss = ctc_loss(loss_input, logits)
|
|
||||||
|
|
||||||
if tf.reduce_all(tf.math.is_finite(loss)) and loss.shape == (batch_size,):
|
|
||||||
print("✅ CTC loss computation works")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f"❌ CTC loss failed: shape {loss.shape}, finite: {tf.reduce_all(tf.math.is_finite(loss))}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ CTC loss test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_basic_model():
|
|
||||||
"""Test basic model creation"""
|
|
||||||
print("\n=== Testing Basic Model Creation ===")
|
|
||||||
try:
|
|
||||||
from rnn_model_tf import TripleGRUDecoder
|
|
||||||
|
|
||||||
model = TripleGRUDecoder(
|
|
||||||
neural_dim=64, # Smaller for testing
|
|
||||||
n_units=32,
|
|
||||||
n_days=2,
|
|
||||||
n_classes=10,
|
|
||||||
rnn_dropout=0.1,
|
|
||||||
input_dropout=0.1,
|
|
||||||
patch_size=2,
|
|
||||||
patch_stride=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test forward pass
|
|
||||||
batch_size = 2
|
|
||||||
time_steps = 10
|
|
||||||
x = tf.random.normal((batch_size, time_steps, 64))
|
|
||||||
day_idx = tf.constant([0, 1], dtype=tf.int32)
|
|
||||||
|
|
||||||
# Test inference mode
|
|
||||||
logits = model(x, day_idx, mode='inference', training=False)
|
|
||||||
expected_time_steps = (time_steps - 2) // 1 + 1
|
|
||||||
|
|
||||||
if logits.shape == (batch_size, expected_time_steps, 10):
|
|
||||||
print("✅ Basic model inference works")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f"❌ Model output shape incorrect: {logits.shape}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Basic model test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all tests"""
|
|
||||||
print("🧪 Testing TensorFlow Implementation Fixes")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
tests = [
|
|
||||||
test_gradient_reversal,
|
|
||||||
test_ctc_loss,
|
|
||||||
test_basic_model
|
|
||||||
]
|
|
||||||
|
|
||||||
passed = 0
|
|
||||||
total = len(tests)
|
|
||||||
|
|
||||||
for test in tests:
|
|
||||||
if test():
|
|
||||||
passed += 1
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print(f"📊 Test Results: {passed}/{total} tests passed")
|
|
||||||
|
|
||||||
if passed == total:
|
|
||||||
print("🎉 All fixes working correctly!")
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
print("❌ Some fixes still need work")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
exit(main())
|
|
Reference in New Issue
Block a user