233 lines
7.6 KiB
Python
233 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for the triple-GRU adversarial architecture
|
|
"""
|
|
|
|
import torch
|
|
import sys
|
|
import os
|
|
|
|
# Add the current directory to the path to import rnn_model
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from rnn_model import TripleGRUDecoder, NoiseModel, CleanSpeechModel, NoisySpeechModel
|
|
|
|
def test_individual_models():
|
|
"""Test each individual model first"""
|
|
print("=== Testing Individual Models ===")
|
|
|
|
# Model parameters
|
|
neural_dim = 512
|
|
n_units = 768
|
|
n_days = 5
|
|
n_classes = 41
|
|
batch_size = 4
|
|
seq_len = 100
|
|
|
|
# Create synthetic input
|
|
x = torch.randn(batch_size, seq_len, neural_dim)
|
|
day_idx = torch.randint(0, n_days, (batch_size,))
|
|
|
|
print(f"Input shape: {x.shape}")
|
|
print(f"Day indices: {day_idx}")
|
|
print()
|
|
|
|
# Test NoiseModel
|
|
print("1. Testing NoiseModel...")
|
|
noise_model = NoiseModel(neural_dim, n_units, n_days)
|
|
with torch.no_grad():
|
|
noise_out, noise_hidden = noise_model(x, day_idx)
|
|
print(f" Noise output shape: {noise_out.shape}")
|
|
print(f" Noise hidden shape: {noise_hidden.shape}")
|
|
print(" ✓ NoiseModel working")
|
|
print()
|
|
|
|
# Test CleanSpeechModel
|
|
print("2. Testing CleanSpeechModel...")
|
|
clean_model = CleanSpeechModel(neural_dim, n_units, n_days, n_classes)
|
|
with torch.no_grad():
|
|
clean_logits = clean_model(x, day_idx)
|
|
print(f" Clean logits shape: {clean_logits.shape}")
|
|
print(" ✓ CleanSpeechModel working")
|
|
print()
|
|
|
|
# Test NoisySpeechModel
|
|
print("3. Testing NoisySpeechModel...")
|
|
noisy_model = NoisySpeechModel(neural_dim, n_units, n_days, n_classes)
|
|
with torch.no_grad():
|
|
noisy_logits = noisy_model(noise_out) # Use noise output as input
|
|
print(f" Noisy logits shape: {noisy_logits.shape}")
|
|
print(" ✓ NoisySpeechModel working")
|
|
print()
|
|
|
|
return True
|
|
|
|
def test_triple_gru_architecture():
|
|
"""Test the complete triple-GRU architecture"""
|
|
print("=== Triple-GRU Architecture Test ===")
|
|
|
|
# Model parameters
|
|
neural_dim = 512
|
|
n_units = 768
|
|
n_days = 5
|
|
n_classes = 41
|
|
batch_size = 4
|
|
seq_len = 100
|
|
|
|
print(f"Neural dim: {neural_dim}")
|
|
print(f"Hidden units: {n_units}")
|
|
print(f"Days: {n_days}")
|
|
print(f"Classes: {n_classes}")
|
|
print()
|
|
|
|
# Create model instance
|
|
model = TripleGRUDecoder(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
n_classes=n_classes,
|
|
rnn_dropout=0.1,
|
|
input_dropout=0.1,
|
|
patch_size=0, # Start without patching
|
|
patch_stride=0
|
|
)
|
|
|
|
print(f"Model created successfully!")
|
|
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
print()
|
|
|
|
# Create synthetic input data
|
|
x = torch.randn(batch_size, seq_len, neural_dim)
|
|
day_idx = torch.randint(0, n_days, (batch_size,))
|
|
|
|
print(f"Input shape: {x.shape}")
|
|
print(f"Day indices: {day_idx}")
|
|
print()
|
|
|
|
# Test full training mode
|
|
print("Testing full training mode...")
|
|
with torch.no_grad():
|
|
outputs = model(x, day_idx, mode='full')
|
|
|
|
print(f"Clean logits shape: {outputs['clean_logits'].shape}")
|
|
print(f"Noisy logits shape: {outputs['noisy_logits'].shape}")
|
|
print(f"Noise output shape: {outputs['noise_output'].shape}")
|
|
|
|
# Verify shapes
|
|
assert outputs['clean_logits'].shape == (batch_size, seq_len, n_classes), f"Clean logits shape mismatch"
|
|
assert outputs['noisy_logits'].shape == (batch_size, seq_len, n_classes), f"Noisy logits shape mismatch"
|
|
assert outputs['noise_output'].shape == (batch_size, seq_len, neural_dim), f"Noise output shape mismatch"
|
|
print("✓ Full training mode successful")
|
|
print()
|
|
|
|
# Test inference mode
|
|
print("Testing inference mode...")
|
|
with torch.no_grad():
|
|
logits = model(x, day_idx, mode='inference')
|
|
print(f"Inference logits shape: {logits.shape}")
|
|
assert logits.shape == (batch_size, seq_len, n_classes), f"Inference logits shape mismatch"
|
|
print("✓ Inference mode successful")
|
|
print()
|
|
|
|
# Test with patch processing
|
|
print("Testing with patch processing...")
|
|
model_with_patches = TripleGRUDecoder(
|
|
neural_dim=neural_dim,
|
|
n_units=n_units,
|
|
n_days=n_days,
|
|
n_classes=n_classes,
|
|
rnn_dropout=0.1,
|
|
input_dropout=0.1,
|
|
patch_size=14,
|
|
patch_stride=4
|
|
)
|
|
|
|
with torch.no_grad():
|
|
outputs_patches = model_with_patches(x, day_idx, mode='full')
|
|
expected_patches = (seq_len - 14) // 4 + 1
|
|
print(f"Output shapes with patches:")
|
|
print(f" Clean logits: {outputs_patches['clean_logits'].shape}")
|
|
print(f" Expected patches: {expected_patches}")
|
|
print("✓ Patch processing successful")
|
|
print()
|
|
|
|
# Test gradient flow
|
|
print("Testing gradient flow...")
|
|
model.train()
|
|
x_grad = torch.randn(batch_size, seq_len, neural_dim, requires_grad=True)
|
|
|
|
# Forward pass in training mode
|
|
outputs = model(x_grad, day_idx, mode='full')
|
|
|
|
# Calculate losses
|
|
target = torch.randint(0, n_classes, (batch_size, seq_len))
|
|
clean_loss = torch.nn.functional.cross_entropy(
|
|
outputs['clean_logits'].reshape(-1, n_classes),
|
|
target.reshape(-1)
|
|
)
|
|
noisy_loss = torch.nn.functional.cross_entropy(
|
|
outputs['noisy_logits'].reshape(-1, n_classes),
|
|
target.reshape(-1)
|
|
)
|
|
|
|
# Backward passes
|
|
clean_loss.backward(retain_graph=True)
|
|
noisy_loss.backward()
|
|
|
|
# Check gradients
|
|
clean_has_grad = any(p.grad is not None for p in model.clean_speech_model.parameters())
|
|
noisy_has_grad = any(p.grad is not None for p in model.noisy_speech_model.parameters())
|
|
noise_has_grad = any(p.grad is not None for p in model.noise_model.parameters())
|
|
|
|
print(f"Clean model gradients: {'✓' if clean_has_grad else '✗'}")
|
|
print(f"Noisy model gradients: {'✓' if noisy_has_grad else '✗'}")
|
|
print(f"Noise model gradients: {'✓' if noise_has_grad else '✗'}")
|
|
print("✓ Gradient flow test successful")
|
|
print()
|
|
|
|
return True
|
|
|
|
def test_adversarial_training_simulation():
|
|
"""Simulate adversarial training with gradient combination"""
|
|
print("=== Adversarial Training Simulation ===")
|
|
|
|
# Simple test to verify gradient combination logic
|
|
model = TripleGRUDecoder(512, 768, 5, 41, rnn_dropout=0.1)
|
|
|
|
# Create fake gradients
|
|
fake_clean_grad = torch.randn(41, 768) # Output layer gradients
|
|
fake_noisy_grad = torch.randn(41, 768)
|
|
|
|
print("Testing gradient combination...")
|
|
try:
|
|
model.apply_gradient_combination(fake_clean_grad, fake_noisy_grad, learning_rate=1e-3)
|
|
print("✓ Gradient combination mechanism working")
|
|
except Exception as e:
|
|
print(f"✗ Gradient combination failed: {e}")
|
|
|
|
print()
|
|
return True
|
|
|
|
if __name__ == "__main__":
|
|
print("Starting comprehensive tests for Triple-GRU architecture...\n")
|
|
|
|
# Run all tests
|
|
test_individual_models()
|
|
test_triple_gru_architecture()
|
|
test_adversarial_training_simulation()
|
|
|
|
print("=== All tests completed! ===")
|
|
print()
|
|
|
|
# Print architecture summary
|
|
print("=== Triple-GRU Architecture Summary ===")
|
|
print("Training Mode Data Flow:")
|
|
print("1. Input → NoiseModel → Noise Estimation")
|
|
print("2. Input - Noise → CleanSpeechModel → Clean Recognition")
|
|
print("3. Noise → NoisySpeechModel → Noisy Recognition")
|
|
print("4. Gradient Combination: -Clean_grad + Noisy_grad → NoiseModel")
|
|
print()
|
|
print("Inference Mode Data Flow:")
|
|
print("1. Input → NoiseModel → Noise Estimation")
|
|
print("2. Input - Noise → CleanSpeechModel → Final Recognition")
|
|
print() |