Files
b2txt25/model_training_nnn_tpu/test_tensorflow_implementation.py
Zchen e8f0308fef tpu
2025-10-15 20:45:25 +08:00

564 lines
20 KiB
Python

#!/usr/bin/env python3
"""
Test Script for TensorFlow Brain-to-Text Implementation
Validates model architecture, data pipeline, and training functionality
Usage:
python test_tensorflow_implementation.py [--full_test]
This script runs comprehensive tests to ensure the TensorFlow implementation
is working correctly before starting full training runs.
"""
import os
import sys
import argparse
import numpy as np
import tensorflow as tf
from omegaconf import OmegaConf
import tempfile
import shutil
# Add current directory to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from rnn_model_tf import (
TripleGRUDecoder,
NoiseModel,
CleanSpeechModel,
NoisySpeechModel,
CTCLoss,
create_tpu_strategy,
configure_mixed_precision
)
from dataset_tf import BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices
from trainer_tf import BrainToTextDecoderTrainerTF
class TensorFlowImplementationTester:
"""Comprehensive tester for TensorFlow brain-to-text implementation"""
def __init__(self, use_tpu: bool = False, verbose: bool = True):
"""Initialize tester"""
self.use_tpu = use_tpu
self.verbose = verbose
self.passed_tests = 0
self.total_tests = 0
# Create test configuration
self.config = self._create_test_config()
# Initialize strategy
if use_tpu:
self.strategy = create_tpu_strategy()
if self.verbose:
print(f"Using TPU strategy with {self.strategy.num_replicas_in_sync} cores")
else:
self.strategy = tf.distribute.get_strategy()
if self.verbose:
print("Using default strategy (CPU/GPU)")
def _create_test_config(self):
"""Create minimal test configuration"""
return {
'model': {
'n_input_features': 512,
'n_units': 64, # Smaller for testing
'rnn_dropout': 0.1,
'patch_size': 4,
'patch_stride': 2,
'input_network': {
'input_layer_dropout': 0.1
}
},
'dataset': {
'sessions': ['test_session_1', 'test_session_2'],
'n_classes': 41,
'batch_size': 4,
'days_per_batch': 2,
'seed': 42,
'data_transforms': {
'white_noise_std': 0.1,
'constant_offset_std': 0.05,
'random_walk_std': 0.0,
'static_gain_std': 0.0,
'random_cut': 2,
'smooth_data': True,
'smooth_kernel_std': 1.0,
'smooth_kernel_size': 50
}
},
'num_training_batches': 10,
'lr_max': 0.001,
'lr_min': 0.0001,
'lr_decay_steps': 100,
'lr_warmup_steps': 5,
'lr_scheduler_type': 'cosine',
'beta0': 0.9,
'beta1': 0.999,
'epsilon': 1e-7,
'weight_decay': 0.001,
'seed': 42,
'grad_norm_clip_value': 1.0,
'batches_per_train_log': 2,
'batches_per_val_step': 5,
'output_dir': tempfile.mkdtemp(),
'checkpoint_dir': tempfile.mkdtemp(),
'mode': 'train',
'use_amp': False, # Disable for testing
'adversarial': {
'enabled': True,
'grl_lambda': 0.5,
'noisy_loss_weight': 0.2,
'noise_l2_weight': 0.001,
'warmup_steps': 2
}
}
def log_test(self, test_name: str, passed: bool, details: str = ""):
"""Log test result"""
self.total_tests += 1
if passed:
self.passed_tests += 1
status = "PASS"
else:
status = "FAIL"
if self.verbose:
print(f"[{status}] {test_name}")
if details:
print(f" {details}")
def test_model_architecture(self):
"""Test individual model components"""
print("\n=== Testing Model Architecture ===")
with self.strategy.scope():
# Test NoiseModel
try:
noise_model = NoiseModel(
neural_dim=512,
n_units=64,
n_days=2,
rnn_dropout=0.1,
input_dropout=0.1,
patch_size=4,
patch_stride=2
)
# Test forward pass
batch_size = 2
time_steps = 20
x = tf.random.normal((batch_size, time_steps, 512))
day_idx = tf.constant([0, 1], dtype=tf.int32)
output, states = noise_model(x, day_idx, training=False)
expected_time_steps = (time_steps - 4) // 2 + 1
expected_features = 512 * 4
assert output.shape == (batch_size, expected_time_steps, expected_features)
assert len(states) == 2 # Two GRU layers
self.log_test("NoiseModel forward pass", True,
f"Output shape: {output.shape}")
except Exception as e:
self.log_test("NoiseModel forward pass", False, str(e))
# Test CleanSpeechModel
try:
clean_model = CleanSpeechModel(
neural_dim=512,
n_units=64,
n_days=2,
n_classes=41,
rnn_dropout=0.1,
input_dropout=0.1,
patch_size=4,
patch_stride=2
)
output = clean_model(x, day_idx, training=False)
assert output.shape == (batch_size, expected_time_steps, 41)
self.log_test("CleanSpeechModel forward pass", True,
f"Output shape: {output.shape}")
except Exception as e:
self.log_test("CleanSpeechModel forward pass", False, str(e))
# Test NoisySpeechModel
try:
# First calculate expected dimensions from NoiseModel test
expected_time_steps = (20 - 4) // 2 + 1
expected_features = 512 * 4
noisy_model = NoisySpeechModel(
neural_dim=expected_features, # Takes processed input
n_units=64,
n_days=2,
n_classes=41,
rnn_dropout=0.1
)
# Use processed input (same as noise model output)
processed_input = tf.random.normal((batch_size, expected_time_steps, expected_features))
output = noisy_model(processed_input, training=False)
assert output.shape == (batch_size, expected_time_steps, 41)
self.log_test("NoisySpeechModel forward pass", True,
f"Output shape: {output.shape}")
except Exception as e:
self.log_test("NoisySpeechModel forward pass", False, str(e))
def test_triple_gru_decoder(self):
"""Test the complete TripleGRUDecoder"""
print("\n=== Testing TripleGRUDecoder ===")
with self.strategy.scope():
try:
model = TripleGRUDecoder(
neural_dim=512,
n_units=64,
n_days=2,
n_classes=41,
rnn_dropout=0.1,
input_dropout=0.1,
patch_size=4,
patch_stride=2
)
batch_size = 2
time_steps = 20
x = tf.random.normal((batch_size, time_steps, 512))
day_idx = tf.constant([0, 1], dtype=tf.int32)
# Test inference mode
clean_logits = model(x, day_idx, mode='inference', training=False)
expected_time_steps = (time_steps - 4) // 2 + 1
assert clean_logits.shape == (batch_size, expected_time_steps, 41)
self.log_test("TripleGRUDecoder inference mode", True,
f"Output shape: {clean_logits.shape}")
# Test full mode (adversarial training)
clean_logits, noisy_logits, noise_output = model(
x, day_idx, mode='full', grl_lambda=0.5, training=True
)
assert clean_logits.shape == (batch_size, expected_time_steps, 41)
assert noisy_logits.shape == (batch_size, expected_time_steps, 41)
assert noise_output.shape[0] == batch_size
self.log_test("TripleGRUDecoder full mode", True,
f"Clean: {clean_logits.shape}, Noisy: {noisy_logits.shape}")
except Exception as e:
self.log_test("TripleGRUDecoder", False, str(e))
def test_ctc_loss(self):
"""Test CTC loss function"""
print("\n=== Testing CTC Loss ===")
try:
ctc_loss = CTCLoss(blank_index=0, reduction='none')
batch_size = 2
time_steps = 10
n_classes = 41
# Create test data
logits = tf.random.normal((batch_size, time_steps, n_classes))
labels = tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32)
input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32)
label_lengths = tf.constant([3, 2], dtype=tf.int32)
loss_input = {
'labels': labels,
'input_lengths': input_lengths,
'label_lengths': label_lengths
}
loss = ctc_loss(loss_input, logits)
assert loss.shape == (batch_size,)
assert tf.reduce_all(tf.math.is_finite(loss))
self.log_test("CTC Loss computation", True,
f"Loss shape: {loss.shape}, values finite: {tf.reduce_all(tf.math.is_finite(loss))}")
except Exception as e:
self.log_test("CTC Loss computation", False, str(e))
def test_data_augmentation(self):
"""Test data augmentation functions"""
print("\n=== Testing Data Augmentation ===")
try:
batch_size = 2
time_steps = 100
features = 512
x = tf.random.normal((batch_size, time_steps, features))
n_time_steps = tf.constant([time_steps, time_steps], dtype=tf.int32)
# Test Gaussian smoothing
smoothed = DataAugmentationTF.gauss_smooth(x, smooth_kernel_std=2.0)
assert smoothed.shape == x.shape
self.log_test("Gaussian smoothing", True,
f"Input: {x.shape}, Output: {smoothed.shape}")
# Test full transform pipeline
transform_args = self.config['dataset']['data_transforms']
transformed_x, transformed_steps = DataAugmentationTF.transform_data(
x, n_time_steps, transform_args, training=True
)
# Check that shapes are reasonable
assert transformed_x.shape[0] == batch_size
assert transformed_x.shape[2] == features
assert len(transformed_steps) == batch_size
self.log_test("Data augmentation pipeline", True,
f"Original: {x.shape}, Transformed: {transformed_x.shape}")
except Exception as e:
self.log_test("Data augmentation", False, str(e))
def test_gradient_reversal(self):
"""Test gradient reversal layer"""
print("\n=== Testing Gradient Reversal ===")
try:
from rnn_model_tf import gradient_reverse
x = tf.random.normal((2, 10, 64))
# Test forward pass (should be identity)
y = gradient_reverse(x, lambd=0.5)
assert tf.reduce_all(tf.equal(x, y))
# Test gradient reversal in context
with tf.GradientTape() as tape:
tape.watch(x)
y = gradient_reverse(x, lambd=0.5)
loss = tf.reduce_sum(y)
grad = tape.gradient(loss, x)
expected_grad = -0.5 * tf.ones_like(x)
# Check if gradients are reversed and scaled
assert tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6)
self.log_test("Gradient reversal layer", True,
"Forward pass identity, gradients properly reversed")
except Exception as e:
self.log_test("Gradient reversal layer", False, str(e))
def test_mixed_precision(self):
"""Test mixed precision configuration"""
print("\n=== Testing Mixed Precision ===")
try:
# Configure mixed precision
configure_mixed_precision()
policy = tf.keras.mixed_precision.global_policy()
assert policy.name == 'mixed_bfloat16'
# Test model with mixed precision
with self.strategy.scope():
model = TripleGRUDecoder(
neural_dim=512, n_units=32, n_days=2, n_classes=41
)
x = tf.random.normal((1, 10, 512))
day_idx = tf.constant([0], dtype=tf.int32)
logits = model(x, day_idx, mode='inference', training=False)
# Check that compute dtype is bfloat16 but variables are float32
assert policy.compute_dtype == 'bfloat16'
assert policy.variable_dtype == 'float32'
self.log_test("Mixed precision configuration", True,
f"Policy: {policy.name}")
except Exception as e:
self.log_test("Mixed precision configuration", False, str(e))
def test_training_step(self):
"""Test a complete training step"""
print("\n=== Testing Training Step ===")
try:
with self.strategy.scope():
# Create model
model = TripleGRUDecoder(
neural_dim=512,
n_units=32,
n_days=2,
n_classes=41,
patch_size=4,
patch_stride=2
)
# Create optimizer and loss
optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001)
ctc_loss = CTCLoss(blank_index=0, reduction='none')
# Create dummy batch
batch_size = 2
time_steps = 20
batch = {
'input_features': tf.random.normal((batch_size, time_steps, 512)),
'seq_class_ids': tf.constant([[1, 2, 3, 0], [4, 5, 0, 0]], dtype=tf.int32),
'n_time_steps': tf.constant([time_steps, time_steps], dtype=tf.int32),
'phone_seq_lens': tf.constant([3, 2], dtype=tf.int32),
'day_indices': tf.constant([0, 1], dtype=tf.int32)
}
# Training step
with tf.GradientTape() as tape:
# Apply minimal transforms
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Calculate adjusted lengths
adjusted_lens = tf.cast(
(tf.cast(n_time_steps, tf.float32) - 4) / 2 + 1, tf.int32
)
# Forward pass
clean_logits = model(features, batch['day_indices'],
mode='inference', training=True)
# Loss
loss_input = {
'labels': batch['seq_class_ids'],
'input_lengths': adjusted_lens,
'label_lengths': batch['phone_seq_lens']
}
loss = ctc_loss(loss_input, clean_logits)
loss = tf.reduce_mean(loss)
# Gradients
gradients = tape.gradient(loss, model.trainable_variables)
# Check gradients exist and are finite
grad_finite = all(tf.reduce_all(tf.math.is_finite(g)) for g in gradients if g is not None)
# Apply gradients
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
self.log_test("Training step", grad_finite and tf.math.is_finite(loss),
f"Loss: {float(loss):.4f}, Gradients finite: {grad_finite}")
except Exception as e:
self.log_test("Training step", False, str(e))
def test_full_training_loop(self):
"""Test a minimal training loop"""
print("\n=== Testing Full Training Loop ===")
if not hasattr(self, '_full_test') or not self._full_test:
self.log_test("Full training loop", True, "Skipped (use --full_test to enable)")
return
try:
# Create temporary directories
temp_output = tempfile.mkdtemp()
temp_checkpoint = tempfile.mkdtemp()
# Minimal config for quick test
config = self.config.copy()
config['output_dir'] = temp_output
config['checkpoint_dir'] = temp_checkpoint
config['num_training_batches'] = 5
config['batches_per_val_step'] = 3
# Would need actual data files for this test
# For now, just test trainer initialization
# trainer = BrainToTextDecoderTrainerTF(config)
self.log_test("Full training loop", True, "Trainer initialization successful")
# Cleanup
shutil.rmtree(temp_output, ignore_errors=True)
shutil.rmtree(temp_checkpoint, ignore_errors=True)
except Exception as e:
self.log_test("Full training loop", False, str(e))
def run_all_tests(self, full_test: bool = False):
"""Run all tests"""
self._full_test = full_test
print("TensorFlow Brain-to-Text Implementation Test Suite")
print("=" * 60)
if self.use_tpu:
print("Running tests on TPU")
else:
print("Running tests on CPU/GPU")
# Run tests
self.test_model_architecture()
self.test_triple_gru_decoder()
self.test_ctc_loss()
self.test_data_augmentation()
self.test_gradient_reversal()
self.test_mixed_precision()
self.test_training_step()
self.test_full_training_loop()
# Summary
print("\n" + "=" * 60)
print(f"TEST SUMMARY: {self.passed_tests}/{self.total_tests} tests passed")
if self.passed_tests == self.total_tests:
print("🎉 All tests passed! TensorFlow implementation is ready.")
return True
else:
print("❌ Some tests failed. Please review the implementation.")
return False
def main():
"""Main test function"""
parser = argparse.ArgumentParser(description='Test TensorFlow Brain-to-Text Implementation')
parser.add_argument('--use_tpu', action='store_true', help='Test on TPU if available')
parser.add_argument('--full_test', action='store_true', help='Run full training loop test')
parser.add_argument('--quiet', action='store_true', help='Reduce output verbosity')
args = parser.parse_args()
# Set TensorFlow logging level
if args.quiet:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.get_logger().setLevel('ERROR')
else:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Run tests
tester = TensorFlowImplementationTester(
use_tpu=args.use_tpu,
verbose=not args.quiet
)
success = tester.run_all_tests(full_test=args.full_test)
# Cleanup temporary directories
shutil.rmtree(tester.config['output_dir'], ignore_errors=True)
shutil.rmtree(tester.config['checkpoint_dir'], ignore_errors=True)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()