Files
b2txt25/TTA-E/evaluate_model.py

1063 lines
45 KiB
Python
Raw Normal View History

2025-10-06 15:17:44 +08:00
import os
import sys
import torch
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
2025-10-12 09:11:32 +08:00
from collections import defaultdict, Counter
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
2025-10-06 15:17:44 +08:00
# Add parent directories to path to import models
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
from model_training.rnn_model import GRUDecoder
from model_training_lstm.rnn_model import LSTMDecoder
from model_training.evaluate_model_helpers import *
2025-10-12 09:11:32 +08:00
# ========== COMPREHENSIVE EVALUATION METRICS FUNCTIONS ==========
def calculate_phoneme_operations(true_seq, pred_seq):
"""Calculate insertions, deletions, and substitutions using dynamic programming."""
m, n = len(true_seq), len(pred_seq)
# DP matrix for edit distance calculation
dp = [[0] * (n + 1) for _ in range(m + 1)]
# Initialize base cases
for i in range(m + 1):
dp[i][0] = i # deletions
for j in range(n + 1):
dp[0][j] = j # insertions
# Fill DP matrix
for i in range(1, m + 1):
for j in range(1, n + 1):
if true_seq[i-1] == pred_seq[j-1]:
dp[i][j] = dp[i-1][j-1] # match
else:
dp[i][j] = 1 + min(
dp[i-1][j], # deletion
dp[i][j-1], # insertion
dp[i-1][j-1] # substitution
)
# Backtrack to count operations
i, j = m, n
insertions = deletions = substitutions = 0
while i > 0 or j > 0:
if i > 0 and j > 0 and true_seq[i-1] == pred_seq[j-1]:
i -= 1
j -= 1
elif i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + 1:
substitutions += 1
i -= 1
j -= 1
elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
deletions += 1
i -= 1
else:
insertions += 1
j -= 1
return insertions, deletions, substitutions
def calculate_phoneme_level_metrics(true_seqs, pred_seqs):
"""Calculate phoneme-level precision, recall, and F1-score using alignment."""
all_true_phonemes = []
all_pred_phonemes = []
# Align sequences properly to handle different lengths
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
# Use dynamic programming alignment similar to edit distance
# For simplicity, we'll pad shorter sequences and align based on minimum length
min_len = min(len(true_seq), len(pred_seq))
max_len = max(len(true_seq), len(pred_seq))
# Take only the overlapping part for fair comparison
if min_len > 0:
# Align by taking proportional samples
if len(true_seq) == len(pred_seq):
# Same length - direct alignment
all_true_phonemes.extend(true_seq)
all_pred_phonemes.extend(pred_seq)
else:
# Different lengths - align by truncating to shorter length
truncated_true = true_seq[:min_len]
truncated_pred = pred_seq[:min_len]
all_true_phonemes.extend(truncated_true)
all_pred_phonemes.extend(truncated_pred)
# Verify lengths match
if len(all_true_phonemes) != len(all_pred_phonemes):
print(f"Warning: After alignment, lengths still don't match: {len(all_true_phonemes)} vs {len(all_pred_phonemes)}")
# Take minimum length to avoid sklearn error
min_len = min(len(all_true_phonemes), len(all_pred_phonemes))
all_true_phonemes = all_true_phonemes[:min_len]
all_pred_phonemes = all_pred_phonemes[:min_len]
if len(all_true_phonemes) == 0:
# Handle empty case
return {
'phoneme_metrics': {
'phonemes': [],
'precision': np.array([]),
'recall': np.array([]),
'f1': np.array([]),
'support': np.array([])
},
'macro_avg': {
'precision': 0.0,
'recall': 0.0,
'f1': 0.0
},
'micro_avg': {
'precision': 0.0,
'recall': 0.0,
'f1': 0.0
}
}
# Get unique phonemes
unique_phonemes = sorted(list(set(all_true_phonemes + all_pred_phonemes)))
# Calculate metrics for each phoneme class
precision, recall, f1, support = precision_recall_fscore_support(
all_true_phonemes, all_pred_phonemes,
labels=unique_phonemes, average=None, zero_division=0
)
# Calculate macro averages
macro_precision = np.mean(precision)
macro_recall = np.mean(recall)
macro_f1 = np.mean(f1)
# Calculate micro averages
micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
all_true_phonemes, all_pred_phonemes, average='micro', zero_division=0
)
return {
'phoneme_metrics': {
'phonemes': unique_phonemes,
'precision': precision,
'recall': recall,
'f1': f1,
'support': support
},
'macro_avg': {
'precision': macro_precision,
'recall': macro_recall,
'f1': macro_f1
},
'micro_avg': {
'precision': micro_precision,
'recall': micro_recall,
'f1': micro_f1
}
}
def calculate_sequence_level_metrics(true_seqs, pred_seqs):
"""Calculate sequence-level metrics."""
exact_matches = 0
total_sequences = len(true_seqs)
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
if true_seq == pred_seq:
exact_matches += 1
exact_match_accuracy = exact_matches / total_sequences if total_sequences > 0 else 0
return {
'exact_match_accuracy': exact_match_accuracy,
'exact_matches': exact_matches,
'total_sequences': total_sequences
}
def calculate_confidence_intervals(errors, confidence=0.95):
"""Calculate confidence intervals for error rates."""
n = len(errors)
if n == 0:
return 0, 0, 0
mean_error = np.mean(errors)
std_error = np.std(errors, ddof=1) if n > 1 else 0
# Use t-distribution for small samples, normal for large samples
if n < 30:
from scipy import stats
t_val = stats.t.ppf((1 + confidence) / 2, n - 1)
margin = t_val * std_error / np.sqrt(n)
else:
# For large samples, use normal distribution
z_val = 1.96 if confidence == 0.95 else 2.576 # 95% or 99%
margin = z_val * std_error / np.sqrt(n)
return mean_error, mean_error - margin, mean_error + margin
def calculate_phoneme_frequency_analysis(true_seqs, pred_seqs):
"""Analyze phoneme frequency and error patterns."""
true_phoneme_counts = Counter()
pred_phoneme_counts = Counter()
error_patterns = defaultdict(int) # Maps (true_phoneme, pred_phoneme) -> count
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
true_phoneme_counts.update(true_seq)
pred_phoneme_counts.update(pred_seq)
# For error pattern analysis, align sequences by taking minimum length
min_len = min(len(true_seq), len(pred_seq))
if min_len > 0:
for i in range(min_len):
true_p = true_seq[i]
pred_p = pred_seq[i]
if true_p != pred_p:
error_patterns[(true_p, pred_p)] += 1
return {
'true_phoneme_counts': true_phoneme_counts,
'pred_phoneme_counts': pred_phoneme_counts,
'error_patterns': dict(error_patterns)
}
def generate_confusion_matrix_plot(true_seqs, pred_seqs, save_path=None):
"""Generate and save confusion matrix plot for phonemes."""
all_true = []
all_pred = []
# Align sequences properly to handle different lengths
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
min_len = min(len(true_seq), len(pred_seq))
if min_len > 0:
# Truncate both to same length for fair comparison
all_true.extend(true_seq[:min_len])
all_pred.extend(pred_seq[:min_len])
if len(all_true) == 0 or len(all_pred) == 0:
print("Warning: No aligned phonemes found for confusion matrix")
return None, []
unique_phonemes = sorted(list(set(all_true + all_pred)))
cm = confusion_matrix(all_true, all_pred, labels=unique_phonemes)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=unique_phonemes, yticklabels=unique_phonemes)
plt.title('Phoneme Confusion Matrix')
plt.xlabel('Predicted Phonemes')
plt.ylabel('True Phonemes')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Confusion matrix saved to: {save_path}")
plt.close()
return cm, unique_phonemes
2025-10-06 15:17:44 +08:00
# argument parser for command line arguments
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
2025-10-12 09:11:32 +08:00
parser.add_argument('--gru_model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
2025-10-06 15:17:44 +08:00
help='Path to the pretrained GRU model directory.')
2025-10-12 09:11:32 +08:00
parser.add_argument('--lstm_model_path', type=str, default='../data/t15_pretrained_lstm_baseline',
2025-10-06 15:17:44 +08:00
help='Path to the pretrained LSTM model directory.')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset.')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU number to use for model inference. Set to -1 to use CPU.')
parser.add_argument('--gru_weight', type=float, default=1,
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.')
# TTA parameters
parser.add_argument('--tta_samples', type=int, default=4,
help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.')
parser.add_argument('--tta_noise_std', type=float, default=0.01,
help='Standard deviation for TTA noise augmentation.')
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
help='Range for TTA smoothing kernel variation (±range from default).')
parser.add_argument('--tta_scale_range', type=float, default=0.05,
help='Range for TTA amplitude scaling (±range from 1.0).')
parser.add_argument('--tta_cut_max', type=int, default=3,
help='Maximum number of timesteps to cut from beginning in TTA.')
# TTA样本权重配置
parser.add_argument('--tta_weights', type=str, default='2.0,1.0,0.0,0.0,0.0',
help='Comma-separated weights for each TTA augmentation type: original,noise,scale,shift,smooth. Set to 0 to disable.')
2025-10-12 09:11:32 +08:00
# Evaluation metrics configuration
parser.add_argument('--skip_confusion_matrix', action='store_true',
help='Skip confusion matrix generation to save time.')
parser.add_argument('--detailed_metrics', action='store_true', default=True,
help='Calculate detailed phoneme-level and sequence-level metrics.')
parser.add_argument('--save_plots', action='store_true', default=True,
help='Save confusion matrix and other plots.')
# Memory management options
parser.add_argument('--max_trials_per_session', type=int, default=1000,
help='Maximum number of trials to load per session (for memory management).')
parser.add_argument('--batch_inference', action='store_true',
help='Process trials in batches to reduce memory usage during inference.')
parser.add_argument('--batch_size', type=int, default=10,
help='Batch size for inference when using --batch_inference.')
2025-10-06 15:17:44 +08:00
args = parser.parse_args()
# Model paths
gru_model_path = args.gru_model_path
lstm_model_path = args.lstm_model_path
data_dir = args.data_dir
# Ensemble weights
gru_weight = args.gru_weight
lstm_weight = 1.0 - gru_weight
# TTA parameters
tta_samples = args.tta_samples
tta_noise_std = args.tta_noise_std
tta_smooth_range = args.tta_smooth_range
tta_scale_range = args.tta_scale_range
tta_cut_max = args.tta_cut_max
# Parse TTA weights
tta_weights_str = args.tta_weights.split(',')
if len(tta_weights_str) != 5:
raise ValueError("TTA weights must have exactly 5 values: original,noise,scale,shift,smooth")
tta_weights = {
'original': float(tta_weights_str[0]),
'noise': float(tta_weights_str[1]),
'scale': float(tta_weights_str[2]),
'shift': float(tta_weights_str[3]),
'smooth': float(tta_weights_str[4])
}
# Calculate effective number of TTA samples based on enabled augmentations
enabled_augmentations = [k for k, v in tta_weights.items() if v > 0]
effective_tta_samples = len(enabled_augmentations)
print(f"Improved TTA-E Configuration:")
print(f"GRU weight: {gru_weight:.2f} (balanced for better ensemble)")
print(f"LSTM weight: {lstm_weight:.2f}")
print(f"TTA augmentation weights:")
for aug_type, weight in tta_weights.items():
status = "✓ enabled" if weight > 0 else "✗ disabled"
print(f" - {aug_type}: {weight:.1f} ({status})")
print(f"Effective TTA samples: {effective_tta_samples}")
print(f"TTA noise std: {tta_noise_std}")
print(f"TTA smooth range: ±{tta_smooth_range}")
print(f"TTA scale range: ±{tta_scale_range}")
print(f"TTA max cut: {tta_cut_max} timesteps")
print(f"Key improvements:")
print(f" - Customizable TTA sample weights")
print(f" - Probability-based ensemble instead of logit averaging")
print(f" - Geometric mean for more robust fusion")
print(f" - Direct PER evaluation without language model dependency")
print(f"GRU model path: {gru_model_path}")
print(f"LSTM model path: {lstm_model_path}")
print()
# Define evaluation type
eval_type = args.eval_type
# Load CSV file
b2txt_csv_df = pd.read_csv(args.csv_path)
# Load model arguments for both models
gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml'))
lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml'))
# Set up GPU device
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
if gpu_number >= torch.cuda.device_count():
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
device = f'cuda:{gpu_number}'
device = torch.device(device)
print(f'Using {device} for model inference.')
else:
if gpu_number >= 0:
print(f'GPU number {gpu_number} requested but not available.')
print('Using CPU for model inference.')
device = torch.device('cpu')
# Define GRU model
gru_model = GRUDecoder(
neural_dim=gru_model_args['model']['n_input_features'],
n_units=gru_model_args['model']['n_units'],
n_days=len(gru_model_args['dataset']['sessions']),
n_classes=gru_model_args['dataset']['n_classes'],
rnn_dropout=gru_model_args['model']['rnn_dropout'],
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
n_layers=gru_model_args['model']['n_layers'],
patch_size=gru_model_args['model']['patch_size'],
patch_stride=gru_model_args['model']['patch_stride'],
)
# Load GRU model weights
gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(gru_checkpoint['model_state_dict'].keys()):
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
# Define LSTM model
lstm_model = LSTMDecoder(
neural_dim=lstm_model_args['model']['n_input_features'],
n_units=lstm_model_args['model']['n_units'],
n_days=len(lstm_model_args['dataset']['sessions']),
n_classes=lstm_model_args['dataset']['n_classes'],
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
n_layers=lstm_model_args['model']['n_layers'],
patch_size=lstm_model_args['model']['patch_size'],
patch_stride=lstm_model_args['model']['patch_stride'],
)
# Load LSTM model weights
lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(lstm_checkpoint['model_state_dict'].keys()):
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
# Add models to device
gru_model.to(device)
lstm_model.to(device)
# Set models to eval mode
gru_model.eval()
lstm_model.eval()
print("Both models loaded successfully!")
print()
# TTA-E inference function
def runTTAEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
device, gru_weight, lstm_weight, tta_weights, tta_noise_std,
tta_smooth_range, tta_scale_range, tta_cut_max):
"""
Run Customizable TTA-E (Test Time Augmentation + Ensemble) inference:
1. Apply selected data augmentations based on weights
2. Run both GRU and LSTM models on each augmented version
3. Use probability-based ensemble with customizable sample weights
4. Average predictions in probability space for better stability
"""
all_gru_probs = []
all_lstm_probs = []
sample_weights = []
# Get default smoothing parameters
default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
# Generate augmented samples based on enabled augmentations
augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
for aug_type in augmentation_types:
if tta_weights[aug_type] <= 0:
continue # Skip disabled augmentations
x_augmented = x.clone()
if aug_type == 'original':
# Original data (baseline)
pass
elif aug_type == 'noise':
# Add Gaussian noise with varying intensity
noise_scale = tta_noise_std * (0.5 + 0.5 * np.random.rand())
noise = torch.randn_like(x_augmented) * noise_scale
x_augmented = x_augmented + noise
elif aug_type == 'scale':
# Amplitude scaling with more variation
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range
x_augmented = x_augmented * scale_factor
elif aug_type == 'shift' and tta_cut_max > 0:
# Time shift (circular shift)
shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8))
x_augmented = torch.cat([x_augmented[:, shift_amount:, :],
x_augmented[:, :shift_amount, :]], dim=1)
elif aug_type == 'smooth':
# Apply smoothing variation after augmentation
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range
varied_smooth_std = max(0.3, default_smooth_std + smooth_variation)
2025-10-12 09:11:32 +08:00
# Use autocast for efficiency - auto-detect device type
device_type = "cuda" if device.type == "cuda" else "cpu"
use_amp = gru_model_args.get('use_amp', False) and device.type == "cuda"
with torch.autocast(device_type=device_type, enabled=use_amp, dtype=torch.bfloat16 if device.type == "cuda" else torch.float32):
2025-10-06 15:17:44 +08:00
# Apply Gaussian smoothing
if aug_type == 'smooth':
# Use varied smoothing
x_smoothed = gauss_smooth(
inputs=x_augmented,
device=device,
smooth_kernel_std=varied_smooth_std,
smooth_kernel_size=default_smooth_size,
padding='valid',
)
else:
# Use default smoothing
x_smoothed = gauss_smooth(
inputs=x_augmented,
device=device,
smooth_kernel_std=default_smooth_std,
smooth_kernel_size=default_smooth_size,
padding='valid',
)
with torch.no_grad():
# Get GRU logits and convert to probabilities
gru_logits, _ = gru_model(
x=x_smoothed,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
gru_probs = torch.softmax(gru_logits, dim=-1)
# Get LSTM logits and convert to probabilities
lstm_logits, _ = lstm_model(
x=x_smoothed,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
lstm_probs = torch.softmax(lstm_logits, dim=-1)
all_gru_probs.append(gru_probs)
all_lstm_probs.append(lstm_probs)
sample_weights.append(tta_weights[aug_type])
if len(all_gru_probs) == 0:
raise ValueError("No TTA augmentations are enabled. Please set at least one weight > 0.")
# TTA fusion: Handle potentially different tensor shapes
if len(all_gru_probs) > 1:
# Find the minimum sequence length among all TTA samples
min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs])
# Truncate all tensors to the minimum length
truncated_gru_probs = []
truncated_lstm_probs = []
for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs):
if gru_probs.shape[1] > min_length:
truncated_gru_probs.append(gru_probs[:, :min_length, :])
else:
truncated_gru_probs.append(gru_probs)
if lstm_probs.shape[1] > min_length:
truncated_lstm_probs.append(lstm_probs[:, :min_length, :])
else:
truncated_lstm_probs.append(lstm_probs)
# Weighted average probabilities across TTA samples
sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32, device=device)
sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum() # Normalize weights
weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0])
weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0])
for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)):
weighted_gru_probs += weight * gru_probs
weighted_lstm_probs += weight * lstm_probs
avg_gru_probs = weighted_gru_probs
avg_lstm_probs = weighted_lstm_probs
else:
avg_gru_probs = all_gru_probs[0]
avg_lstm_probs = all_lstm_probs[0]
# Improved ensemble: geometric mean of probabilities (more robust than weighted average)
# Apply epsilon smoothing to avoid log(0)
epsilon = 1e-8
avg_gru_probs = avg_gru_probs + epsilon
avg_lstm_probs = avg_lstm_probs + epsilon
# Weighted geometric mean in log space
log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) +
lstm_weight * torch.log(avg_lstm_probs))
ensemble_probs = torch.exp(log_ensemble_probs)
# Normalize to ensure proper probability distribution
ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
# Convert back to logits for compatibility with downstream code
final_logits = torch.log(ensemble_probs + epsilon)
# Convert from bfloat16 to float32
return final_logits.float().cpu().numpy()
2025-10-12 09:11:32 +08:00
# Load data for each session with memory management
2025-10-06 15:17:44 +08:00
test_data = {}
total_test_trials = 0
2025-10-12 09:11:32 +08:00
print("Loading data with memory management...")
2025-10-06 15:17:44 +08:00
for session in gru_model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
2025-10-12 09:11:32 +08:00
try:
# Try to load data normally first
print(f'Loading {session}...')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Successfully loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
except (MemoryError, np._core._exceptions._ArrayMemoryError) as e:
print(f'Memory error loading {session}: {e}')
print(f'Attempting to load {session} in smaller chunks...')
# If memory error, try loading in chunks or skip this session
import h5py
try:
with h5py.File(eval_file, 'r') as h5file:
# Get the size first
input_features_shape = h5file['input_features'].shape
print(f'Session {session} has {input_features_shape[0]} trials with {input_features_shape[1]} features')
2025-10-06 15:17:44 +08:00
2025-10-12 09:11:32 +08:00
# If it's too large, skip or implement chunk loading
max_trials_per_session = args.max_trials_per_session # Limit trials per session
if input_features_shape[0] > max_trials_per_session:
print(f'Session {session} has {input_features_shape[0]} trials, limiting to {max_trials_per_session} for memory management')
# Could implement chunked loading here if needed
print(f'Skipping session {session} due to memory constraints')
continue
else:
# Try one more time with garbage collection
import gc
gc.collect()
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session} after GC.')
except Exception as e2:
print(f'Failed to load session {session}: {e2}')
print(f'Skipping session {session}')
continue
print(f'Total number of {eval_type} trials loaded: {total_test_trials}')
if total_test_trials == 0:
print("ERROR: No trials loaded! Check data paths and memory availability.")
sys.exit(1)
2025-10-06 15:17:44 +08:00
print()
# Put neural data through the TTA-E ensemble model to get phoneme predictions (logits)
2025-10-12 09:11:32 +08:00
import gc
2025-10-06 15:17:44 +08:00
with tqdm(total=total_test_trials, desc=f'Customizable TTA-E inference ({effective_tta_samples} augmentation types)', unit='trial') as pbar:
for session, data in test_data.items():
data['logits'] = []
data['pred_seq'] = []
input_layer = gru_model_args['dataset']['sessions'].index(session)
2025-10-12 09:11:32 +08:00
# Process trials in batches if requested
if args.batch_inference:
batch_size = args.batch_size
num_trials = len(data['neural_features'])
for batch_start in range(0, num_trials, batch_size):
batch_end = min(batch_start + batch_size, num_trials)
batch_logits = []
for trial in range(batch_start, batch_end):
# Get neural input for the trial
neural_input = data['neural_features'][trial]
# Add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# Convert to torch tensor - use float32 to avoid dtype issues with smoothing
neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
2025-10-06 15:17:44 +08:00
2025-10-12 09:11:32 +08:00
# Run TTA-E decoding step with customizable weights
ensemble_logits = runTTAEnsembleDecodingStep(
neural_input, input_layer, gru_model, lstm_model,
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
)
batch_logits.append(ensemble_logits)
2025-10-06 15:17:44 +08:00
2025-10-12 09:11:32 +08:00
# Clear GPU memory
del neural_input
if device.type == 'cuda':
torch.cuda.empty_cache()
pbar.update(1)
# Add batch results
data['logits'].extend(batch_logits)
# Clear batch memory
del batch_logits
gc.collect()
else:
# Process all trials individually (original method)
for trial in range(len(data['neural_features'])):
# Get neural input for the trial
neural_input = data['neural_features'][trial]
# Add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# Convert to torch tensor - use float32 to avoid dtype issues with smoothing
neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
# Run TTA-E decoding step with customizable weights
ensemble_logits = runTTAEnsembleDecodingStep(
neural_input, input_layer, gru_model, lstm_model,
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
)
data['logits'].append(ensemble_logits)
# Clear memory periodically
if trial % 50 == 0:
del neural_input
if device.type == 'cuda':
torch.cuda.empty_cache()
gc.collect()
pbar.update(1)
# Clear session data memory after processing
gc.collect()
2025-10-06 15:17:44 +08:00
pbar.close()
# Convert logits to phoneme sequences and print them out
results = {
'session': [],
'block': [],
'trial': [],
'true_sentence': [],
'pred_phonemes': [],
}
for session, data in test_data.items():
data['pred_seq'] = []
for trial in range(len(data['logits'])):
logits = data['logits'][trial][0]
pred_seq = np.argmax(logits, axis=-1)
# Remove blanks (0)
pred_seq = [int(p) for p in pred_seq if p != 0]
# Remove consecutive duplicates
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
# Convert to phonemes
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
# Add to data
data['pred_seq'].append(pred_seq)
# Store results
results['session'].append(session)
results['block'].append(data['block_num'][trial])
results['trial'].append(data['trial_num'][trial])
if eval_type == 'val':
results['true_sentence'].append(data['sentence_label'][trial])
else:
results['true_sentence'].append(None)
results['pred_phonemes'].append(pred_seq)
# Print out the predicted sequences
block_num = data['block_num'][trial]
trial_num = data['trial_num'][trial]
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
if eval_type == 'val':
sentence_label = data['sentence_label'][trial]
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
print(f'Sentence label: {sentence_label}')
print(f'True sequence: {" ".join(true_seq)}')
print(f'Predicted Sequence: {" ".join(pred_seq)}')
print()
2025-10-12 09:11:32 +08:00
# If using the validation set, calculate comprehensive evaluation metrics
2025-10-06 15:17:44 +08:00
if eval_type == 'val':
2025-10-12 09:11:32 +08:00
print("=" * 80)
print("COMPREHENSIVE EVALUATION METRICS")
print("=" * 80)
# Record performance timing
eval_start_time = time.time()
# Initialize metrics storage
2025-10-06 15:17:44 +08:00
total_true_length = 0
total_edit_distance = 0
2025-10-12 09:11:32 +08:00
all_true_seqs = []
all_pred_seqs = []
per_trial_errors = []
per_session_metrics = defaultdict(list)
2025-10-06 15:17:44 +08:00
2025-10-12 09:11:32 +08:00
# Additional detailed metrics
2025-10-06 15:17:44 +08:00
results['edit_distance'] = []
results['num_phonemes'] = []
2025-10-12 09:11:32 +08:00
results['insertions'] = []
results['deletions'] = []
results['substitutions'] = []
results['per_trial_per'] = []
print(f"\nCalculating detailed metrics for {len(results['pred_phonemes'])} trials...")
2025-10-06 15:17:44 +08:00
for i in range(len(results['pred_phonemes'])):
# Get true phoneme sequence
session = results['session'][i]
trial_idx = None
for trial in range(len(test_data[session]['seq_class_ids'])):
2025-10-12 09:11:32 +08:00
if (test_data[session]['block_num'][trial] == results['block'][i] and
2025-10-06 15:17:44 +08:00
test_data[session]['trial_num'][trial] == results['trial'][i]):
trial_idx = trial
break
2025-10-12 09:11:32 +08:00
2025-10-06 15:17:44 +08:00
if trial_idx is not None:
true_seq = test_data[session]['seq_class_ids'][trial_idx][0:test_data[session]['seq_len'][trial_idx]]
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
pred_seq = results['pred_phonemes'][i]
2025-10-12 09:11:32 +08:00
all_true_seqs.append(true_seq)
all_pred_seqs.append(pred_seq)
# Calculate basic edit distance
2025-10-06 15:17:44 +08:00
ed = editdistance.eval(true_seq, pred_seq)
2025-10-12 09:11:32 +08:00
trial_per = 100 * ed / len(true_seq) if len(true_seq) > 0 else 0
# Calculate detailed phoneme operations
insertions, deletions, substitutions = calculate_phoneme_operations(true_seq, pred_seq)
2025-10-06 15:17:44 +08:00
2025-10-12 09:11:32 +08:00
# Store metrics
2025-10-06 15:17:44 +08:00
total_true_length += len(true_seq)
total_edit_distance += ed
2025-10-12 09:11:32 +08:00
per_trial_errors.append(trial_per)
per_session_metrics[session].append(trial_per)
2025-10-06 15:17:44 +08:00
results['edit_distance'].append(ed)
results['num_phonemes'].append(len(true_seq))
2025-10-12 09:11:32 +08:00
results['insertions'].append(insertions)
results['deletions'].append(deletions)
results['substitutions'].append(substitutions)
results['per_trial_per'].append(trial_per)
2025-10-06 15:17:44 +08:00
2025-10-12 09:11:32 +08:00
# Print detailed trial information
2025-10-06 15:17:44 +08:00
print(f'{results["session"][i]} - Block {results["block"][i]}, Trial {results["trial"][i]}')
2025-10-12 09:11:32 +08:00
print(f'True phonemes ({len(true_seq)}): {" ".join(true_seq)}')
print(f'Predicted phonemes ({len(pred_seq)}): {" ".join(pred_seq)}')
print(f'Operations: I={insertions}, D={deletions}, S={substitutions}, Total ED={ed}')
print(f'Trial PER: {ed} / {len(true_seq)} = {trial_per:.2f}%')
2025-10-06 15:17:44 +08:00
print()
2025-10-12 09:11:32 +08:00
# ========== BASIC METRICS ==========
aggregate_per = 100 * total_edit_distance / total_true_length if total_true_length > 0 else 0
print("=" * 50)
print("BASIC PHONEME ERROR RATE METRICS")
print("=" * 50)
2025-10-06 15:17:44 +08:00
print(f'Total true phoneme length: {total_true_length}')
print(f'Total edit distance: {total_edit_distance}')
2025-10-12 09:11:32 +08:00
print(f'Aggregate Phoneme Error Rate (PER): {aggregate_per:.2f}%')
# Calculate PER confidence intervals
per_mean, per_ci_lower, per_ci_upper = calculate_confidence_intervals(per_trial_errors)
print(f'Per-trial PER: {per_mean:.2f}% (95% CI: [{per_ci_lower:.2f}%, {per_ci_upper:.2f}%])')
# ========== PHONEME OPERATION ANALYSIS ==========
total_insertions = sum(results['insertions'])
total_deletions = sum(results['deletions'])
total_substitutions = sum(results['substitutions'])
print("\n" + "=" * 50)
print("PHONEME OPERATION BREAKDOWN")
print("=" * 50)
print(f'Total insertions: {total_insertions} ({100*total_insertions/total_edit_distance:.1f}% of errors)')
print(f'Total deletions: {total_deletions} ({100*total_deletions/total_edit_distance:.1f}% of errors)')
print(f'Total substitutions: {total_substitutions} ({100*total_substitutions/total_edit_distance:.1f}% of errors)')
print(f'Verification: I+D+S = {total_insertions + total_deletions + total_substitutions} (should equal {total_edit_distance})')
# ========== PHONEME-LEVEL CLASSIFICATION METRICS ==========
print("\n" + "=" * 50)
print("PHONEME-LEVEL CLASSIFICATION METRICS")
print("=" * 50)
phoneme_metrics = calculate_phoneme_level_metrics(all_true_seqs, all_pred_seqs)
print(f"Macro-averaged metrics:")
print(f" Precision: {phoneme_metrics['macro_avg']['precision']:.3f}")
print(f" Recall: {phoneme_metrics['macro_avg']['recall']:.3f}")
print(f" F1-Score: {phoneme_metrics['macro_avg']['f1']:.3f}")
2025-10-06 15:17:44 +08:00
2025-10-12 09:11:32 +08:00
print(f"\nMicro-averaged metrics:")
print(f" Precision: {phoneme_metrics['micro_avg']['precision']:.3f}")
print(f" Recall: {phoneme_metrics['micro_avg']['recall']:.3f}")
print(f" F1-Score: {phoneme_metrics['micro_avg']['f1']:.3f}")
# ========== SEQUENCE-LEVEL METRICS ==========
print("\n" + "=" * 50)
print("SEQUENCE-LEVEL METRICS")
print("=" * 50)
seq_metrics = calculate_sequence_level_metrics(all_true_seqs, all_pred_seqs)
print(f"Exact Match Accuracy: {seq_metrics['exact_match_accuracy']:.3f} ({seq_metrics['exact_matches']}/{seq_metrics['total_sequences']})")
# ========== PER-SESSION BREAKDOWN ==========
print("\n" + "=" * 50)
print("PER-SESSION PERFORMANCE BREAKDOWN")
print("=" * 50)
for session in sorted(per_session_metrics.keys()):
session_errors = per_session_metrics[session]
session_mean, session_ci_lower, session_ci_upper = calculate_confidence_intervals(session_errors)
print(f"{session}: {session_mean:.2f}% ± {session_mean - session_ci_lower:.2f}% "
f"(n={len(session_errors)}, range: {min(session_errors):.1f}%-{max(session_errors):.1f}%)")
# ========== STATISTICAL SUMMARY ==========
print("\n" + "=" * 50)
print("STATISTICAL SUMMARY")
print("=" * 50)
print(f"Mean trial PER: {np.mean(per_trial_errors):.2f}%")
print(f"Median trial PER: {np.median(per_trial_errors):.2f}%")
print(f"Std deviation: {np.std(per_trial_errors):.2f}%")
print(f"Min/Max trial PER: {min(per_trial_errors):.1f}% / {max(per_trial_errors):.1f}%")
print(f"Trials with 0% PER: {sum(1 for x in per_trial_errors if x == 0)} ({100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors):.1f}%)")
# ========== PHONEME FREQUENCY ANALYSIS ==========
if args.detailed_metrics:
print("\n" + "=" * 50)
print("PHONEME FREQUENCY AND ERROR PATTERN ANALYSIS")
print("=" * 50)
freq_analysis = calculate_phoneme_frequency_analysis(all_true_seqs, all_pred_seqs)
print("Top 10 most frequent true phonemes:")
for phoneme, count in freq_analysis['true_phoneme_counts'].most_common(10):
print(f" {phoneme}: {count} ({100*count/sum(freq_analysis['true_phoneme_counts'].values()):.1f}%)")
print("\nTop 10 most frequent predicted phonemes:")
for phoneme, count in freq_analysis['pred_phoneme_counts'].most_common(10):
print(f" {phoneme}: {count} ({100*count/sum(freq_analysis['pred_phoneme_counts'].values()):.1f}%)")
print("\nTop 10 most frequent error patterns (true -> predicted):")
sorted_errors = sorted(freq_analysis['error_patterns'].items(), key=lambda x: x[1], reverse=True)
for (true_p, pred_p), count in sorted_errors[:10]:
print(f" {true_p} -> {pred_p}: {count} errors")
# ========== GENERATE CONFUSION MATRIX ==========
if not args.skip_confusion_matrix and args.save_plots:
timestamp_for_plots = time.strftime("%Y%m%d_%H%M%S")
confusion_matrix_path = f'confusion_matrix_TTA-E_{eval_type}_{timestamp_for_plots}.png'
confusion_matrix_path = os.path.join(os.path.dirname(__file__), confusion_matrix_path)
print(f"\nGenerating confusion matrix...")
cm_result = generate_confusion_matrix_plot(all_true_seqs, all_pred_seqs, confusion_matrix_path)
if cm_result[0] is None:
confusion_matrix_path = None
else:
print(f"\nSkipping confusion matrix generation (--skip_confusion_matrix or --no-save_plots specified)")
confusion_matrix_path = None
# ========== TIMING METRICS ==========
eval_end_time = time.time()
eval_duration = eval_end_time - eval_start_time
print("\n" + "=" * 50)
print("TIMING METRICS")
print("=" * 50)
print(f"Evaluation time: {eval_duration:.2f} seconds")
print(f"Time per trial: {eval_duration/len(results['pred_phonemes']):.3f} seconds")
print("\n" + "=" * 80)
print("EVALUATION SUMMARY COMPLETE")
print("=" * 80)
# Write comprehensive results to CSV files
2025-10-06 15:17:44 +08:00
timestamp = time.strftime("%Y%m%d_%H%M%S")
enabled_augs = '_'.join([k for k, v in tta_weights.items() if v > 0])
2025-10-12 09:11:32 +08:00
# Basic phoneme predictions CSV (for compatibility)
2025-10-06 15:17:44 +08:00
output_file = f'TTA-E_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_{enabled_augs}_{eval_type}_{timestamp}.csv'
output_path = os.path.join(os.path.dirname(__file__), output_file)
ids = [i for i in range(len(results['pred_phonemes']))]
phoneme_strings = [" ".join(phonemes) for phonemes in results['pred_phonemes']]
df_out = pd.DataFrame({'id': ids, 'phonemes': phoneme_strings})
df_out.to_csv(output_path, index=False)
2025-10-12 09:11:32 +08:00
# Comprehensive metrics CSV (if validation set)
if eval_type == 'val':
detailed_output_file = f'TTA-E_detailed_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv'
detailed_output_path = os.path.join(os.path.dirname(__file__), detailed_output_file)
# Create detailed DataFrame with all metrics
detailed_df = pd.DataFrame({
'id': ids,
'session': results['session'],
'block': results['block'],
'trial': results['trial'],
'true_sentence': results['true_sentence'],
'pred_phonemes': phoneme_strings,
'num_phonemes': results['num_phonemes'],
'edit_distance': results['edit_distance'],
'insertions': results['insertions'],
'deletions': results['deletions'],
'substitutions': results['substitutions'],
'trial_per': results['per_trial_per']
})
detailed_df.to_csv(detailed_output_path, index=False)
# Summary metrics CSV
summary_output_file = f'TTA-E_summary_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv'
summary_output_path = os.path.join(os.path.dirname(__file__), summary_output_file)
summary_data = {
'metric': [
'aggregate_per', 'mean_trial_per', 'median_trial_per', 'std_trial_per',
'min_trial_per', 'max_trial_per', 'zero_per_trials_count', 'zero_per_trials_percent',
'total_phonemes', 'total_edit_distance', 'total_insertions', 'total_deletions',
'total_substitutions', 'exact_match_accuracy', 'exact_matches',
'macro_precision', 'macro_recall', 'macro_f1',
'micro_precision', 'micro_recall', 'micro_f1',
'per_mean_95ci_lower', 'per_mean_95ci_upper', 'evaluation_time_seconds'
],
'value': [
aggregate_per, np.mean(per_trial_errors), np.median(per_trial_errors), np.std(per_trial_errors),
min(per_trial_errors), max(per_trial_errors),
sum(1 for x in per_trial_errors if x == 0),
100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors),
total_true_length, total_edit_distance, total_insertions, total_deletions,
total_substitutions, seq_metrics['exact_match_accuracy'], seq_metrics['exact_matches'],
phoneme_metrics['macro_avg']['precision'], phoneme_metrics['macro_avg']['recall'],
phoneme_metrics['macro_avg']['f1'], phoneme_metrics['micro_avg']['precision'],
phoneme_metrics['micro_avg']['recall'], phoneme_metrics['micro_avg']['f1'],
per_ci_lower, per_ci_upper, eval_duration
]
}
summary_df = pd.DataFrame(summary_data)
summary_df.to_csv(summary_output_path, index=False)
print(f'\nComprehensive results saved to:')
print(f' Basic predictions: {output_path}')
print(f' Detailed metrics: {detailed_output_path}')
print(f' Summary metrics: {summary_output_path}')
if confusion_matrix_path:
print(f' Confusion matrix: {confusion_matrix_path}')
else:
print(f' Confusion matrix: Not generated')
else:
print(f'\nResults saved to: {output_path}')
2025-10-06 15:17:44 +08:00
print(f'TTA-E configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}')
print(f'Enabled augmentations: {", ".join(enabled_augs.split("_"))}')