final version? maybe

This commit is contained in:
Zchen
2025-10-12 23:36:16 +08:00
parent 6cfc568f9a
commit 0d2a0aa8fa
5 changed files with 375 additions and 51 deletions

View File

@@ -0,0 +1,52 @@
#!/usr/bin/env python3
"""
Quick XLA Model Test
"""
import torch
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from rnn_model import TripleGRUDecoder
def quick_test():
print("Quick XLA model test...")
# Small model for fast testing
model = TripleGRUDecoder(
neural_dim=64, # Smaller
n_units=128, # Smaller
n_days=3, # Smaller
n_classes=10, # Smaller
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=4, # Smaller
patch_stride=1
)
model.eval()
# Small test data
batch_size, seq_len = 2, 20
features = torch.randn(batch_size, seq_len, 64)
day_indices = torch.tensor([0, 1])
print(f"Input shape: {features.shape}")
print(f"Day indices: {day_indices}")
# Test inference
with torch.no_grad():
result = model(features, day_indices, mode='inference')
print(f"Inference result shape: {result.shape}")
print("✓ Inference mode works")
# Test full mode
clean, noisy, noise = model(features, day_indices, mode='full')
print(f"Full mode shapes: clean={clean.shape}, noisy={noisy.shape}, noise={noise.shape}")
print("✓ Full mode works")
print("🎉 Quick test passed!")
if __name__ == "__main__":
quick_test()