remove quick test script for TensorFlow implementation fixes

This commit is contained in:
Zchen
2025-10-16 23:05:53 +08:00
parent 7efa33d730
commit 9453b70fad

View File

@@ -1,161 +0,0 @@
#!/usr/bin/env python3
"""
Quick test to verify TensorFlow implementation fixes
This tests the core fixes without requiring external dependencies
"""
try:
import tensorflow as tf
print("✅ TensorFlow imported successfully")
except ImportError as e:
print(f"❌ TensorFlow import failed: {e}")
exit(1)
def test_gradient_reversal():
"""Test gradient reversal layer fix"""
print("\n=== Testing Gradient Reversal Fix ===")
try:
# Import our fixed gradient reversal function
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from rnn_model_tf import gradient_reverse
x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
# Test forward pass (should be identity)
y = gradient_reverse(x, lambd=0.5)
# Check forward pass
if tf.reduce_all(tf.equal(x, y)):
print("✅ Gradient reversal forward pass works")
# Test gradient computation
with tf.GradientTape() as tape:
tape.watch(x)
y = gradient_reverse(x, lambd=0.5)
loss = tf.reduce_sum(y)
grad = tape.gradient(loss, x)
expected_grad = -0.5 * tf.ones_like(x)
if tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6):
print("✅ Gradient reversal gradients work correctly")
return True
else:
print(f"❌ Gradient reversal gradients incorrect: got {grad}, expected {expected_grad}")
return False
else:
print("❌ Gradient reversal forward pass failed")
return False
except Exception as e:
print(f"❌ Gradient reversal test failed: {e}")
return False
def test_ctc_loss():
"""Test CTC loss fix"""
print("\n=== Testing CTC Loss Fix ===")
try:
from rnn_model_tf import CTCLoss
ctc_loss = CTCLoss(blank_index=0, reduction='none')
# Create simple test data
batch_size = 2
time_steps = 5
n_classes = 4
logits = tf.random.normal((batch_size, time_steps, n_classes))
labels = tf.constant([[1, 2, 0, 0], [3, 1, 2, 0]], dtype=tf.int32)
input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32)
label_lengths = tf.constant([2, 3], dtype=tf.int32)
loss_input = {
'labels': labels,
'input_lengths': input_lengths,
'label_lengths': label_lengths
}
loss = ctc_loss(loss_input, logits)
if tf.reduce_all(tf.math.is_finite(loss)) and loss.shape == (batch_size,):
print("✅ CTC loss computation works")
return True
else:
print(f"❌ CTC loss failed: shape {loss.shape}, finite: {tf.reduce_all(tf.math.is_finite(loss))}")
return False
except Exception as e:
print(f"❌ CTC loss test failed: {e}")
return False
def test_basic_model():
"""Test basic model creation"""
print("\n=== Testing Basic Model Creation ===")
try:
from rnn_model_tf import TripleGRUDecoder
model = TripleGRUDecoder(
neural_dim=64, # Smaller for testing
n_units=32,
n_days=2,
n_classes=10,
rnn_dropout=0.1,
input_dropout=0.1,
patch_size=2,
patch_stride=1
)
# Test forward pass
batch_size = 2
time_steps = 10
x = tf.random.normal((batch_size, time_steps, 64))
day_idx = tf.constant([0, 1], dtype=tf.int32)
# Test inference mode
logits = model(x, day_idx, mode='inference', training=False)
expected_time_steps = (time_steps - 2) // 1 + 1
if logits.shape == (batch_size, expected_time_steps, 10):
print("✅ Basic model inference works")
return True
else:
print(f"❌ Model output shape incorrect: {logits.shape}")
return False
except Exception as e:
print(f"❌ Basic model test failed: {e}")
return False
def main():
"""Run all tests"""
print("🧪 Testing TensorFlow Implementation Fixes")
print("=" * 50)
tests = [
test_gradient_reversal,
test_ctc_loss,
test_basic_model
]
passed = 0
total = len(tests)
for test in tests:
if test():
passed += 1
print("\n" + "=" * 50)
print(f"📊 Test Results: {passed}/{total} tests passed")
if passed == total:
print("🎉 All fixes working correctly!")
return 0
else:
print("❌ Some fixes still need work")
return 1
if __name__ == "__main__":
exit(main())