Files
b2txt25/model_training_nnn/quick_test_xla.py
2025-10-12 23:36:16 +08:00

52 lines
1.3 KiB
Python

#!/usr/bin/env python3
"""
Quick XLA Model Test
"""
import torch
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from rnn_model import TripleGRUDecoder
def quick_test():
print("Quick XLA model test...")
# Small model for fast testing
model = TripleGRUDecoder(
neural_dim=64, # Smaller
n_units=128, # Smaller
n_days=3, # Smaller
n_classes=10, # Smaller
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=4, # Smaller
patch_stride=1
)
model.eval()
# Small test data
batch_size, seq_len = 2, 20
features = torch.randn(batch_size, seq_len, 64)
day_indices = torch.tensor([0, 1])
print(f"Input shape: {features.shape}")
print(f"Day indices: {day_indices}")
# Test inference
with torch.no_grad():
result = model(features, day_indices, mode='inference')
print(f"Inference result shape: {result.shape}")
print("✓ Inference mode works")
# Test full mode
clean, noisy, noise = model(features, day_indices, mode='full')
print(f"Full mode shapes: clean={clean.shape}, noisy={noisy.shape}, noise={noise.shape}")
print("✓ Full mode works")
print("🎉 Quick test passed!")
if __name__ == "__main__":
quick_test()