52 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			52 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| """
 | |
| Quick XLA Model Test
 | |
| """
 | |
| 
 | |
| import torch
 | |
| import sys
 | |
| import os
 | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 | |
| 
 | |
| from rnn_model import TripleGRUDecoder
 | |
| 
 | |
| def quick_test():
 | |
|     print("Quick XLA model test...")
 | |
| 
 | |
|     # Small model for fast testing
 | |
|     model = TripleGRUDecoder(
 | |
|         neural_dim=64,  # Smaller
 | |
|         n_units=128,    # Smaller
 | |
|         n_days=3,       # Smaller
 | |
|         n_classes=10,   # Smaller
 | |
|         rnn_dropout=0.0,
 | |
|         input_dropout=0.0,
 | |
|         patch_size=4,   # Smaller
 | |
|         patch_stride=1
 | |
|     )
 | |
| 
 | |
|     model.eval()
 | |
| 
 | |
|     # Small test data
 | |
|     batch_size, seq_len = 2, 20
 | |
|     features = torch.randn(batch_size, seq_len, 64)
 | |
|     day_indices = torch.tensor([0, 1])
 | |
| 
 | |
|     print(f"Input shape: {features.shape}")
 | |
|     print(f"Day indices: {day_indices}")
 | |
| 
 | |
|     # Test inference
 | |
|     with torch.no_grad():
 | |
|         result = model(features, day_indices, mode='inference')
 | |
|         print(f"Inference result shape: {result.shape}")
 | |
|         print("✓ Inference mode works")
 | |
| 
 | |
|         # Test full mode
 | |
|         clean, noisy, noise = model(features, day_indices, mode='full')
 | |
|         print(f"Full mode shapes: clean={clean.shape}, noisy={noisy.shape}, noise={noise.shape}")
 | |
|         print("✓ Full mode works")
 | |
| 
 | |
|     print("🎉 Quick test passed!")
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     quick_test() | 
