Files
b2txt25/TTA-E/evaluate_model.py
2025-10-12 09:11:32 +08:00

1063 lines
45 KiB
Python

import os
import sys
import torch
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
from collections import defaultdict, Counter
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
# Add parent directories to path to import models
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
from model_training.rnn_model import GRUDecoder
from model_training_lstm.rnn_model import LSTMDecoder
from model_training.evaluate_model_helpers import *
# ========== COMPREHENSIVE EVALUATION METRICS FUNCTIONS ==========
def calculate_phoneme_operations(true_seq, pred_seq):
"""Calculate insertions, deletions, and substitutions using dynamic programming."""
m, n = len(true_seq), len(pred_seq)
# DP matrix for edit distance calculation
dp = [[0] * (n + 1) for _ in range(m + 1)]
# Initialize base cases
for i in range(m + 1):
dp[i][0] = i # deletions
for j in range(n + 1):
dp[0][j] = j # insertions
# Fill DP matrix
for i in range(1, m + 1):
for j in range(1, n + 1):
if true_seq[i-1] == pred_seq[j-1]:
dp[i][j] = dp[i-1][j-1] # match
else:
dp[i][j] = 1 + min(
dp[i-1][j], # deletion
dp[i][j-1], # insertion
dp[i-1][j-1] # substitution
)
# Backtrack to count operations
i, j = m, n
insertions = deletions = substitutions = 0
while i > 0 or j > 0:
if i > 0 and j > 0 and true_seq[i-1] == pred_seq[j-1]:
i -= 1
j -= 1
elif i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + 1:
substitutions += 1
i -= 1
j -= 1
elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
deletions += 1
i -= 1
else:
insertions += 1
j -= 1
return insertions, deletions, substitutions
def calculate_phoneme_level_metrics(true_seqs, pred_seqs):
"""Calculate phoneme-level precision, recall, and F1-score using alignment."""
all_true_phonemes = []
all_pred_phonemes = []
# Align sequences properly to handle different lengths
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
# Use dynamic programming alignment similar to edit distance
# For simplicity, we'll pad shorter sequences and align based on minimum length
min_len = min(len(true_seq), len(pred_seq))
max_len = max(len(true_seq), len(pred_seq))
# Take only the overlapping part for fair comparison
if min_len > 0:
# Align by taking proportional samples
if len(true_seq) == len(pred_seq):
# Same length - direct alignment
all_true_phonemes.extend(true_seq)
all_pred_phonemes.extend(pred_seq)
else:
# Different lengths - align by truncating to shorter length
truncated_true = true_seq[:min_len]
truncated_pred = pred_seq[:min_len]
all_true_phonemes.extend(truncated_true)
all_pred_phonemes.extend(truncated_pred)
# Verify lengths match
if len(all_true_phonemes) != len(all_pred_phonemes):
print(f"Warning: After alignment, lengths still don't match: {len(all_true_phonemes)} vs {len(all_pred_phonemes)}")
# Take minimum length to avoid sklearn error
min_len = min(len(all_true_phonemes), len(all_pred_phonemes))
all_true_phonemes = all_true_phonemes[:min_len]
all_pred_phonemes = all_pred_phonemes[:min_len]
if len(all_true_phonemes) == 0:
# Handle empty case
return {
'phoneme_metrics': {
'phonemes': [],
'precision': np.array([]),
'recall': np.array([]),
'f1': np.array([]),
'support': np.array([])
},
'macro_avg': {
'precision': 0.0,
'recall': 0.0,
'f1': 0.0
},
'micro_avg': {
'precision': 0.0,
'recall': 0.0,
'f1': 0.0
}
}
# Get unique phonemes
unique_phonemes = sorted(list(set(all_true_phonemes + all_pred_phonemes)))
# Calculate metrics for each phoneme class
precision, recall, f1, support = precision_recall_fscore_support(
all_true_phonemes, all_pred_phonemes,
labels=unique_phonemes, average=None, zero_division=0
)
# Calculate macro averages
macro_precision = np.mean(precision)
macro_recall = np.mean(recall)
macro_f1 = np.mean(f1)
# Calculate micro averages
micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
all_true_phonemes, all_pred_phonemes, average='micro', zero_division=0
)
return {
'phoneme_metrics': {
'phonemes': unique_phonemes,
'precision': precision,
'recall': recall,
'f1': f1,
'support': support
},
'macro_avg': {
'precision': macro_precision,
'recall': macro_recall,
'f1': macro_f1
},
'micro_avg': {
'precision': micro_precision,
'recall': micro_recall,
'f1': micro_f1
}
}
def calculate_sequence_level_metrics(true_seqs, pred_seqs):
"""Calculate sequence-level metrics."""
exact_matches = 0
total_sequences = len(true_seqs)
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
if true_seq == pred_seq:
exact_matches += 1
exact_match_accuracy = exact_matches / total_sequences if total_sequences > 0 else 0
return {
'exact_match_accuracy': exact_match_accuracy,
'exact_matches': exact_matches,
'total_sequences': total_sequences
}
def calculate_confidence_intervals(errors, confidence=0.95):
"""Calculate confidence intervals for error rates."""
n = len(errors)
if n == 0:
return 0, 0, 0
mean_error = np.mean(errors)
std_error = np.std(errors, ddof=1) if n > 1 else 0
# Use t-distribution for small samples, normal for large samples
if n < 30:
from scipy import stats
t_val = stats.t.ppf((1 + confidence) / 2, n - 1)
margin = t_val * std_error / np.sqrt(n)
else:
# For large samples, use normal distribution
z_val = 1.96 if confidence == 0.95 else 2.576 # 95% or 99%
margin = z_val * std_error / np.sqrt(n)
return mean_error, mean_error - margin, mean_error + margin
def calculate_phoneme_frequency_analysis(true_seqs, pred_seqs):
"""Analyze phoneme frequency and error patterns."""
true_phoneme_counts = Counter()
pred_phoneme_counts = Counter()
error_patterns = defaultdict(int) # Maps (true_phoneme, pred_phoneme) -> count
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
true_phoneme_counts.update(true_seq)
pred_phoneme_counts.update(pred_seq)
# For error pattern analysis, align sequences by taking minimum length
min_len = min(len(true_seq), len(pred_seq))
if min_len > 0:
for i in range(min_len):
true_p = true_seq[i]
pred_p = pred_seq[i]
if true_p != pred_p:
error_patterns[(true_p, pred_p)] += 1
return {
'true_phoneme_counts': true_phoneme_counts,
'pred_phoneme_counts': pred_phoneme_counts,
'error_patterns': dict(error_patterns)
}
def generate_confusion_matrix_plot(true_seqs, pred_seqs, save_path=None):
"""Generate and save confusion matrix plot for phonemes."""
all_true = []
all_pred = []
# Align sequences properly to handle different lengths
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
min_len = min(len(true_seq), len(pred_seq))
if min_len > 0:
# Truncate both to same length for fair comparison
all_true.extend(true_seq[:min_len])
all_pred.extend(pred_seq[:min_len])
if len(all_true) == 0 or len(all_pred) == 0:
print("Warning: No aligned phonemes found for confusion matrix")
return None, []
unique_phonemes = sorted(list(set(all_true + all_pred)))
cm = confusion_matrix(all_true, all_pred, labels=unique_phonemes)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=unique_phonemes, yticklabels=unique_phonemes)
plt.title('Phoneme Confusion Matrix')
plt.xlabel('Predicted Phonemes')
plt.ylabel('True Phonemes')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Confusion matrix saved to: {save_path}")
plt.close()
return cm, unique_phonemes
# argument parser for command line arguments
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
parser.add_argument('--gru_model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
help='Path to the pretrained GRU model directory.')
parser.add_argument('--lstm_model_path', type=str, default='../data/t15_pretrained_lstm_baseline',
help='Path to the pretrained LSTM model directory.')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset.')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU number to use for model inference. Set to -1 to use CPU.')
parser.add_argument('--gru_weight', type=float, default=1,
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.')
# TTA parameters
parser.add_argument('--tta_samples', type=int, default=4,
help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.')
parser.add_argument('--tta_noise_std', type=float, default=0.01,
help='Standard deviation for TTA noise augmentation.')
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
help='Range for TTA smoothing kernel variation (±range from default).')
parser.add_argument('--tta_scale_range', type=float, default=0.05,
help='Range for TTA amplitude scaling (±range from 1.0).')
parser.add_argument('--tta_cut_max', type=int, default=3,
help='Maximum number of timesteps to cut from beginning in TTA.')
# TTA样本权重配置
parser.add_argument('--tta_weights', type=str, default='2.0,1.0,0.0,0.0,0.0',
help='Comma-separated weights for each TTA augmentation type: original,noise,scale,shift,smooth. Set to 0 to disable.')
# Evaluation metrics configuration
parser.add_argument('--skip_confusion_matrix', action='store_true',
help='Skip confusion matrix generation to save time.')
parser.add_argument('--detailed_metrics', action='store_true', default=True,
help='Calculate detailed phoneme-level and sequence-level metrics.')
parser.add_argument('--save_plots', action='store_true', default=True,
help='Save confusion matrix and other plots.')
# Memory management options
parser.add_argument('--max_trials_per_session', type=int, default=1000,
help='Maximum number of trials to load per session (for memory management).')
parser.add_argument('--batch_inference', action='store_true',
help='Process trials in batches to reduce memory usage during inference.')
parser.add_argument('--batch_size', type=int, default=10,
help='Batch size for inference when using --batch_inference.')
args = parser.parse_args()
# Model paths
gru_model_path = args.gru_model_path
lstm_model_path = args.lstm_model_path
data_dir = args.data_dir
# Ensemble weights
gru_weight = args.gru_weight
lstm_weight = 1.0 - gru_weight
# TTA parameters
tta_samples = args.tta_samples
tta_noise_std = args.tta_noise_std
tta_smooth_range = args.tta_smooth_range
tta_scale_range = args.tta_scale_range
tta_cut_max = args.tta_cut_max
# Parse TTA weights
tta_weights_str = args.tta_weights.split(',')
if len(tta_weights_str) != 5:
raise ValueError("TTA weights must have exactly 5 values: original,noise,scale,shift,smooth")
tta_weights = {
'original': float(tta_weights_str[0]),
'noise': float(tta_weights_str[1]),
'scale': float(tta_weights_str[2]),
'shift': float(tta_weights_str[3]),
'smooth': float(tta_weights_str[4])
}
# Calculate effective number of TTA samples based on enabled augmentations
enabled_augmentations = [k for k, v in tta_weights.items() if v > 0]
effective_tta_samples = len(enabled_augmentations)
print(f"Improved TTA-E Configuration:")
print(f"GRU weight: {gru_weight:.2f} (balanced for better ensemble)")
print(f"LSTM weight: {lstm_weight:.2f}")
print(f"TTA augmentation weights:")
for aug_type, weight in tta_weights.items():
status = "✓ enabled" if weight > 0 else "✗ disabled"
print(f" - {aug_type}: {weight:.1f} ({status})")
print(f"Effective TTA samples: {effective_tta_samples}")
print(f"TTA noise std: {tta_noise_std}")
print(f"TTA smooth range: ±{tta_smooth_range}")
print(f"TTA scale range: ±{tta_scale_range}")
print(f"TTA max cut: {tta_cut_max} timesteps")
print(f"Key improvements:")
print(f" - Customizable TTA sample weights")
print(f" - Probability-based ensemble instead of logit averaging")
print(f" - Geometric mean for more robust fusion")
print(f" - Direct PER evaluation without language model dependency")
print(f"GRU model path: {gru_model_path}")
print(f"LSTM model path: {lstm_model_path}")
print()
# Define evaluation type
eval_type = args.eval_type
# Load CSV file
b2txt_csv_df = pd.read_csv(args.csv_path)
# Load model arguments for both models
gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml'))
lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml'))
# Set up GPU device
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
if gpu_number >= torch.cuda.device_count():
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
device = f'cuda:{gpu_number}'
device = torch.device(device)
print(f'Using {device} for model inference.')
else:
if gpu_number >= 0:
print(f'GPU number {gpu_number} requested but not available.')
print('Using CPU for model inference.')
device = torch.device('cpu')
# Define GRU model
gru_model = GRUDecoder(
neural_dim=gru_model_args['model']['n_input_features'],
n_units=gru_model_args['model']['n_units'],
n_days=len(gru_model_args['dataset']['sessions']),
n_classes=gru_model_args['dataset']['n_classes'],
rnn_dropout=gru_model_args['model']['rnn_dropout'],
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
n_layers=gru_model_args['model']['n_layers'],
patch_size=gru_model_args['model']['patch_size'],
patch_stride=gru_model_args['model']['patch_stride'],
)
# Load GRU model weights
gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(gru_checkpoint['model_state_dict'].keys()):
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
# Define LSTM model
lstm_model = LSTMDecoder(
neural_dim=lstm_model_args['model']['n_input_features'],
n_units=lstm_model_args['model']['n_units'],
n_days=len(lstm_model_args['dataset']['sessions']),
n_classes=lstm_model_args['dataset']['n_classes'],
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
n_layers=lstm_model_args['model']['n_layers'],
patch_size=lstm_model_args['model']['patch_size'],
patch_stride=lstm_model_args['model']['patch_stride'],
)
# Load LSTM model weights
lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(lstm_checkpoint['model_state_dict'].keys()):
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
# Add models to device
gru_model.to(device)
lstm_model.to(device)
# Set models to eval mode
gru_model.eval()
lstm_model.eval()
print("Both models loaded successfully!")
print()
# TTA-E inference function
def runTTAEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
device, gru_weight, lstm_weight, tta_weights, tta_noise_std,
tta_smooth_range, tta_scale_range, tta_cut_max):
"""
Run Customizable TTA-E (Test Time Augmentation + Ensemble) inference:
1. Apply selected data augmentations based on weights
2. Run both GRU and LSTM models on each augmented version
3. Use probability-based ensemble with customizable sample weights
4. Average predictions in probability space for better stability
"""
all_gru_probs = []
all_lstm_probs = []
sample_weights = []
# Get default smoothing parameters
default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
# Generate augmented samples based on enabled augmentations
augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
for aug_type in augmentation_types:
if tta_weights[aug_type] <= 0:
continue # Skip disabled augmentations
x_augmented = x.clone()
if aug_type == 'original':
# Original data (baseline)
pass
elif aug_type == 'noise':
# Add Gaussian noise with varying intensity
noise_scale = tta_noise_std * (0.5 + 0.5 * np.random.rand())
noise = torch.randn_like(x_augmented) * noise_scale
x_augmented = x_augmented + noise
elif aug_type == 'scale':
# Amplitude scaling with more variation
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range
x_augmented = x_augmented * scale_factor
elif aug_type == 'shift' and tta_cut_max > 0:
# Time shift (circular shift)
shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8))
x_augmented = torch.cat([x_augmented[:, shift_amount:, :],
x_augmented[:, :shift_amount, :]], dim=1)
elif aug_type == 'smooth':
# Apply smoothing variation after augmentation
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range
varied_smooth_std = max(0.3, default_smooth_std + smooth_variation)
# Use autocast for efficiency - auto-detect device type
device_type = "cuda" if device.type == "cuda" else "cpu"
use_amp = gru_model_args.get('use_amp', False) and device.type == "cuda"
with torch.autocast(device_type=device_type, enabled=use_amp, dtype=torch.bfloat16 if device.type == "cuda" else torch.float32):
# Apply Gaussian smoothing
if aug_type == 'smooth':
# Use varied smoothing
x_smoothed = gauss_smooth(
inputs=x_augmented,
device=device,
smooth_kernel_std=varied_smooth_std,
smooth_kernel_size=default_smooth_size,
padding='valid',
)
else:
# Use default smoothing
x_smoothed = gauss_smooth(
inputs=x_augmented,
device=device,
smooth_kernel_std=default_smooth_std,
smooth_kernel_size=default_smooth_size,
padding='valid',
)
with torch.no_grad():
# Get GRU logits and convert to probabilities
gru_logits, _ = gru_model(
x=x_smoothed,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
gru_probs = torch.softmax(gru_logits, dim=-1)
# Get LSTM logits and convert to probabilities
lstm_logits, _ = lstm_model(
x=x_smoothed,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
lstm_probs = torch.softmax(lstm_logits, dim=-1)
all_gru_probs.append(gru_probs)
all_lstm_probs.append(lstm_probs)
sample_weights.append(tta_weights[aug_type])
if len(all_gru_probs) == 0:
raise ValueError("No TTA augmentations are enabled. Please set at least one weight > 0.")
# TTA fusion: Handle potentially different tensor shapes
if len(all_gru_probs) > 1:
# Find the minimum sequence length among all TTA samples
min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs])
# Truncate all tensors to the minimum length
truncated_gru_probs = []
truncated_lstm_probs = []
for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs):
if gru_probs.shape[1] > min_length:
truncated_gru_probs.append(gru_probs[:, :min_length, :])
else:
truncated_gru_probs.append(gru_probs)
if lstm_probs.shape[1] > min_length:
truncated_lstm_probs.append(lstm_probs[:, :min_length, :])
else:
truncated_lstm_probs.append(lstm_probs)
# Weighted average probabilities across TTA samples
sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32, device=device)
sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum() # Normalize weights
weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0])
weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0])
for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)):
weighted_gru_probs += weight * gru_probs
weighted_lstm_probs += weight * lstm_probs
avg_gru_probs = weighted_gru_probs
avg_lstm_probs = weighted_lstm_probs
else:
avg_gru_probs = all_gru_probs[0]
avg_lstm_probs = all_lstm_probs[0]
# Improved ensemble: geometric mean of probabilities (more robust than weighted average)
# Apply epsilon smoothing to avoid log(0)
epsilon = 1e-8
avg_gru_probs = avg_gru_probs + epsilon
avg_lstm_probs = avg_lstm_probs + epsilon
# Weighted geometric mean in log space
log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) +
lstm_weight * torch.log(avg_lstm_probs))
ensemble_probs = torch.exp(log_ensemble_probs)
# Normalize to ensure proper probability distribution
ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
# Convert back to logits for compatibility with downstream code
final_logits = torch.log(ensemble_probs + epsilon)
# Convert from bfloat16 to float32
return final_logits.float().cpu().numpy()
# Load data for each session with memory management
test_data = {}
total_test_trials = 0
print("Loading data with memory management...")
for session in gru_model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
try:
# Try to load data normally first
print(f'Loading {session}...')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Successfully loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
except (MemoryError, np._core._exceptions._ArrayMemoryError) as e:
print(f'Memory error loading {session}: {e}')
print(f'Attempting to load {session} in smaller chunks...')
# If memory error, try loading in chunks or skip this session
import h5py
try:
with h5py.File(eval_file, 'r') as h5file:
# Get the size first
input_features_shape = h5file['input_features'].shape
print(f'Session {session} has {input_features_shape[0]} trials with {input_features_shape[1]} features')
# If it's too large, skip or implement chunk loading
max_trials_per_session = args.max_trials_per_session # Limit trials per session
if input_features_shape[0] > max_trials_per_session:
print(f'Session {session} has {input_features_shape[0]} trials, limiting to {max_trials_per_session} for memory management')
# Could implement chunked loading here if needed
print(f'Skipping session {session} due to memory constraints')
continue
else:
# Try one more time with garbage collection
import gc
gc.collect()
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session} after GC.')
except Exception as e2:
print(f'Failed to load session {session}: {e2}')
print(f'Skipping session {session}')
continue
print(f'Total number of {eval_type} trials loaded: {total_test_trials}')
if total_test_trials == 0:
print("ERROR: No trials loaded! Check data paths and memory availability.")
sys.exit(1)
print()
# Put neural data through the TTA-E ensemble model to get phoneme predictions (logits)
import gc
with tqdm(total=total_test_trials, desc=f'Customizable TTA-E inference ({effective_tta_samples} augmentation types)', unit='trial') as pbar:
for session, data in test_data.items():
data['logits'] = []
data['pred_seq'] = []
input_layer = gru_model_args['dataset']['sessions'].index(session)
# Process trials in batches if requested
if args.batch_inference:
batch_size = args.batch_size
num_trials = len(data['neural_features'])
for batch_start in range(0, num_trials, batch_size):
batch_end = min(batch_start + batch_size, num_trials)
batch_logits = []
for trial in range(batch_start, batch_end):
# Get neural input for the trial
neural_input = data['neural_features'][trial]
# Add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# Convert to torch tensor - use float32 to avoid dtype issues with smoothing
neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
# Run TTA-E decoding step with customizable weights
ensemble_logits = runTTAEnsembleDecodingStep(
neural_input, input_layer, gru_model, lstm_model,
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
)
batch_logits.append(ensemble_logits)
# Clear GPU memory
del neural_input
if device.type == 'cuda':
torch.cuda.empty_cache()
pbar.update(1)
# Add batch results
data['logits'].extend(batch_logits)
# Clear batch memory
del batch_logits
gc.collect()
else:
# Process all trials individually (original method)
for trial in range(len(data['neural_features'])):
# Get neural input for the trial
neural_input = data['neural_features'][trial]
# Add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# Convert to torch tensor - use float32 to avoid dtype issues with smoothing
neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
# Run TTA-E decoding step with customizable weights
ensemble_logits = runTTAEnsembleDecodingStep(
neural_input, input_layer, gru_model, lstm_model,
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
)
data['logits'].append(ensemble_logits)
# Clear memory periodically
if trial % 50 == 0:
del neural_input
if device.type == 'cuda':
torch.cuda.empty_cache()
gc.collect()
pbar.update(1)
# Clear session data memory after processing
gc.collect()
pbar.close()
# Convert logits to phoneme sequences and print them out
results = {
'session': [],
'block': [],
'trial': [],
'true_sentence': [],
'pred_phonemes': [],
}
for session, data in test_data.items():
data['pred_seq'] = []
for trial in range(len(data['logits'])):
logits = data['logits'][trial][0]
pred_seq = np.argmax(logits, axis=-1)
# Remove blanks (0)
pred_seq = [int(p) for p in pred_seq if p != 0]
# Remove consecutive duplicates
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
# Convert to phonemes
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
# Add to data
data['pred_seq'].append(pred_seq)
# Store results
results['session'].append(session)
results['block'].append(data['block_num'][trial])
results['trial'].append(data['trial_num'][trial])
if eval_type == 'val':
results['true_sentence'].append(data['sentence_label'][trial])
else:
results['true_sentence'].append(None)
results['pred_phonemes'].append(pred_seq)
# Print out the predicted sequences
block_num = data['block_num'][trial]
trial_num = data['trial_num'][trial]
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
if eval_type == 'val':
sentence_label = data['sentence_label'][trial]
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
print(f'Sentence label: {sentence_label}')
print(f'True sequence: {" ".join(true_seq)}')
print(f'Predicted Sequence: {" ".join(pred_seq)}')
print()
# If using the validation set, calculate comprehensive evaluation metrics
if eval_type == 'val':
print("=" * 80)
print("COMPREHENSIVE EVALUATION METRICS")
print("=" * 80)
# Record performance timing
eval_start_time = time.time()
# Initialize metrics storage
total_true_length = 0
total_edit_distance = 0
all_true_seqs = []
all_pred_seqs = []
per_trial_errors = []
per_session_metrics = defaultdict(list)
# Additional detailed metrics
results['edit_distance'] = []
results['num_phonemes'] = []
results['insertions'] = []
results['deletions'] = []
results['substitutions'] = []
results['per_trial_per'] = []
print(f"\nCalculating detailed metrics for {len(results['pred_phonemes'])} trials...")
for i in range(len(results['pred_phonemes'])):
# Get true phoneme sequence
session = results['session'][i]
trial_idx = None
for trial in range(len(test_data[session]['seq_class_ids'])):
if (test_data[session]['block_num'][trial] == results['block'][i] and
test_data[session]['trial_num'][trial] == results['trial'][i]):
trial_idx = trial
break
if trial_idx is not None:
true_seq = test_data[session]['seq_class_ids'][trial_idx][0:test_data[session]['seq_len'][trial_idx]]
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
pred_seq = results['pred_phonemes'][i]
all_true_seqs.append(true_seq)
all_pred_seqs.append(pred_seq)
# Calculate basic edit distance
ed = editdistance.eval(true_seq, pred_seq)
trial_per = 100 * ed / len(true_seq) if len(true_seq) > 0 else 0
# Calculate detailed phoneme operations
insertions, deletions, substitutions = calculate_phoneme_operations(true_seq, pred_seq)
# Store metrics
total_true_length += len(true_seq)
total_edit_distance += ed
per_trial_errors.append(trial_per)
per_session_metrics[session].append(trial_per)
results['edit_distance'].append(ed)
results['num_phonemes'].append(len(true_seq))
results['insertions'].append(insertions)
results['deletions'].append(deletions)
results['substitutions'].append(substitutions)
results['per_trial_per'].append(trial_per)
# Print detailed trial information
print(f'{results["session"][i]} - Block {results["block"][i]}, Trial {results["trial"][i]}')
print(f'True phonemes ({len(true_seq)}): {" ".join(true_seq)}')
print(f'Predicted phonemes ({len(pred_seq)}): {" ".join(pred_seq)}')
print(f'Operations: I={insertions}, D={deletions}, S={substitutions}, Total ED={ed}')
print(f'Trial PER: {ed} / {len(true_seq)} = {trial_per:.2f}%')
print()
# ========== BASIC METRICS ==========
aggregate_per = 100 * total_edit_distance / total_true_length if total_true_length > 0 else 0
print("=" * 50)
print("BASIC PHONEME ERROR RATE METRICS")
print("=" * 50)
print(f'Total true phoneme length: {total_true_length}')
print(f'Total edit distance: {total_edit_distance}')
print(f'Aggregate Phoneme Error Rate (PER): {aggregate_per:.2f}%')
# Calculate PER confidence intervals
per_mean, per_ci_lower, per_ci_upper = calculate_confidence_intervals(per_trial_errors)
print(f'Per-trial PER: {per_mean:.2f}% (95% CI: [{per_ci_lower:.2f}%, {per_ci_upper:.2f}%])')
# ========== PHONEME OPERATION ANALYSIS ==========
total_insertions = sum(results['insertions'])
total_deletions = sum(results['deletions'])
total_substitutions = sum(results['substitutions'])
print("\n" + "=" * 50)
print("PHONEME OPERATION BREAKDOWN")
print("=" * 50)
print(f'Total insertions: {total_insertions} ({100*total_insertions/total_edit_distance:.1f}% of errors)')
print(f'Total deletions: {total_deletions} ({100*total_deletions/total_edit_distance:.1f}% of errors)')
print(f'Total substitutions: {total_substitutions} ({100*total_substitutions/total_edit_distance:.1f}% of errors)')
print(f'Verification: I+D+S = {total_insertions + total_deletions + total_substitutions} (should equal {total_edit_distance})')
# ========== PHONEME-LEVEL CLASSIFICATION METRICS ==========
print("\n" + "=" * 50)
print("PHONEME-LEVEL CLASSIFICATION METRICS")
print("=" * 50)
phoneme_metrics = calculate_phoneme_level_metrics(all_true_seqs, all_pred_seqs)
print(f"Macro-averaged metrics:")
print(f" Precision: {phoneme_metrics['macro_avg']['precision']:.3f}")
print(f" Recall: {phoneme_metrics['macro_avg']['recall']:.3f}")
print(f" F1-Score: {phoneme_metrics['macro_avg']['f1']:.3f}")
print(f"\nMicro-averaged metrics:")
print(f" Precision: {phoneme_metrics['micro_avg']['precision']:.3f}")
print(f" Recall: {phoneme_metrics['micro_avg']['recall']:.3f}")
print(f" F1-Score: {phoneme_metrics['micro_avg']['f1']:.3f}")
# ========== SEQUENCE-LEVEL METRICS ==========
print("\n" + "=" * 50)
print("SEQUENCE-LEVEL METRICS")
print("=" * 50)
seq_metrics = calculate_sequence_level_metrics(all_true_seqs, all_pred_seqs)
print(f"Exact Match Accuracy: {seq_metrics['exact_match_accuracy']:.3f} ({seq_metrics['exact_matches']}/{seq_metrics['total_sequences']})")
# ========== PER-SESSION BREAKDOWN ==========
print("\n" + "=" * 50)
print("PER-SESSION PERFORMANCE BREAKDOWN")
print("=" * 50)
for session in sorted(per_session_metrics.keys()):
session_errors = per_session_metrics[session]
session_mean, session_ci_lower, session_ci_upper = calculate_confidence_intervals(session_errors)
print(f"{session}: {session_mean:.2f}% ± {session_mean - session_ci_lower:.2f}% "
f"(n={len(session_errors)}, range: {min(session_errors):.1f}%-{max(session_errors):.1f}%)")
# ========== STATISTICAL SUMMARY ==========
print("\n" + "=" * 50)
print("STATISTICAL SUMMARY")
print("=" * 50)
print(f"Mean trial PER: {np.mean(per_trial_errors):.2f}%")
print(f"Median trial PER: {np.median(per_trial_errors):.2f}%")
print(f"Std deviation: {np.std(per_trial_errors):.2f}%")
print(f"Min/Max trial PER: {min(per_trial_errors):.1f}% / {max(per_trial_errors):.1f}%")
print(f"Trials with 0% PER: {sum(1 for x in per_trial_errors if x == 0)} ({100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors):.1f}%)")
# ========== PHONEME FREQUENCY ANALYSIS ==========
if args.detailed_metrics:
print("\n" + "=" * 50)
print("PHONEME FREQUENCY AND ERROR PATTERN ANALYSIS")
print("=" * 50)
freq_analysis = calculate_phoneme_frequency_analysis(all_true_seqs, all_pred_seqs)
print("Top 10 most frequent true phonemes:")
for phoneme, count in freq_analysis['true_phoneme_counts'].most_common(10):
print(f" {phoneme}: {count} ({100*count/sum(freq_analysis['true_phoneme_counts'].values()):.1f}%)")
print("\nTop 10 most frequent predicted phonemes:")
for phoneme, count in freq_analysis['pred_phoneme_counts'].most_common(10):
print(f" {phoneme}: {count} ({100*count/sum(freq_analysis['pred_phoneme_counts'].values()):.1f}%)")
print("\nTop 10 most frequent error patterns (true -> predicted):")
sorted_errors = sorted(freq_analysis['error_patterns'].items(), key=lambda x: x[1], reverse=True)
for (true_p, pred_p), count in sorted_errors[:10]:
print(f" {true_p} -> {pred_p}: {count} errors")
# ========== GENERATE CONFUSION MATRIX ==========
if not args.skip_confusion_matrix and args.save_plots:
timestamp_for_plots = time.strftime("%Y%m%d_%H%M%S")
confusion_matrix_path = f'confusion_matrix_TTA-E_{eval_type}_{timestamp_for_plots}.png'
confusion_matrix_path = os.path.join(os.path.dirname(__file__), confusion_matrix_path)
print(f"\nGenerating confusion matrix...")
cm_result = generate_confusion_matrix_plot(all_true_seqs, all_pred_seqs, confusion_matrix_path)
if cm_result[0] is None:
confusion_matrix_path = None
else:
print(f"\nSkipping confusion matrix generation (--skip_confusion_matrix or --no-save_plots specified)")
confusion_matrix_path = None
# ========== TIMING METRICS ==========
eval_end_time = time.time()
eval_duration = eval_end_time - eval_start_time
print("\n" + "=" * 50)
print("TIMING METRICS")
print("=" * 50)
print(f"Evaluation time: {eval_duration:.2f} seconds")
print(f"Time per trial: {eval_duration/len(results['pred_phonemes']):.3f} seconds")
print("\n" + "=" * 80)
print("EVALUATION SUMMARY COMPLETE")
print("=" * 80)
# Write comprehensive results to CSV files
timestamp = time.strftime("%Y%m%d_%H%M%S")
enabled_augs = '_'.join([k for k, v in tta_weights.items() if v > 0])
# Basic phoneme predictions CSV (for compatibility)
output_file = f'TTA-E_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_{enabled_augs}_{eval_type}_{timestamp}.csv'
output_path = os.path.join(os.path.dirname(__file__), output_file)
ids = [i for i in range(len(results['pred_phonemes']))]
phoneme_strings = [" ".join(phonemes) for phonemes in results['pred_phonemes']]
df_out = pd.DataFrame({'id': ids, 'phonemes': phoneme_strings})
df_out.to_csv(output_path, index=False)
# Comprehensive metrics CSV (if validation set)
if eval_type == 'val':
detailed_output_file = f'TTA-E_detailed_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv'
detailed_output_path = os.path.join(os.path.dirname(__file__), detailed_output_file)
# Create detailed DataFrame with all metrics
detailed_df = pd.DataFrame({
'id': ids,
'session': results['session'],
'block': results['block'],
'trial': results['trial'],
'true_sentence': results['true_sentence'],
'pred_phonemes': phoneme_strings,
'num_phonemes': results['num_phonemes'],
'edit_distance': results['edit_distance'],
'insertions': results['insertions'],
'deletions': results['deletions'],
'substitutions': results['substitutions'],
'trial_per': results['per_trial_per']
})
detailed_df.to_csv(detailed_output_path, index=False)
# Summary metrics CSV
summary_output_file = f'TTA-E_summary_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv'
summary_output_path = os.path.join(os.path.dirname(__file__), summary_output_file)
summary_data = {
'metric': [
'aggregate_per', 'mean_trial_per', 'median_trial_per', 'std_trial_per',
'min_trial_per', 'max_trial_per', 'zero_per_trials_count', 'zero_per_trials_percent',
'total_phonemes', 'total_edit_distance', 'total_insertions', 'total_deletions',
'total_substitutions', 'exact_match_accuracy', 'exact_matches',
'macro_precision', 'macro_recall', 'macro_f1',
'micro_precision', 'micro_recall', 'micro_f1',
'per_mean_95ci_lower', 'per_mean_95ci_upper', 'evaluation_time_seconds'
],
'value': [
aggregate_per, np.mean(per_trial_errors), np.median(per_trial_errors), np.std(per_trial_errors),
min(per_trial_errors), max(per_trial_errors),
sum(1 for x in per_trial_errors if x == 0),
100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors),
total_true_length, total_edit_distance, total_insertions, total_deletions,
total_substitutions, seq_metrics['exact_match_accuracy'], seq_metrics['exact_matches'],
phoneme_metrics['macro_avg']['precision'], phoneme_metrics['macro_avg']['recall'],
phoneme_metrics['macro_avg']['f1'], phoneme_metrics['micro_avg']['precision'],
phoneme_metrics['micro_avg']['recall'], phoneme_metrics['micro_avg']['f1'],
per_ci_lower, per_ci_upper, eval_duration
]
}
summary_df = pd.DataFrame(summary_data)
summary_df.to_csv(summary_output_path, index=False)
print(f'\nComprehensive results saved to:')
print(f' Basic predictions: {output_path}')
print(f' Detailed metrics: {detailed_output_path}')
print(f' Summary metrics: {summary_output_path}')
if confusion_matrix_path:
print(f' Confusion matrix: {confusion_matrix_path}')
else:
print(f' Confusion matrix: Not generated')
else:
print(f'\nResults saved to: {output_path}')
print(f'TTA-E configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}')
print(f'Enabled augmentations: {", ".join(enabled_augs.split("_"))}')