139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for the dual-GRU architecture
|
|
"""
|
|
|
|
import torch
|
|
import sys
|
|
import os
|
|
|
|
# Add the current directory to the path to import rnn_model
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from rnn_model import GRUDecoder
|
|
|
|
def test_dual_gru_architecture():
|
|
"""Test the dual-GRU model with synthetic data"""
|
|
|
|
# Model parameters (matching the original configuration)
|
|
neural_dim = 512
|
|
n_units = 768
|
|
n_days = 5
|
|
n_classes = 41
|
|
batch_size = 4
|
|
seq_len = 100
|
|
|
|
print("=== Dual-GRU Architecture Test ===")
|
|
print(f"Neural dim: {neural_dim}")
|
|
print(f"Hidden units: {n_units}")
|
|
print(f"Regression GRU layers: 2")
|
|
print(f"Residual GRU layers: 3")
|
|
print(f"Classes: {n_classes}")
|
|
print()
|
|
|
|
# Create model instance
|
|
model = GRUDecoder(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
n_classes=n_classes,
|
|
rnn_dropout=0.1,
|
|
input_dropout=0.1,
|
|
n_layers_regression=2,
|
|
n_layers_residual=3,
|
|
patch_size=0, # Start without patching for simpler test
|
|
patch_stride=0
|
|
)
|
|
|
|
print(f"Model created successfully!")
|
|
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
print()
|
|
|
|
# Create synthetic input data
|
|
x = torch.randn(batch_size, seq_len, neural_dim)
|
|
day_idx = torch.randint(0, n_days, (batch_size,))
|
|
|
|
print(f"Input shape: {x.shape}")
|
|
print(f"Day indices: {day_idx}")
|
|
print()
|
|
|
|
# Test forward pass without states
|
|
print("Testing forward pass (without states)...")
|
|
with torch.no_grad():
|
|
logits = model(x, day_idx)
|
|
print(f"Output logits shape: {logits.shape}")
|
|
print(f"Expected shape: [{batch_size}, {seq_len}, {n_classes}]")
|
|
assert logits.shape == (batch_size, seq_len, n_classes), f"Shape mismatch! Got {logits.shape}"
|
|
print("✓ Forward pass successful")
|
|
print()
|
|
|
|
# Test forward pass with state return
|
|
print("Testing forward pass (with state return)...")
|
|
with torch.no_grad():
|
|
logits, states = model(x, day_idx, return_state=True)
|
|
regression_states, residual_states = states
|
|
print(f"Output logits shape: {logits.shape}")
|
|
print(f"Regression states shape: {regression_states.shape}")
|
|
print(f"Residual states shape: {residual_states.shape}")
|
|
print("✓ Forward pass with states successful")
|
|
print()
|
|
|
|
# Test with patch processing
|
|
print("Testing with patch processing...")
|
|
model_with_patches = GRUDecoder(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
n_classes=n_classes,
|
|
rnn_dropout=0.1,
|
|
input_dropout=0.1,
|
|
n_layers_regression=2,
|
|
n_layers_residual=3,
|
|
patch_size=14,
|
|
patch_stride=4
|
|
)
|
|
|
|
with torch.no_grad():
|
|
logits_patches = model_with_patches(x, day_idx)
|
|
expected_patches = (seq_len - 14) // 4 + 1 # Number of patches
|
|
print(f"Output logits shape (with patches): {logits_patches.shape}")
|
|
print(f"Expected patches: {expected_patches}")
|
|
print("✓ Patch processing successful")
|
|
print()
|
|
|
|
# Test gradient flow
|
|
print("Testing gradient flow...")
|
|
model.train()
|
|
x_grad = torch.randn(batch_size, seq_len, neural_dim, requires_grad=True)
|
|
logits = model(x_grad, day_idx)
|
|
loss = logits.sum()
|
|
loss.backward()
|
|
|
|
# Check if gradients exist
|
|
regression_grad_exists = any(p.grad is not None for p in model.gru_regression.parameters())
|
|
residual_grad_exists = any(p.grad is not None for p in model.gru_residual.parameters())
|
|
day_grad_exists = any(p.grad is not None for p in model.day_weights)
|
|
|
|
print(f"Regression GRU gradients: {'✓' if regression_grad_exists else '✗'}")
|
|
print(f"Residual GRU gradients: {'✓' if residual_grad_exists else '✗'}")
|
|
print(f"Day-specific layer gradients: {'✓' if day_grad_exists else '✗'}")
|
|
print("✓ Gradient flow test successful")
|
|
print()
|
|
|
|
print("=== All tests passed! ===")
|
|
print()
|
|
|
|
# Print architecture summary
|
|
print("=== Architecture Summary ===")
|
|
print("Data Flow:")
|
|
print("1. Input → Day-specific layers (512 → 512)")
|
|
print("2. Day output → Regression GRU (2 layers, 512 hidden)")
|
|
print("3. Residual = Day output - Regression output")
|
|
print("4. Residual → Residual GRU (3 layers, 768 hidden)")
|
|
print("5. Residual GRU output → Linear classifier (768 → 41)")
|
|
print()
|
|
|
|
return True
|
|
|
|
if __name__ == "__main__":
|
|
test_dual_gru_architecture() |