Files
b2txt25/model_training_nnn/test_xla_model.py

154 lines
5.8 KiB
Python
Raw Normal View History

2025-10-12 23:36:16 +08:00
#!/usr/bin/env python3
"""
XLA Model Verification Script
验证XLA优化后的模型输出与原始模型保持一致
"""
import torch
import torch.nn as nn
import sys
import os
# Add the model training directory to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from rnn_model import TripleGRUDecoder
def create_test_data(batch_size=4, seq_len=100, neural_dim=512, n_days=10):
"""Create synthetic test data matching expected model inputs"""
# Create random neural features
features = torch.randn(batch_size, seq_len, neural_dim)
# Create random day indices (should be valid indices < n_days)
day_indices = torch.randint(0, n_days, (batch_size,))
return features, day_indices
def test_model_consistency():
"""Test that XLA-optimized model produces consistent outputs"""
print("Testing XLA-optimized TripleGRUDecoder consistency...")
# Model parameters (matching typical configuration)
neural_dim = 512
n_units = 768
n_days = 10
n_classes = 40 # Typical phoneme count
batch_size = 4
seq_len = 100
patch_size = 14
patch_stride = 1
# Create model
model = TripleGRUDecoder(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=0.0, # Disable dropout for consistent testing
input_dropout=0.0,
patch_size=patch_size,
patch_stride=patch_stride
)
# Set to eval mode for consistent results
model.eval()
# Create test data
features, day_indices = create_test_data(batch_size, seq_len, neural_dim, n_days)
print(f"Test data shapes:")
print(f" Features: {features.shape}")
print(f" Day indices: {day_indices.shape}")
print(f" Day indices values: {day_indices.tolist()}")
# Test inference mode (most commonly used)
print("\n=== Testing Inference Mode ===")
with torch.no_grad():
try:
# Run inference mode
clean_logits = model(features, day_indices, states=None, return_state=False, mode='inference')
print(f"Clean logits shape: {clean_logits.shape}")
print(f"Clean logits range: [{clean_logits.min().item():.4f}, {clean_logits.max().item():.4f}]")
print("✓ Inference mode successful")
# Test with return_state=True
clean_logits_with_state, noise_hidden = model(features, day_indices, states=None, return_state=True, mode='inference')
# Verify consistency
assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent outputs with/without return_state"
print("✓ return_state consistency verified")
except Exception as e:
print(f"✗ Inference mode failed: {e}")
raise
# Test full mode (training)
print("\n=== Testing Full Mode ===")
with torch.no_grad():
try:
# Run full mode
clean_logits, noisy_logits, noise_output = model(features, day_indices, states=None, return_state=False, mode='full')
print(f"Clean logits shape: {clean_logits.shape}")
print(f"Noisy logits shape: {noisy_logits.shape}")
print(f"Noise output shape: {noise_output.shape}")
print("✓ Full mode successful")
# Test with return_state=True
(clean_logits_with_state, noisy_logits_with_state, noise_output_with_state), noise_hidden = model(
features, day_indices, states=None, return_state=True, mode='full')
# Verify consistency
assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent clean logits"
assert torch.allclose(noisy_logits, noisy_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noisy logits"
assert torch.allclose(noise_output, noise_output_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noise output"
print("✓ return_state consistency verified")
except Exception as e:
print(f"✗ Full mode failed: {e}")
raise
# Test multiple runs for consistency
print("\n=== Testing Multiple Run Consistency ===")
with torch.no_grad():
try:
# Run same input multiple times
results = []
for i in range(3):
result = model(features, day_indices, states=None, return_state=False, mode='inference')
results.append(result)
# Verify all runs produce identical results
for i in range(1, len(results)):
assert torch.allclose(results[0], results[i], rtol=1e-7, atol=1e-8), f"Inconsistent results between runs 0 and {i}"
print("✓ Multiple runs produce identical results")
except Exception as e:
print(f"✗ Multiple run consistency failed: {e}")
raise
# Test different batch sizes
print("\n=== Testing Different Batch Sizes ===")
with torch.no_grad():
try:
for test_batch_size in [1, 2, 8]:
test_features, test_day_indices = create_test_data(test_batch_size, seq_len, neural_dim, n_days)
result = model(test_features, test_day_indices, states=None, return_state=False, mode='inference')
expected_shape = (test_batch_size, (seq_len - patch_size) // patch_stride + 1, n_classes)
assert result.shape == expected_shape, f"Unexpected shape for batch_size={test_batch_size}: {result.shape} vs {expected_shape}"
print(f"✓ Batch size {test_batch_size}: {result.shape}")
except Exception as e:
print(f"✗ Batch size testing failed: {e}")
raise
print("\n🎉 All tests passed! XLA-optimized model is working correctly.")
return True
if __name__ == "__main__":
test_model_consistency()