#!/usr/bin/env python3 """ Quick XLA Model Test """ import torch import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__))) from rnn_model import TripleGRUDecoder def quick_test(): print("Quick XLA model test...") # Small model for fast testing model = TripleGRUDecoder( neural_dim=64, # Smaller n_units=128, # Smaller n_days=3, # Smaller n_classes=10, # Smaller rnn_dropout=0.0, input_dropout=0.0, patch_size=4, # Smaller patch_stride=1 ) model.eval() # Small test data batch_size, seq_len = 2, 20 features = torch.randn(batch_size, seq_len, 64) day_indices = torch.tensor([0, 1]) print(f"Input shape: {features.shape}") print(f"Day indices: {day_indices}") # Test inference with torch.no_grad(): result = model(features, day_indices, mode='inference') print(f"Inference result shape: {result.shape}") print("✓ Inference mode works") # Test full mode clean, noisy, noise = model(features, day_indices, mode='full') print(f"Full mode shapes: clean={clean.shape}, noisy={noisy.shape}, noise={noise.shape}") print("✓ Full mode works") print("🎉 Quick test passed!") if __name__ == "__main__": quick_test()