52 lines
1.3 KiB
Python
52 lines
1.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Quick XLA Model Test
|
|
"""
|
|
|
|
import torch
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from rnn_model import TripleGRUDecoder
|
|
|
|
def quick_test():
|
|
print("Quick XLA model test...")
|
|
|
|
# Small model for fast testing
|
|
model = TripleGRUDecoder(
|
|
neural_dim=64, # Smaller
|
|
n_units=128, # Smaller
|
|
n_days=3, # Smaller
|
|
n_classes=10, # Smaller
|
|
rnn_dropout=0.0,
|
|
input_dropout=0.0,
|
|
patch_size=4, # Smaller
|
|
patch_stride=1
|
|
)
|
|
|
|
model.eval()
|
|
|
|
# Small test data
|
|
batch_size, seq_len = 2, 20
|
|
features = torch.randn(batch_size, seq_len, 64)
|
|
day_indices = torch.tensor([0, 1])
|
|
|
|
print(f"Input shape: {features.shape}")
|
|
print(f"Day indices: {day_indices}")
|
|
|
|
# Test inference
|
|
with torch.no_grad():
|
|
result = model(features, day_indices, mode='inference')
|
|
print(f"Inference result shape: {result.shape}")
|
|
print("✓ Inference mode works")
|
|
|
|
# Test full mode
|
|
clean, noisy, noise = model(features, day_indices, mode='full')
|
|
print(f"Full mode shapes: clean={clean.shape}, noisy={noisy.shape}, noise={noise.shape}")
|
|
print("✓ Full mode works")
|
|
|
|
print("🎉 Quick test passed!")
|
|
|
|
if __name__ == "__main__":
|
|
quick_test() |