final version? maybe
This commit is contained in:
52
model_training_nnn/quick_test_xla.py
Normal file
52
model_training_nnn/quick_test_xla.py
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick XLA Model Test
|
||||
"""
|
||||
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from rnn_model import TripleGRUDecoder
|
||||
|
||||
def quick_test():
|
||||
print("Quick XLA model test...")
|
||||
|
||||
# Small model for fast testing
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=64, # Smaller
|
||||
n_units=128, # Smaller
|
||||
n_days=3, # Smaller
|
||||
n_classes=10, # Smaller
|
||||
rnn_dropout=0.0,
|
||||
input_dropout=0.0,
|
||||
patch_size=4, # Smaller
|
||||
patch_stride=1
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Small test data
|
||||
batch_size, seq_len = 2, 20
|
||||
features = torch.randn(batch_size, seq_len, 64)
|
||||
day_indices = torch.tensor([0, 1])
|
||||
|
||||
print(f"Input shape: {features.shape}")
|
||||
print(f"Day indices: {day_indices}")
|
||||
|
||||
# Test inference
|
||||
with torch.no_grad():
|
||||
result = model(features, day_indices, mode='inference')
|
||||
print(f"Inference result shape: {result.shape}")
|
||||
print("✓ Inference mode works")
|
||||
|
||||
# Test full mode
|
||||
clean, noisy, noise = model(features, day_indices, mode='full')
|
||||
print(f"Full mode shapes: clean={clean.shape}, noisy={noisy.shape}, noise={noise.shape}")
|
||||
print("✓ Full mode works")
|
||||
|
||||
print("🎉 Quick test passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
quick_test()
|
Reference in New Issue
Block a user