final version? maybe

This commit is contained in:
Zchen
2025-10-12 23:36:16 +08:00
parent 6cfc568f9a
commit 0d2a0aa8fa
5 changed files with 375 additions and 51 deletions

View File

@@ -0,0 +1,52 @@
#!/usr/bin/env python3
"""
Quick XLA Model Test
"""
import torch
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from rnn_model import TripleGRUDecoder
def quick_test():
print("Quick XLA model test...")
# Small model for fast testing
model = TripleGRUDecoder(
neural_dim=64, # Smaller
n_units=128, # Smaller
n_days=3, # Smaller
n_classes=10, # Smaller
rnn_dropout=0.0,
input_dropout=0.0,
patch_size=4, # Smaller
patch_stride=1
)
model.eval()
# Small test data
batch_size, seq_len = 2, 20
features = torch.randn(batch_size, seq_len, 64)
day_indices = torch.tensor([0, 1])
print(f"Input shape: {features.shape}")
print(f"Day indices: {day_indices}")
# Test inference
with torch.no_grad():
result = model(features, day_indices, mode='inference')
print(f"Inference result shape: {result.shape}")
print("✓ Inference mode works")
# Test full mode
clean, noisy, noise = model(features, day_indices, mode='full')
print(f"Full mode shapes: clean={clean.shape}, noisy={noisy.shape}, noise={noise.shape}")
print("✓ Full mode works")
print("🎉 Quick test passed!")
if __name__ == "__main__":
quick_test()

View File

@@ -56,28 +56,37 @@ class NoiseModel(nn.Module):
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.input_size)))
def forward(self, x, day_idx, states=None):
# Apply day-specific transformation
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
batch_size = x.size(0)
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
# Stack all day weights and biases upfront for static indexing
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
x = torch.bmm(x, day_weights) + day_biases
x = self.day_layer_activation(x)
# XLA-friendly conditional dropout
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# Apply patch processing if enabled
# Apply patch processing if enabled (keep conditional for now, optimize later)
if self.patch_size > 0:
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Initialize hidden states
# XLA-friendly hidden state initialization - avoid dynamic allocation
if states is None:
states = self.h0.expand(2, x.shape[0], self.input_size).contiguous()
states = self.h0.expand(2, batch_size, self.input_size).contiguous()
# GRU forward pass
output, hidden_states = self.gru(x, states)
@@ -146,11 +155,19 @@ class CleanSpeechModel(nn.Module):
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
def forward(self, x, day_idx, states=None, return_state=False):
# Apply day-specific transformation
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
batch_size = x.size(0)
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
# Stack all day weights and biases upfront for static indexing
all_day_weights = torch.stack(list(self.day_weights), dim=0) # [n_days, neural_dim, neural_dim]
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0) # [n_days, neural_dim]
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx) # [batch_size, neural_dim, neural_dim]
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1) # [batch_size, 1, neural_dim]
# Use bmm (batch matrix multiply) which is highly optimized in XLA
x = torch.bmm(x, day_weights) + day_biases
x = self.day_layer_activation(x)
if self.input_dropout > 0:
@@ -163,11 +180,11 @@ class CleanSpeechModel(nn.Module):
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Initialize hidden states
# XLA-friendly hidden state initialization
if states is None:
states = self.h0.expand(3, x.shape[0], self.n_units).contiguous()
states = self.h0.expand(3, batch_size, self.n_units).contiguous()
# GRU forward pass
output, hidden_states = self.gru(x, states)
@@ -235,10 +252,11 @@ class NoisySpeechModel(nn.Module):
def forward(self, x, states=None, return_state=False):
# Note: NoisySpeechModel doesn't need day-specific layers as it processes noise
batch_size = x.size(0)
# Initialize hidden states
# XLA-friendly hidden state initialization
if states is None:
states = self.h0.expand(2, x.shape[0], self.n_units).contiguous()
states = self.h0.expand(2, batch_size, self.n_units).contiguous()
# GRU forward pass
output, hidden_states = self.gru(x, states)
@@ -329,30 +347,39 @@ class TripleGRUDecoder(nn.Module):
self.training_mode = 'full' # 'full', 'inference'
def _apply_preprocessing(self, x, day_idx):
'''Apply day-specific transformation and patch processing to match what models expect'''
# Apply day-specific transformation (same as in each model)
day_weights = torch.stack([self.clean_speech_model.day_weights[i] for i in day_idx], dim=0)
day_biases = torch.cat([self.clean_speech_model.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
'''XLA-friendly preprocessing with static operations'''
batch_size = x.size(0)
x_processed = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
# XLA-friendly day-specific transformation using gather instead of dynamic indexing
all_day_weights = torch.stack(list(self.clean_speech_model.day_weights), dim=0)
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.clean_speech_model.day_biases], dim=0)
# XLA-friendly gather operation
day_weights = torch.index_select(all_day_weights, 0, day_idx)
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)
# Use bmm (batch matrix multiply) which is highly optimized in XLA
x_processed = torch.bmm(x, day_weights) + day_biases
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
# Apply patch processing if enabled (same as in each model)
# Apply patch processing if enabled
if self.patch_size > 0:
x_processed = x_processed.unsqueeze(1)
x_processed = x_processed.permute(0, 3, 1, 2)
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x_processed = x_unfold.reshape(x_processed.size(0), x_unfold.size(1), -1)
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
return x_processed
def _clean_forward_with_processed_input(self, x_processed, day_idx, states=None):
'''Forward pass for CleanSpeechModel with already processed input (bypasses day layers and patching)'''
# Initialize hidden states
batch_size = x_processed.size(0)
# XLA-friendly hidden state initialization
if states is None:
states = self.clean_speech_model.h0.expand(3, x_processed.shape[0], self.clean_speech_model.n_units).contiguous()
states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
# GRU forward pass (skip preprocessing since input is already processed)
output, hidden_states = self.clean_speech_model.gru(x_processed, states)
@@ -363,9 +390,11 @@ class TripleGRUDecoder(nn.Module):
def _noisy_forward_with_processed_input(self, x_processed, states=None):
'''Forward pass for NoisySpeechModel with already processed input'''
# Initialize hidden states
batch_size = x_processed.size(0)
# XLA-friendly hidden state initialization
if states is None:
states = self.noisy_speech_model.h0.expand(2, x_processed.shape[0], self.noisy_speech_model.n_units).contiguous()
states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
# GRU forward pass (NoisySpeechModel doesn't have day layers anyway)
output, hidden_states = self.noisy_speech_model.gru(x_processed, states)
@@ -407,23 +436,10 @@ class TripleGRUDecoder(nn.Module):
noisy_logits = self._noisy_forward_with_processed_input(noise_output,
states['noisy'] if states else None)
# XLA-friendly return - use tuple instead of dict for better compilation
if return_state:
return_states = {
'noise': noise_hidden,
'clean': None, # CleanSpeechModel doesn't return hidden states in this call
'noisy': None # NoisySpeechModel doesn't return hidden states in this call
}
return {
'clean_logits': clean_logits,
'noisy_logits': noisy_logits,
'noise_output': noise_output
}, return_states
return {
'clean_logits': clean_logits,
'noisy_logits': noisy_logits,
'noise_output': noise_output
}
return (clean_logits, noisy_logits, noise_output), noise_hidden
return clean_logits, noisy_logits, noise_output
elif mode == 'inference':
# Inference mode: only noise model + clean speech model
@@ -440,13 +456,9 @@ class TripleGRUDecoder(nn.Module):
clean_logits = self._clean_forward_with_processed_input(denoised_input, day_idx,
states['clean'] if states else None)
# XLA-friendly return - use tuple for consistency
if return_state:
return_states = {
'noise': noise_hidden,
'clean': None
}
return clean_logits, return_states
return clean_logits, noise_hidden
return clean_logits
else:

View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python3
"""
XLA Model Verification Script
验证XLA优化后的模型输出与原始模型保持一致
"""
import torch
import torch.nn as nn
import sys
import os
# Add the model training directory to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from rnn_model import TripleGRUDecoder
def create_test_data(batch_size=4, seq_len=100, neural_dim=512, n_days=10):
"""Create synthetic test data matching expected model inputs"""
# Create random neural features
features = torch.randn(batch_size, seq_len, neural_dim)
# Create random day indices (should be valid indices < n_days)
day_indices = torch.randint(0, n_days, (batch_size,))
return features, day_indices
def test_model_consistency():
"""Test that XLA-optimized model produces consistent outputs"""
print("Testing XLA-optimized TripleGRUDecoder consistency...")
# Model parameters (matching typical configuration)
neural_dim = 512
n_units = 768
n_days = 10
n_classes = 40 # Typical phoneme count
batch_size = 4
seq_len = 100
patch_size = 14
patch_stride = 1
# Create model
model = TripleGRUDecoder(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=0.0, # Disable dropout for consistent testing
input_dropout=0.0,
patch_size=patch_size,
patch_stride=patch_stride
)
# Set to eval mode for consistent results
model.eval()
# Create test data
features, day_indices = create_test_data(batch_size, seq_len, neural_dim, n_days)
print(f"Test data shapes:")
print(f" Features: {features.shape}")
print(f" Day indices: {day_indices.shape}")
print(f" Day indices values: {day_indices.tolist()}")
# Test inference mode (most commonly used)
print("\n=== Testing Inference Mode ===")
with torch.no_grad():
try:
# Run inference mode
clean_logits = model(features, day_indices, states=None, return_state=False, mode='inference')
print(f"Clean logits shape: {clean_logits.shape}")
print(f"Clean logits range: [{clean_logits.min().item():.4f}, {clean_logits.max().item():.4f}]")
print("✓ Inference mode successful")
# Test with return_state=True
clean_logits_with_state, noise_hidden = model(features, day_indices, states=None, return_state=True, mode='inference')
# Verify consistency
assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent outputs with/without return_state"
print("✓ return_state consistency verified")
except Exception as e:
print(f"✗ Inference mode failed: {e}")
raise
# Test full mode (training)
print("\n=== Testing Full Mode ===")
with torch.no_grad():
try:
# Run full mode
clean_logits, noisy_logits, noise_output = model(features, day_indices, states=None, return_state=False, mode='full')
print(f"Clean logits shape: {clean_logits.shape}")
print(f"Noisy logits shape: {noisy_logits.shape}")
print(f"Noise output shape: {noise_output.shape}")
print("✓ Full mode successful")
# Test with return_state=True
(clean_logits_with_state, noisy_logits_with_state, noise_output_with_state), noise_hidden = model(
features, day_indices, states=None, return_state=True, mode='full')
# Verify consistency
assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent clean logits"
assert torch.allclose(noisy_logits, noisy_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noisy logits"
assert torch.allclose(noise_output, noise_output_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noise output"
print("✓ return_state consistency verified")
except Exception as e:
print(f"✗ Full mode failed: {e}")
raise
# Test multiple runs for consistency
print("\n=== Testing Multiple Run Consistency ===")
with torch.no_grad():
try:
# Run same input multiple times
results = []
for i in range(3):
result = model(features, day_indices, states=None, return_state=False, mode='inference')
results.append(result)
# Verify all runs produce identical results
for i in range(1, len(results)):
assert torch.allclose(results[0], results[i], rtol=1e-7, atol=1e-8), f"Inconsistent results between runs 0 and {i}"
print("✓ Multiple runs produce identical results")
except Exception as e:
print(f"✗ Multiple run consistency failed: {e}")
raise
# Test different batch sizes
print("\n=== Testing Different Batch Sizes ===")
with torch.no_grad():
try:
for test_batch_size in [1, 2, 8]:
test_features, test_day_indices = create_test_data(test_batch_size, seq_len, neural_dim, n_days)
result = model(test_features, test_day_indices, states=None, return_state=False, mode='inference')
expected_shape = (test_batch_size, (seq_len - patch_size) // patch_stride + 1, n_classes)
assert result.shape == expected_shape, f"Unexpected shape for batch_size={test_batch_size}: {result.shape} vs {expected_shape}"
print(f"✓ Batch size {test_batch_size}: {result.shape}")
except Exception as e:
print(f"✗ Batch size testing failed: {e}")
raise
print("\n🎉 All tests passed! XLA-optimized model is working correctly.")
return True
if __name__ == "__main__":
test_model_consistency()