Files
b2txt25/model_training_nnn/test_triple_gru.py

233 lines
7.6 KiB
Python
Raw Normal View History

2025-10-12 09:11:32 +08:00
#!/usr/bin/env python3
"""
Test script for the triple-GRU adversarial architecture
"""
import torch
import sys
import os
# Add the current directory to the path to import rnn_model
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from rnn_model import TripleGRUDecoder, NoiseModel, CleanSpeechModel, NoisySpeechModel
def test_individual_models():
"""Test each individual model first"""
print("=== Testing Individual Models ===")
# Model parameters
neural_dim = 512
n_units = 768
n_days = 5
n_classes = 41
batch_size = 4
seq_len = 100
# Create synthetic input
x = torch.randn(batch_size, seq_len, neural_dim)
day_idx = torch.randint(0, n_days, (batch_size,))
print(f"Input shape: {x.shape}")
print(f"Day indices: {day_idx}")
print()
# Test NoiseModel
print("1. Testing NoiseModel...")
noise_model = NoiseModel(neural_dim, n_units, n_days)
with torch.no_grad():
noise_out, noise_hidden = noise_model(x, day_idx)
print(f" Noise output shape: {noise_out.shape}")
print(f" Noise hidden shape: {noise_hidden.shape}")
print(" ✓ NoiseModel working")
print()
# Test CleanSpeechModel
print("2. Testing CleanSpeechModel...")
clean_model = CleanSpeechModel(neural_dim, n_units, n_days, n_classes)
with torch.no_grad():
clean_logits = clean_model(x, day_idx)
print(f" Clean logits shape: {clean_logits.shape}")
print(" ✓ CleanSpeechModel working")
print()
# Test NoisySpeechModel
print("3. Testing NoisySpeechModel...")
noisy_model = NoisySpeechModel(neural_dim, n_units, n_days, n_classes)
with torch.no_grad():
noisy_logits = noisy_model(noise_out) # Use noise output as input
print(f" Noisy logits shape: {noisy_logits.shape}")
print(" ✓ NoisySpeechModel working")
print()
return True
def test_triple_gru_architecture():
"""Test the complete triple-GRU architecture"""
print("=== Triple-GRU Architecture Test ===")
# Model parameters
neural_dim = 512
n_units = 768
n_days = 5
n_classes = 41
batch_size = 4
seq_len = 100
print(f"Neural dim: {neural_dim}")
print(f"Hidden units: {n_units}")
print(f"Days: {n_days}")
print(f"Classes: {n_classes}")
print()
# Create model instance
model = TripleGRUDecoder(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=0.1,
input_dropout=0.1,
patch_size=0, # Start without patching
patch_stride=0
)
print(f"Model created successfully!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print()
# Create synthetic input data
x = torch.randn(batch_size, seq_len, neural_dim)
day_idx = torch.randint(0, n_days, (batch_size,))
print(f"Input shape: {x.shape}")
print(f"Day indices: {day_idx}")
print()
# Test full training mode
print("Testing full training mode...")
with torch.no_grad():
outputs = model(x, day_idx, mode='full')
print(f"Clean logits shape: {outputs['clean_logits'].shape}")
print(f"Noisy logits shape: {outputs['noisy_logits'].shape}")
print(f"Noise output shape: {outputs['noise_output'].shape}")
# Verify shapes
assert outputs['clean_logits'].shape == (batch_size, seq_len, n_classes), f"Clean logits shape mismatch"
assert outputs['noisy_logits'].shape == (batch_size, seq_len, n_classes), f"Noisy logits shape mismatch"
assert outputs['noise_output'].shape == (batch_size, seq_len, neural_dim), f"Noise output shape mismatch"
print("✓ Full training mode successful")
print()
# Test inference mode
print("Testing inference mode...")
with torch.no_grad():
logits = model(x, day_idx, mode='inference')
print(f"Inference logits shape: {logits.shape}")
assert logits.shape == (batch_size, seq_len, n_classes), f"Inference logits shape mismatch"
print("✓ Inference mode successful")
print()
# Test with patch processing
print("Testing with patch processing...")
model_with_patches = TripleGRUDecoder(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=0.1,
input_dropout=0.1,
patch_size=14,
patch_stride=4
)
with torch.no_grad():
outputs_patches = model_with_patches(x, day_idx, mode='full')
expected_patches = (seq_len - 14) // 4 + 1
print(f"Output shapes with patches:")
print(f" Clean logits: {outputs_patches['clean_logits'].shape}")
print(f" Expected patches: {expected_patches}")
print("✓ Patch processing successful")
print()
# Test gradient flow
print("Testing gradient flow...")
model.train()
x_grad = torch.randn(batch_size, seq_len, neural_dim, requires_grad=True)
# Forward pass in training mode
outputs = model(x_grad, day_idx, mode='full')
# Calculate losses
target = torch.randint(0, n_classes, (batch_size, seq_len))
clean_loss = torch.nn.functional.cross_entropy(
outputs['clean_logits'].reshape(-1, n_classes),
target.reshape(-1)
)
noisy_loss = torch.nn.functional.cross_entropy(
outputs['noisy_logits'].reshape(-1, n_classes),
target.reshape(-1)
)
# Backward passes
clean_loss.backward(retain_graph=True)
noisy_loss.backward()
# Check gradients
clean_has_grad = any(p.grad is not None for p in model.clean_speech_model.parameters())
noisy_has_grad = any(p.grad is not None for p in model.noisy_speech_model.parameters())
noise_has_grad = any(p.grad is not None for p in model.noise_model.parameters())
print(f"Clean model gradients: {'' if clean_has_grad else ''}")
print(f"Noisy model gradients: {'' if noisy_has_grad else ''}")
print(f"Noise model gradients: {'' if noise_has_grad else ''}")
print("✓ Gradient flow test successful")
print()
return True
def test_adversarial_training_simulation():
"""Simulate adversarial training with gradient combination"""
print("=== Adversarial Training Simulation ===")
# Simple test to verify gradient combination logic
model = TripleGRUDecoder(512, 768, 5, 41, rnn_dropout=0.1)
# Create fake gradients
fake_clean_grad = torch.randn(41, 768) # Output layer gradients
fake_noisy_grad = torch.randn(41, 768)
print("Testing gradient combination...")
try:
model.apply_gradient_combination(fake_clean_grad, fake_noisy_grad, learning_rate=1e-3)
print("✓ Gradient combination mechanism working")
except Exception as e:
print(f"✗ Gradient combination failed: {e}")
print()
return True
if __name__ == "__main__":
print("Starting comprehensive tests for Triple-GRU architecture...\n")
# Run all tests
test_individual_models()
test_triple_gru_architecture()
test_adversarial_training_simulation()
print("=== All tests completed! ===")
print()
# Print architecture summary
print("=== Triple-GRU Architecture Summary ===")
print("Training Mode Data Flow:")
print("1. Input → NoiseModel → Noise Estimation")
print("2. Input - Noise → CleanSpeechModel → Clean Recognition")
print("3. Noise → NoisySpeechModel → Noisy Recognition")
print("4. Gradient Combination: -Clean_grad + Noisy_grad → NoiseModel")
print()
print("Inference Mode Data Flow:")
print("1. Input → NoiseModel → Noise Estimation")
print("2. Input - Noise → CleanSpeechModel → Final Recognition")
print()