2025-10-06 15:17:44 +08:00
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
import time
|
|
|
|
from tqdm import tqdm
|
|
|
|
import editdistance
|
|
|
|
import argparse
|
2025-10-12 09:11:32 +08:00
|
|
|
from collections import defaultdict, Counter
|
|
|
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import seaborn as sns
|
2025-10-06 15:17:44 +08:00
|
|
|
|
|
|
|
# Add parent directories to path to import models
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
|
|
|
|
|
|
|
|
from model_training.rnn_model import GRUDecoder
|
|
|
|
from model_training_lstm.rnn_model import LSTMDecoder
|
|
|
|
from model_training.evaluate_model_helpers import *
|
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# ========== COMPREHENSIVE EVALUATION METRICS FUNCTIONS ==========
|
|
|
|
|
|
|
|
def calculate_phoneme_operations(true_seq, pred_seq):
|
|
|
|
"""Calculate insertions, deletions, and substitutions using dynamic programming."""
|
|
|
|
m, n = len(true_seq), len(pred_seq)
|
|
|
|
|
|
|
|
# DP matrix for edit distance calculation
|
|
|
|
dp = [[0] * (n + 1) for _ in range(m + 1)]
|
|
|
|
|
|
|
|
# Initialize base cases
|
|
|
|
for i in range(m + 1):
|
|
|
|
dp[i][0] = i # deletions
|
|
|
|
for j in range(n + 1):
|
|
|
|
dp[0][j] = j # insertions
|
|
|
|
|
|
|
|
# Fill DP matrix
|
|
|
|
for i in range(1, m + 1):
|
|
|
|
for j in range(1, n + 1):
|
|
|
|
if true_seq[i-1] == pred_seq[j-1]:
|
|
|
|
dp[i][j] = dp[i-1][j-1] # match
|
|
|
|
else:
|
|
|
|
dp[i][j] = 1 + min(
|
|
|
|
dp[i-1][j], # deletion
|
|
|
|
dp[i][j-1], # insertion
|
|
|
|
dp[i-1][j-1] # substitution
|
|
|
|
)
|
|
|
|
|
|
|
|
# Backtrack to count operations
|
|
|
|
i, j = m, n
|
|
|
|
insertions = deletions = substitutions = 0
|
|
|
|
|
|
|
|
while i > 0 or j > 0:
|
|
|
|
if i > 0 and j > 0 and true_seq[i-1] == pred_seq[j-1]:
|
|
|
|
i -= 1
|
|
|
|
j -= 1
|
|
|
|
elif i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + 1:
|
|
|
|
substitutions += 1
|
|
|
|
i -= 1
|
|
|
|
j -= 1
|
|
|
|
elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
|
|
|
|
deletions += 1
|
|
|
|
i -= 1
|
|
|
|
else:
|
|
|
|
insertions += 1
|
|
|
|
j -= 1
|
|
|
|
|
|
|
|
return insertions, deletions, substitutions
|
|
|
|
|
|
|
|
def calculate_phoneme_level_metrics(true_seqs, pred_seqs):
|
|
|
|
"""Calculate phoneme-level precision, recall, and F1-score using alignment."""
|
|
|
|
all_true_phonemes = []
|
|
|
|
all_pred_phonemes = []
|
|
|
|
|
|
|
|
# Align sequences properly to handle different lengths
|
|
|
|
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
|
|
|
|
# Use dynamic programming alignment similar to edit distance
|
|
|
|
# For simplicity, we'll pad shorter sequences and align based on minimum length
|
|
|
|
min_len = min(len(true_seq), len(pred_seq))
|
|
|
|
max_len = max(len(true_seq), len(pred_seq))
|
|
|
|
|
|
|
|
# Take only the overlapping part for fair comparison
|
|
|
|
if min_len > 0:
|
|
|
|
# Align by taking proportional samples
|
|
|
|
if len(true_seq) == len(pred_seq):
|
|
|
|
# Same length - direct alignment
|
|
|
|
all_true_phonemes.extend(true_seq)
|
|
|
|
all_pred_phonemes.extend(pred_seq)
|
|
|
|
else:
|
|
|
|
# Different lengths - align by truncating to shorter length
|
|
|
|
truncated_true = true_seq[:min_len]
|
|
|
|
truncated_pred = pred_seq[:min_len]
|
|
|
|
all_true_phonemes.extend(truncated_true)
|
|
|
|
all_pred_phonemes.extend(truncated_pred)
|
|
|
|
|
|
|
|
# Verify lengths match
|
|
|
|
if len(all_true_phonemes) != len(all_pred_phonemes):
|
|
|
|
print(f"Warning: After alignment, lengths still don't match: {len(all_true_phonemes)} vs {len(all_pred_phonemes)}")
|
|
|
|
# Take minimum length to avoid sklearn error
|
|
|
|
min_len = min(len(all_true_phonemes), len(all_pred_phonemes))
|
|
|
|
all_true_phonemes = all_true_phonemes[:min_len]
|
|
|
|
all_pred_phonemes = all_pred_phonemes[:min_len]
|
|
|
|
|
|
|
|
if len(all_true_phonemes) == 0:
|
|
|
|
# Handle empty case
|
|
|
|
return {
|
|
|
|
'phoneme_metrics': {
|
|
|
|
'phonemes': [],
|
|
|
|
'precision': np.array([]),
|
|
|
|
'recall': np.array([]),
|
|
|
|
'f1': np.array([]),
|
|
|
|
'support': np.array([])
|
|
|
|
},
|
|
|
|
'macro_avg': {
|
|
|
|
'precision': 0.0,
|
|
|
|
'recall': 0.0,
|
|
|
|
'f1': 0.0
|
|
|
|
},
|
|
|
|
'micro_avg': {
|
|
|
|
'precision': 0.0,
|
|
|
|
'recall': 0.0,
|
|
|
|
'f1': 0.0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
# Get unique phonemes
|
|
|
|
unique_phonemes = sorted(list(set(all_true_phonemes + all_pred_phonemes)))
|
|
|
|
|
|
|
|
# Calculate metrics for each phoneme class
|
|
|
|
precision, recall, f1, support = precision_recall_fscore_support(
|
|
|
|
all_true_phonemes, all_pred_phonemes,
|
|
|
|
labels=unique_phonemes, average=None, zero_division=0
|
|
|
|
)
|
|
|
|
|
|
|
|
# Calculate macro averages
|
|
|
|
macro_precision = np.mean(precision)
|
|
|
|
macro_recall = np.mean(recall)
|
|
|
|
macro_f1 = np.mean(f1)
|
|
|
|
|
|
|
|
# Calculate micro averages
|
|
|
|
micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
|
|
|
|
all_true_phonemes, all_pred_phonemes, average='micro', zero_division=0
|
|
|
|
)
|
|
|
|
|
|
|
|
return {
|
|
|
|
'phoneme_metrics': {
|
|
|
|
'phonemes': unique_phonemes,
|
|
|
|
'precision': precision,
|
|
|
|
'recall': recall,
|
|
|
|
'f1': f1,
|
|
|
|
'support': support
|
|
|
|
},
|
|
|
|
'macro_avg': {
|
|
|
|
'precision': macro_precision,
|
|
|
|
'recall': macro_recall,
|
|
|
|
'f1': macro_f1
|
|
|
|
},
|
|
|
|
'micro_avg': {
|
|
|
|
'precision': micro_precision,
|
|
|
|
'recall': micro_recall,
|
|
|
|
'f1': micro_f1
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
def calculate_sequence_level_metrics(true_seqs, pred_seqs):
|
|
|
|
"""Calculate sequence-level metrics."""
|
|
|
|
exact_matches = 0
|
|
|
|
total_sequences = len(true_seqs)
|
|
|
|
|
|
|
|
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
|
|
|
|
if true_seq == pred_seq:
|
|
|
|
exact_matches += 1
|
|
|
|
|
|
|
|
exact_match_accuracy = exact_matches / total_sequences if total_sequences > 0 else 0
|
|
|
|
|
|
|
|
return {
|
|
|
|
'exact_match_accuracy': exact_match_accuracy,
|
|
|
|
'exact_matches': exact_matches,
|
|
|
|
'total_sequences': total_sequences
|
|
|
|
}
|
|
|
|
|
|
|
|
def calculate_confidence_intervals(errors, confidence=0.95):
|
|
|
|
"""Calculate confidence intervals for error rates."""
|
|
|
|
n = len(errors)
|
|
|
|
if n == 0:
|
|
|
|
return 0, 0, 0
|
|
|
|
|
|
|
|
mean_error = np.mean(errors)
|
|
|
|
std_error = np.std(errors, ddof=1) if n > 1 else 0
|
|
|
|
|
|
|
|
# Use t-distribution for small samples, normal for large samples
|
|
|
|
if n < 30:
|
|
|
|
from scipy import stats
|
|
|
|
t_val = stats.t.ppf((1 + confidence) / 2, n - 1)
|
|
|
|
margin = t_val * std_error / np.sqrt(n)
|
|
|
|
else:
|
|
|
|
# For large samples, use normal distribution
|
|
|
|
z_val = 1.96 if confidence == 0.95 else 2.576 # 95% or 99%
|
|
|
|
margin = z_val * std_error / np.sqrt(n)
|
|
|
|
|
|
|
|
return mean_error, mean_error - margin, mean_error + margin
|
|
|
|
|
|
|
|
def calculate_phoneme_frequency_analysis(true_seqs, pred_seqs):
|
|
|
|
"""Analyze phoneme frequency and error patterns."""
|
|
|
|
true_phoneme_counts = Counter()
|
|
|
|
pred_phoneme_counts = Counter()
|
|
|
|
error_patterns = defaultdict(int) # Maps (true_phoneme, pred_phoneme) -> count
|
|
|
|
|
|
|
|
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
|
|
|
|
true_phoneme_counts.update(true_seq)
|
|
|
|
pred_phoneme_counts.update(pred_seq)
|
|
|
|
|
|
|
|
# For error pattern analysis, align sequences by taking minimum length
|
|
|
|
min_len = min(len(true_seq), len(pred_seq))
|
|
|
|
if min_len > 0:
|
|
|
|
for i in range(min_len):
|
|
|
|
true_p = true_seq[i]
|
|
|
|
pred_p = pred_seq[i]
|
|
|
|
if true_p != pred_p:
|
|
|
|
error_patterns[(true_p, pred_p)] += 1
|
|
|
|
|
|
|
|
return {
|
|
|
|
'true_phoneme_counts': true_phoneme_counts,
|
|
|
|
'pred_phoneme_counts': pred_phoneme_counts,
|
|
|
|
'error_patterns': dict(error_patterns)
|
|
|
|
}
|
|
|
|
|
|
|
|
def generate_confusion_matrix_plot(true_seqs, pred_seqs, save_path=None):
|
|
|
|
"""Generate and save confusion matrix plot for phonemes."""
|
|
|
|
all_true = []
|
|
|
|
all_pred = []
|
|
|
|
|
|
|
|
# Align sequences properly to handle different lengths
|
|
|
|
for true_seq, pred_seq in zip(true_seqs, pred_seqs):
|
|
|
|
min_len = min(len(true_seq), len(pred_seq))
|
|
|
|
if min_len > 0:
|
|
|
|
# Truncate both to same length for fair comparison
|
|
|
|
all_true.extend(true_seq[:min_len])
|
|
|
|
all_pred.extend(pred_seq[:min_len])
|
|
|
|
|
|
|
|
if len(all_true) == 0 or len(all_pred) == 0:
|
|
|
|
print("Warning: No aligned phonemes found for confusion matrix")
|
|
|
|
return None, []
|
|
|
|
|
|
|
|
unique_phonemes = sorted(list(set(all_true + all_pred)))
|
|
|
|
cm = confusion_matrix(all_true, all_pred, labels=unique_phonemes)
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 10))
|
|
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
|
|
|
xticklabels=unique_phonemes, yticklabels=unique_phonemes)
|
|
|
|
plt.title('Phoneme Confusion Matrix')
|
|
|
|
plt.xlabel('Predicted Phonemes')
|
|
|
|
plt.ylabel('True Phonemes')
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
if save_path:
|
|
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
|
|
print(f"Confusion matrix saved to: {save_path}")
|
|
|
|
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
return cm, unique_phonemes
|
|
|
|
|
2025-10-06 15:17:44 +08:00
|
|
|
# argument parser for command line arguments
|
|
|
|
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
|
2025-10-12 09:11:32 +08:00
|
|
|
parser.add_argument('--gru_model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
|
2025-10-06 15:17:44 +08:00
|
|
|
help='Path to the pretrained GRU model directory.')
|
2025-10-12 09:11:32 +08:00
|
|
|
parser.add_argument('--lstm_model_path', type=str, default='../data/t15_pretrained_lstm_baseline',
|
2025-10-06 15:17:44 +08:00
|
|
|
help='Path to the pretrained LSTM model directory.')
|
|
|
|
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
|
|
|
help='Path to the dataset directory (relative to the current working directory).')
|
|
|
|
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
|
|
|
|
help='Evaluation type: "val" for validation set, "test" for test set.')
|
|
|
|
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
|
|
|
help='Path to the CSV file with metadata about the dataset.')
|
|
|
|
parser.add_argument('--gpu_number', type=int, default=0,
|
|
|
|
help='GPU number to use for model inference. Set to -1 to use CPU.')
|
|
|
|
parser.add_argument('--gru_weight', type=float, default=1,
|
|
|
|
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.')
|
|
|
|
# TTA parameters
|
|
|
|
parser.add_argument('--tta_samples', type=int, default=4,
|
|
|
|
help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.')
|
|
|
|
parser.add_argument('--tta_noise_std', type=float, default=0.01,
|
|
|
|
help='Standard deviation for TTA noise augmentation.')
|
|
|
|
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
|
|
|
|
help='Range for TTA smoothing kernel variation (±range from default).')
|
|
|
|
parser.add_argument('--tta_scale_range', type=float, default=0.05,
|
|
|
|
help='Range for TTA amplitude scaling (±range from 1.0).')
|
|
|
|
parser.add_argument('--tta_cut_max', type=int, default=3,
|
|
|
|
help='Maximum number of timesteps to cut from beginning in TTA.')
|
|
|
|
# TTA样本权重配置
|
|
|
|
parser.add_argument('--tta_weights', type=str, default='2.0,1.0,0.0,0.0,0.0',
|
|
|
|
help='Comma-separated weights for each TTA augmentation type: original,noise,scale,shift,smooth. Set to 0 to disable.')
|
2025-10-12 09:11:32 +08:00
|
|
|
|
|
|
|
# Evaluation metrics configuration
|
|
|
|
parser.add_argument('--skip_confusion_matrix', action='store_true',
|
|
|
|
help='Skip confusion matrix generation to save time.')
|
|
|
|
parser.add_argument('--detailed_metrics', action='store_true', default=True,
|
|
|
|
help='Calculate detailed phoneme-level and sequence-level metrics.')
|
|
|
|
parser.add_argument('--save_plots', action='store_true', default=True,
|
|
|
|
help='Save confusion matrix and other plots.')
|
|
|
|
|
|
|
|
# Memory management options
|
|
|
|
parser.add_argument('--max_trials_per_session', type=int, default=1000,
|
|
|
|
help='Maximum number of trials to load per session (for memory management).')
|
|
|
|
parser.add_argument('--batch_inference', action='store_true',
|
|
|
|
help='Process trials in batches to reduce memory usage during inference.')
|
|
|
|
parser.add_argument('--batch_size', type=int, default=10,
|
|
|
|
help='Batch size for inference when using --batch_inference.')
|
|
|
|
|
2025-10-06 15:17:44 +08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
# Model paths
|
|
|
|
gru_model_path = args.gru_model_path
|
|
|
|
lstm_model_path = args.lstm_model_path
|
|
|
|
data_dir = args.data_dir
|
|
|
|
|
|
|
|
# Ensemble weights
|
|
|
|
gru_weight = args.gru_weight
|
|
|
|
lstm_weight = 1.0 - gru_weight
|
|
|
|
|
|
|
|
# TTA parameters
|
|
|
|
tta_samples = args.tta_samples
|
|
|
|
tta_noise_std = args.tta_noise_std
|
|
|
|
tta_smooth_range = args.tta_smooth_range
|
|
|
|
tta_scale_range = args.tta_scale_range
|
|
|
|
tta_cut_max = args.tta_cut_max
|
|
|
|
|
|
|
|
# Parse TTA weights
|
|
|
|
tta_weights_str = args.tta_weights.split(',')
|
|
|
|
if len(tta_weights_str) != 5:
|
|
|
|
raise ValueError("TTA weights must have exactly 5 values: original,noise,scale,shift,smooth")
|
|
|
|
tta_weights = {
|
|
|
|
'original': float(tta_weights_str[0]),
|
|
|
|
'noise': float(tta_weights_str[1]),
|
|
|
|
'scale': float(tta_weights_str[2]),
|
|
|
|
'shift': float(tta_weights_str[3]),
|
|
|
|
'smooth': float(tta_weights_str[4])
|
|
|
|
}
|
|
|
|
|
|
|
|
# Calculate effective number of TTA samples based on enabled augmentations
|
|
|
|
enabled_augmentations = [k for k, v in tta_weights.items() if v > 0]
|
|
|
|
effective_tta_samples = len(enabled_augmentations)
|
|
|
|
|
|
|
|
print(f"Improved TTA-E Configuration:")
|
|
|
|
print(f"GRU weight: {gru_weight:.2f} (balanced for better ensemble)")
|
|
|
|
print(f"LSTM weight: {lstm_weight:.2f}")
|
|
|
|
print(f"TTA augmentation weights:")
|
|
|
|
for aug_type, weight in tta_weights.items():
|
|
|
|
status = "✓ enabled" if weight > 0 else "✗ disabled"
|
|
|
|
print(f" - {aug_type}: {weight:.1f} ({status})")
|
|
|
|
print(f"Effective TTA samples: {effective_tta_samples}")
|
|
|
|
print(f"TTA noise std: {tta_noise_std}")
|
|
|
|
print(f"TTA smooth range: ±{tta_smooth_range}")
|
|
|
|
print(f"TTA scale range: ±{tta_scale_range}")
|
|
|
|
print(f"TTA max cut: {tta_cut_max} timesteps")
|
|
|
|
print(f"Key improvements:")
|
|
|
|
print(f" - Customizable TTA sample weights")
|
|
|
|
print(f" - Probability-based ensemble instead of logit averaging")
|
|
|
|
print(f" - Geometric mean for more robust fusion")
|
|
|
|
print(f" - Direct PER evaluation without language model dependency")
|
|
|
|
print(f"GRU model path: {gru_model_path}")
|
|
|
|
print(f"LSTM model path: {lstm_model_path}")
|
|
|
|
print()
|
|
|
|
|
|
|
|
# Define evaluation type
|
|
|
|
eval_type = args.eval_type
|
|
|
|
|
|
|
|
# Load CSV file
|
|
|
|
b2txt_csv_df = pd.read_csv(args.csv_path)
|
|
|
|
|
|
|
|
# Load model arguments for both models
|
|
|
|
gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml'))
|
|
|
|
lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml'))
|
|
|
|
|
|
|
|
# Set up GPU device
|
|
|
|
gpu_number = args.gpu_number
|
|
|
|
if torch.cuda.is_available() and gpu_number >= 0:
|
|
|
|
if gpu_number >= torch.cuda.device_count():
|
|
|
|
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
|
|
|
|
device = f'cuda:{gpu_number}'
|
|
|
|
device = torch.device(device)
|
|
|
|
print(f'Using {device} for model inference.')
|
|
|
|
else:
|
|
|
|
if gpu_number >= 0:
|
|
|
|
print(f'GPU number {gpu_number} requested but not available.')
|
|
|
|
print('Using CPU for model inference.')
|
|
|
|
device = torch.device('cpu')
|
|
|
|
|
|
|
|
# Define GRU model
|
|
|
|
gru_model = GRUDecoder(
|
|
|
|
neural_dim=gru_model_args['model']['n_input_features'],
|
|
|
|
n_units=gru_model_args['model']['n_units'],
|
|
|
|
n_days=len(gru_model_args['dataset']['sessions']),
|
|
|
|
n_classes=gru_model_args['dataset']['n_classes'],
|
|
|
|
rnn_dropout=gru_model_args['model']['rnn_dropout'],
|
|
|
|
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
|
|
|
|
n_layers=gru_model_args['model']['n_layers'],
|
|
|
|
patch_size=gru_model_args['model']['patch_size'],
|
|
|
|
patch_stride=gru_model_args['model']['patch_stride'],
|
|
|
|
)
|
|
|
|
|
|
|
|
# Load GRU model weights
|
|
|
|
gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'),
|
|
|
|
weights_only=False, map_location=device)
|
|
|
|
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
|
|
|
|
for key in list(gru_checkpoint['model_state_dict'].keys()):
|
|
|
|
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
|
|
|
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
|
|
|
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
|
|
|
|
|
|
|
|
# Define LSTM model
|
|
|
|
lstm_model = LSTMDecoder(
|
|
|
|
neural_dim=lstm_model_args['model']['n_input_features'],
|
|
|
|
n_units=lstm_model_args['model']['n_units'],
|
|
|
|
n_days=len(lstm_model_args['dataset']['sessions']),
|
|
|
|
n_classes=lstm_model_args['dataset']['n_classes'],
|
|
|
|
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
|
|
|
|
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
|
|
|
|
n_layers=lstm_model_args['model']['n_layers'],
|
|
|
|
patch_size=lstm_model_args['model']['patch_size'],
|
|
|
|
patch_stride=lstm_model_args['model']['patch_stride'],
|
|
|
|
)
|
|
|
|
|
|
|
|
# Load LSTM model weights
|
|
|
|
lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'),
|
|
|
|
weights_only=False, map_location=device)
|
|
|
|
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
|
|
|
|
for key in list(lstm_checkpoint['model_state_dict'].keys()):
|
|
|
|
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
|
|
|
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
|
|
|
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
|
|
|
|
|
|
|
|
# Add models to device
|
|
|
|
gru_model.to(device)
|
|
|
|
lstm_model.to(device)
|
|
|
|
|
|
|
|
# Set models to eval mode
|
|
|
|
gru_model.eval()
|
|
|
|
lstm_model.eval()
|
|
|
|
|
|
|
|
print("Both models loaded successfully!")
|
|
|
|
print()
|
|
|
|
|
|
|
|
# TTA-E inference function
|
|
|
|
def runTTAEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
|
|
|
|
device, gru_weight, lstm_weight, tta_weights, tta_noise_std,
|
|
|
|
tta_smooth_range, tta_scale_range, tta_cut_max):
|
|
|
|
"""
|
|
|
|
Run Customizable TTA-E (Test Time Augmentation + Ensemble) inference:
|
|
|
|
1. Apply selected data augmentations based on weights
|
|
|
|
2. Run both GRU and LSTM models on each augmented version
|
|
|
|
3. Use probability-based ensemble with customizable sample weights
|
|
|
|
4. Average predictions in probability space for better stability
|
|
|
|
"""
|
|
|
|
all_gru_probs = []
|
|
|
|
all_lstm_probs = []
|
|
|
|
sample_weights = []
|
|
|
|
|
|
|
|
# Get default smoothing parameters
|
|
|
|
default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
|
|
|
|
default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
|
|
|
|
|
|
|
|
# Generate augmented samples based on enabled augmentations
|
|
|
|
augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
|
|
|
|
|
|
|
|
for aug_type in augmentation_types:
|
|
|
|
if tta_weights[aug_type] <= 0:
|
|
|
|
continue # Skip disabled augmentations
|
|
|
|
|
|
|
|
x_augmented = x.clone()
|
|
|
|
|
|
|
|
if aug_type == 'original':
|
|
|
|
# Original data (baseline)
|
|
|
|
pass
|
|
|
|
elif aug_type == 'noise':
|
|
|
|
# Add Gaussian noise with varying intensity
|
|
|
|
noise_scale = tta_noise_std * (0.5 + 0.5 * np.random.rand())
|
|
|
|
noise = torch.randn_like(x_augmented) * noise_scale
|
|
|
|
x_augmented = x_augmented + noise
|
|
|
|
elif aug_type == 'scale':
|
|
|
|
# Amplitude scaling with more variation
|
|
|
|
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range
|
|
|
|
x_augmented = x_augmented * scale_factor
|
|
|
|
elif aug_type == 'shift' and tta_cut_max > 0:
|
|
|
|
# Time shift (circular shift)
|
|
|
|
shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8))
|
|
|
|
x_augmented = torch.cat([x_augmented[:, shift_amount:, :],
|
|
|
|
x_augmented[:, :shift_amount, :]], dim=1)
|
|
|
|
elif aug_type == 'smooth':
|
|
|
|
# Apply smoothing variation after augmentation
|
|
|
|
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range
|
|
|
|
varied_smooth_std = max(0.3, default_smooth_std + smooth_variation)
|
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# Use autocast for efficiency - auto-detect device type
|
|
|
|
device_type = "cuda" if device.type == "cuda" else "cpu"
|
|
|
|
use_amp = gru_model_args.get('use_amp', False) and device.type == "cuda"
|
|
|
|
with torch.autocast(device_type=device_type, enabled=use_amp, dtype=torch.bfloat16 if device.type == "cuda" else torch.float32):
|
2025-10-06 15:17:44 +08:00
|
|
|
|
|
|
|
# Apply Gaussian smoothing
|
|
|
|
if aug_type == 'smooth':
|
|
|
|
# Use varied smoothing
|
|
|
|
x_smoothed = gauss_smooth(
|
|
|
|
inputs=x_augmented,
|
|
|
|
device=device,
|
|
|
|
smooth_kernel_std=varied_smooth_std,
|
|
|
|
smooth_kernel_size=default_smooth_size,
|
|
|
|
padding='valid',
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# Use default smoothing
|
|
|
|
x_smoothed = gauss_smooth(
|
|
|
|
inputs=x_augmented,
|
|
|
|
device=device,
|
|
|
|
smooth_kernel_std=default_smooth_std,
|
|
|
|
smooth_kernel_size=default_smooth_size,
|
|
|
|
padding='valid',
|
|
|
|
)
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
# Get GRU logits and convert to probabilities
|
|
|
|
gru_logits, _ = gru_model(
|
|
|
|
x=x_smoothed,
|
|
|
|
day_idx=torch.tensor([input_layer], device=device),
|
|
|
|
states=None,
|
|
|
|
return_state=True,
|
|
|
|
)
|
|
|
|
gru_probs = torch.softmax(gru_logits, dim=-1)
|
|
|
|
|
|
|
|
# Get LSTM logits and convert to probabilities
|
|
|
|
lstm_logits, _ = lstm_model(
|
|
|
|
x=x_smoothed,
|
|
|
|
day_idx=torch.tensor([input_layer], device=device),
|
|
|
|
states=None,
|
|
|
|
return_state=True,
|
|
|
|
)
|
|
|
|
lstm_probs = torch.softmax(lstm_logits, dim=-1)
|
|
|
|
|
|
|
|
all_gru_probs.append(gru_probs)
|
|
|
|
all_lstm_probs.append(lstm_probs)
|
|
|
|
sample_weights.append(tta_weights[aug_type])
|
|
|
|
|
|
|
|
if len(all_gru_probs) == 0:
|
|
|
|
raise ValueError("No TTA augmentations are enabled. Please set at least one weight > 0.")
|
|
|
|
|
|
|
|
# TTA fusion: Handle potentially different tensor shapes
|
|
|
|
if len(all_gru_probs) > 1:
|
|
|
|
# Find the minimum sequence length among all TTA samples
|
|
|
|
min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs])
|
|
|
|
|
|
|
|
# Truncate all tensors to the minimum length
|
|
|
|
truncated_gru_probs = []
|
|
|
|
truncated_lstm_probs = []
|
|
|
|
for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs):
|
|
|
|
if gru_probs.shape[1] > min_length:
|
|
|
|
truncated_gru_probs.append(gru_probs[:, :min_length, :])
|
|
|
|
else:
|
|
|
|
truncated_gru_probs.append(gru_probs)
|
|
|
|
|
|
|
|
if lstm_probs.shape[1] > min_length:
|
|
|
|
truncated_lstm_probs.append(lstm_probs[:, :min_length, :])
|
|
|
|
else:
|
|
|
|
truncated_lstm_probs.append(lstm_probs)
|
|
|
|
|
|
|
|
# Weighted average probabilities across TTA samples
|
|
|
|
sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32, device=device)
|
|
|
|
sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum() # Normalize weights
|
|
|
|
|
|
|
|
weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0])
|
|
|
|
weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0])
|
|
|
|
|
|
|
|
for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)):
|
|
|
|
weighted_gru_probs += weight * gru_probs
|
|
|
|
weighted_lstm_probs += weight * lstm_probs
|
|
|
|
|
|
|
|
avg_gru_probs = weighted_gru_probs
|
|
|
|
avg_lstm_probs = weighted_lstm_probs
|
|
|
|
else:
|
|
|
|
avg_gru_probs = all_gru_probs[0]
|
|
|
|
avg_lstm_probs = all_lstm_probs[0]
|
|
|
|
|
|
|
|
# Improved ensemble: geometric mean of probabilities (more robust than weighted average)
|
|
|
|
# Apply epsilon smoothing to avoid log(0)
|
|
|
|
epsilon = 1e-8
|
|
|
|
avg_gru_probs = avg_gru_probs + epsilon
|
|
|
|
avg_lstm_probs = avg_lstm_probs + epsilon
|
|
|
|
|
|
|
|
# Weighted geometric mean in log space
|
|
|
|
log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) +
|
|
|
|
lstm_weight * torch.log(avg_lstm_probs))
|
|
|
|
ensemble_probs = torch.exp(log_ensemble_probs)
|
|
|
|
|
|
|
|
# Normalize to ensure proper probability distribution
|
|
|
|
ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
# Convert back to logits for compatibility with downstream code
|
|
|
|
final_logits = torch.log(ensemble_probs + epsilon)
|
|
|
|
|
|
|
|
# Convert from bfloat16 to float32
|
|
|
|
return final_logits.float().cpu().numpy()
|
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# Load data for each session with memory management
|
2025-10-06 15:17:44 +08:00
|
|
|
test_data = {}
|
|
|
|
total_test_trials = 0
|
2025-10-12 09:11:32 +08:00
|
|
|
|
|
|
|
print("Loading data with memory management...")
|
2025-10-06 15:17:44 +08:00
|
|
|
for session in gru_model_args['dataset']['sessions']:
|
|
|
|
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
|
|
|
|
if f'data_{eval_type}.hdf5' in files:
|
|
|
|
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
|
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
try:
|
|
|
|
# Try to load data normally first
|
|
|
|
print(f'Loading {session}...')
|
|
|
|
data = load_h5py_file(eval_file, b2txt_csv_df)
|
|
|
|
test_data[session] = data
|
|
|
|
total_test_trials += len(test_data[session]["neural_features"])
|
|
|
|
print(f'Successfully loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
|
|
|
|
except (MemoryError, np._core._exceptions._ArrayMemoryError) as e:
|
|
|
|
print(f'Memory error loading {session}: {e}')
|
|
|
|
print(f'Attempting to load {session} in smaller chunks...')
|
|
|
|
|
|
|
|
# If memory error, try loading in chunks or skip this session
|
|
|
|
import h5py
|
|
|
|
try:
|
|
|
|
with h5py.File(eval_file, 'r') as h5file:
|
|
|
|
# Get the size first
|
|
|
|
input_features_shape = h5file['input_features'].shape
|
|
|
|
print(f'Session {session} has {input_features_shape[0]} trials with {input_features_shape[1]} features')
|
2025-10-06 15:17:44 +08:00
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# If it's too large, skip or implement chunk loading
|
|
|
|
max_trials_per_session = args.max_trials_per_session # Limit trials per session
|
|
|
|
if input_features_shape[0] > max_trials_per_session:
|
|
|
|
print(f'Session {session} has {input_features_shape[0]} trials, limiting to {max_trials_per_session} for memory management')
|
|
|
|
# Could implement chunked loading here if needed
|
|
|
|
print(f'Skipping session {session} due to memory constraints')
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
# Try one more time with garbage collection
|
|
|
|
import gc
|
|
|
|
gc.collect()
|
|
|
|
data = load_h5py_file(eval_file, b2txt_csv_df)
|
|
|
|
test_data[session] = data
|
|
|
|
total_test_trials += len(test_data[session]["neural_features"])
|
|
|
|
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session} after GC.')
|
|
|
|
except Exception as e2:
|
|
|
|
print(f'Failed to load session {session}: {e2}')
|
|
|
|
print(f'Skipping session {session}')
|
|
|
|
continue
|
|
|
|
|
|
|
|
print(f'Total number of {eval_type} trials loaded: {total_test_trials}')
|
|
|
|
if total_test_trials == 0:
|
|
|
|
print("ERROR: No trials loaded! Check data paths and memory availability.")
|
|
|
|
sys.exit(1)
|
2025-10-06 15:17:44 +08:00
|
|
|
print()
|
|
|
|
|
|
|
|
# Put neural data through the TTA-E ensemble model to get phoneme predictions (logits)
|
2025-10-12 09:11:32 +08:00
|
|
|
import gc
|
2025-10-06 15:17:44 +08:00
|
|
|
with tqdm(total=total_test_trials, desc=f'Customizable TTA-E inference ({effective_tta_samples} augmentation types)', unit='trial') as pbar:
|
|
|
|
for session, data in test_data.items():
|
|
|
|
|
|
|
|
data['logits'] = []
|
|
|
|
data['pred_seq'] = []
|
|
|
|
input_layer = gru_model_args['dataset']['sessions'].index(session)
|
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# Process trials in batches if requested
|
|
|
|
if args.batch_inference:
|
|
|
|
batch_size = args.batch_size
|
|
|
|
num_trials = len(data['neural_features'])
|
|
|
|
|
|
|
|
for batch_start in range(0, num_trials, batch_size):
|
|
|
|
batch_end = min(batch_start + batch_size, num_trials)
|
|
|
|
batch_logits = []
|
|
|
|
|
|
|
|
for trial in range(batch_start, batch_end):
|
|
|
|
# Get neural input for the trial
|
|
|
|
neural_input = data['neural_features'][trial]
|
|
|
|
|
|
|
|
# Add batch dimension
|
|
|
|
neural_input = np.expand_dims(neural_input, axis=0)
|
|
|
|
|
|
|
|
# Convert to torch tensor - use float32 to avoid dtype issues with smoothing
|
|
|
|
neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
|
2025-10-06 15:17:44 +08:00
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# Run TTA-E decoding step with customizable weights
|
|
|
|
ensemble_logits = runTTAEnsembleDecodingStep(
|
|
|
|
neural_input, input_layer, gru_model, lstm_model,
|
|
|
|
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
|
|
|
|
tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
|
|
|
|
)
|
|
|
|
batch_logits.append(ensemble_logits)
|
2025-10-06 15:17:44 +08:00
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# Clear GPU memory
|
|
|
|
del neural_input
|
|
|
|
if device.type == 'cuda':
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
pbar.update(1)
|
|
|
|
|
|
|
|
# Add batch results
|
|
|
|
data['logits'].extend(batch_logits)
|
|
|
|
|
|
|
|
# Clear batch memory
|
|
|
|
del batch_logits
|
|
|
|
gc.collect()
|
|
|
|
else:
|
|
|
|
# Process all trials individually (original method)
|
|
|
|
for trial in range(len(data['neural_features'])):
|
|
|
|
# Get neural input for the trial
|
|
|
|
neural_input = data['neural_features'][trial]
|
|
|
|
|
|
|
|
# Add batch dimension
|
|
|
|
neural_input = np.expand_dims(neural_input, axis=0)
|
|
|
|
|
|
|
|
# Convert to torch tensor - use float32 to avoid dtype issues with smoothing
|
|
|
|
neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
|
|
|
|
|
|
|
|
# Run TTA-E decoding step with customizable weights
|
|
|
|
ensemble_logits = runTTAEnsembleDecodingStep(
|
|
|
|
neural_input, input_layer, gru_model, lstm_model,
|
|
|
|
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
|
|
|
|
tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
|
|
|
|
)
|
|
|
|
data['logits'].append(ensemble_logits)
|
|
|
|
|
|
|
|
# Clear memory periodically
|
|
|
|
if trial % 50 == 0:
|
|
|
|
del neural_input
|
|
|
|
if device.type == 'cuda':
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
pbar.update(1)
|
|
|
|
|
|
|
|
# Clear session data memory after processing
|
|
|
|
gc.collect()
|
2025-10-06 15:17:44 +08:00
|
|
|
|
|
|
|
pbar.close()
|
|
|
|
|
|
|
|
# Convert logits to phoneme sequences and print them out
|
|
|
|
results = {
|
|
|
|
'session': [],
|
|
|
|
'block': [],
|
|
|
|
'trial': [],
|
|
|
|
'true_sentence': [],
|
|
|
|
'pred_phonemes': [],
|
|
|
|
}
|
|
|
|
|
|
|
|
for session, data in test_data.items():
|
|
|
|
data['pred_seq'] = []
|
|
|
|
for trial in range(len(data['logits'])):
|
|
|
|
logits = data['logits'][trial][0]
|
|
|
|
pred_seq = np.argmax(logits, axis=-1)
|
|
|
|
# Remove blanks (0)
|
|
|
|
pred_seq = [int(p) for p in pred_seq if p != 0]
|
|
|
|
# Remove consecutive duplicates
|
|
|
|
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
|
|
|
|
# Convert to phonemes
|
|
|
|
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
|
|
|
|
# Add to data
|
|
|
|
data['pred_seq'].append(pred_seq)
|
|
|
|
|
|
|
|
# Store results
|
|
|
|
results['session'].append(session)
|
|
|
|
results['block'].append(data['block_num'][trial])
|
|
|
|
results['trial'].append(data['trial_num'][trial])
|
|
|
|
if eval_type == 'val':
|
|
|
|
results['true_sentence'].append(data['sentence_label'][trial])
|
|
|
|
else:
|
|
|
|
results['true_sentence'].append(None)
|
|
|
|
results['pred_phonemes'].append(pred_seq)
|
|
|
|
|
|
|
|
# Print out the predicted sequences
|
|
|
|
block_num = data['block_num'][trial]
|
|
|
|
trial_num = data['trial_num'][trial]
|
|
|
|
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
|
|
|
|
if eval_type == 'val':
|
|
|
|
sentence_label = data['sentence_label'][trial]
|
|
|
|
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
|
|
|
|
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
|
|
|
|
|
|
|
|
print(f'Sentence label: {sentence_label}')
|
|
|
|
print(f'True sequence: {" ".join(true_seq)}')
|
|
|
|
print(f'Predicted Sequence: {" ".join(pred_seq)}')
|
|
|
|
print()
|
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# If using the validation set, calculate comprehensive evaluation metrics
|
2025-10-06 15:17:44 +08:00
|
|
|
if eval_type == 'val':
|
2025-10-12 09:11:32 +08:00
|
|
|
print("=" * 80)
|
|
|
|
print("COMPREHENSIVE EVALUATION METRICS")
|
|
|
|
print("=" * 80)
|
|
|
|
|
|
|
|
# Record performance timing
|
|
|
|
eval_start_time = time.time()
|
|
|
|
|
|
|
|
# Initialize metrics storage
|
2025-10-06 15:17:44 +08:00
|
|
|
total_true_length = 0
|
|
|
|
total_edit_distance = 0
|
2025-10-12 09:11:32 +08:00
|
|
|
all_true_seqs = []
|
|
|
|
all_pred_seqs = []
|
|
|
|
per_trial_errors = []
|
|
|
|
per_session_metrics = defaultdict(list)
|
2025-10-06 15:17:44 +08:00
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# Additional detailed metrics
|
2025-10-06 15:17:44 +08:00
|
|
|
results['edit_distance'] = []
|
|
|
|
results['num_phonemes'] = []
|
2025-10-12 09:11:32 +08:00
|
|
|
results['insertions'] = []
|
|
|
|
results['deletions'] = []
|
|
|
|
results['substitutions'] = []
|
|
|
|
results['per_trial_per'] = []
|
|
|
|
|
|
|
|
print(f"\nCalculating detailed metrics for {len(results['pred_phonemes'])} trials...")
|
2025-10-06 15:17:44 +08:00
|
|
|
|
|
|
|
for i in range(len(results['pred_phonemes'])):
|
|
|
|
# Get true phoneme sequence
|
|
|
|
session = results['session'][i]
|
|
|
|
trial_idx = None
|
|
|
|
for trial in range(len(test_data[session]['seq_class_ids'])):
|
2025-10-12 09:11:32 +08:00
|
|
|
if (test_data[session]['block_num'][trial] == results['block'][i] and
|
2025-10-06 15:17:44 +08:00
|
|
|
test_data[session]['trial_num'][trial] == results['trial'][i]):
|
|
|
|
trial_idx = trial
|
|
|
|
break
|
2025-10-12 09:11:32 +08:00
|
|
|
|
2025-10-06 15:17:44 +08:00
|
|
|
if trial_idx is not None:
|
|
|
|
true_seq = test_data[session]['seq_class_ids'][trial_idx][0:test_data[session]['seq_len'][trial_idx]]
|
|
|
|
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
|
|
|
|
pred_seq = results['pred_phonemes'][i]
|
2025-10-12 09:11:32 +08:00
|
|
|
|
|
|
|
all_true_seqs.append(true_seq)
|
|
|
|
all_pred_seqs.append(pred_seq)
|
|
|
|
|
|
|
|
# Calculate basic edit distance
|
2025-10-06 15:17:44 +08:00
|
|
|
ed = editdistance.eval(true_seq, pred_seq)
|
2025-10-12 09:11:32 +08:00
|
|
|
trial_per = 100 * ed / len(true_seq) if len(true_seq) > 0 else 0
|
|
|
|
|
|
|
|
# Calculate detailed phoneme operations
|
|
|
|
insertions, deletions, substitutions = calculate_phoneme_operations(true_seq, pred_seq)
|
2025-10-06 15:17:44 +08:00
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# Store metrics
|
2025-10-06 15:17:44 +08:00
|
|
|
total_true_length += len(true_seq)
|
|
|
|
total_edit_distance += ed
|
2025-10-12 09:11:32 +08:00
|
|
|
per_trial_errors.append(trial_per)
|
|
|
|
per_session_metrics[session].append(trial_per)
|
2025-10-06 15:17:44 +08:00
|
|
|
|
|
|
|
results['edit_distance'].append(ed)
|
|
|
|
results['num_phonemes'].append(len(true_seq))
|
2025-10-12 09:11:32 +08:00
|
|
|
results['insertions'].append(insertions)
|
|
|
|
results['deletions'].append(deletions)
|
|
|
|
results['substitutions'].append(substitutions)
|
|
|
|
results['per_trial_per'].append(trial_per)
|
2025-10-06 15:17:44 +08:00
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# Print detailed trial information
|
2025-10-06 15:17:44 +08:00
|
|
|
print(f'{results["session"][i]} - Block {results["block"][i]}, Trial {results["trial"][i]}')
|
2025-10-12 09:11:32 +08:00
|
|
|
print(f'True phonemes ({len(true_seq)}): {" ".join(true_seq)}')
|
|
|
|
print(f'Predicted phonemes ({len(pred_seq)}): {" ".join(pred_seq)}')
|
|
|
|
print(f'Operations: I={insertions}, D={deletions}, S={substitutions}, Total ED={ed}')
|
|
|
|
print(f'Trial PER: {ed} / {len(true_seq)} = {trial_per:.2f}%')
|
2025-10-06 15:17:44 +08:00
|
|
|
print()
|
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# ========== BASIC METRICS ==========
|
|
|
|
aggregate_per = 100 * total_edit_distance / total_true_length if total_true_length > 0 else 0
|
|
|
|
|
|
|
|
print("=" * 50)
|
|
|
|
print("BASIC PHONEME ERROR RATE METRICS")
|
|
|
|
print("=" * 50)
|
2025-10-06 15:17:44 +08:00
|
|
|
print(f'Total true phoneme length: {total_true_length}')
|
|
|
|
print(f'Total edit distance: {total_edit_distance}')
|
2025-10-12 09:11:32 +08:00
|
|
|
print(f'Aggregate Phoneme Error Rate (PER): {aggregate_per:.2f}%')
|
|
|
|
|
|
|
|
# Calculate PER confidence intervals
|
|
|
|
per_mean, per_ci_lower, per_ci_upper = calculate_confidence_intervals(per_trial_errors)
|
|
|
|
print(f'Per-trial PER: {per_mean:.2f}% (95% CI: [{per_ci_lower:.2f}%, {per_ci_upper:.2f}%])')
|
|
|
|
|
|
|
|
# ========== PHONEME OPERATION ANALYSIS ==========
|
|
|
|
total_insertions = sum(results['insertions'])
|
|
|
|
total_deletions = sum(results['deletions'])
|
|
|
|
total_substitutions = sum(results['substitutions'])
|
|
|
|
|
|
|
|
print("\n" + "=" * 50)
|
|
|
|
print("PHONEME OPERATION BREAKDOWN")
|
|
|
|
print("=" * 50)
|
|
|
|
print(f'Total insertions: {total_insertions} ({100*total_insertions/total_edit_distance:.1f}% of errors)')
|
|
|
|
print(f'Total deletions: {total_deletions} ({100*total_deletions/total_edit_distance:.1f}% of errors)')
|
|
|
|
print(f'Total substitutions: {total_substitutions} ({100*total_substitutions/total_edit_distance:.1f}% of errors)')
|
|
|
|
print(f'Verification: I+D+S = {total_insertions + total_deletions + total_substitutions} (should equal {total_edit_distance})')
|
|
|
|
|
|
|
|
# ========== PHONEME-LEVEL CLASSIFICATION METRICS ==========
|
|
|
|
print("\n" + "=" * 50)
|
|
|
|
print("PHONEME-LEVEL CLASSIFICATION METRICS")
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
phoneme_metrics = calculate_phoneme_level_metrics(all_true_seqs, all_pred_seqs)
|
|
|
|
|
|
|
|
print(f"Macro-averaged metrics:")
|
|
|
|
print(f" Precision: {phoneme_metrics['macro_avg']['precision']:.3f}")
|
|
|
|
print(f" Recall: {phoneme_metrics['macro_avg']['recall']:.3f}")
|
|
|
|
print(f" F1-Score: {phoneme_metrics['macro_avg']['f1']:.3f}")
|
2025-10-06 15:17:44 +08:00
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
print(f"\nMicro-averaged metrics:")
|
|
|
|
print(f" Precision: {phoneme_metrics['micro_avg']['precision']:.3f}")
|
|
|
|
print(f" Recall: {phoneme_metrics['micro_avg']['recall']:.3f}")
|
|
|
|
print(f" F1-Score: {phoneme_metrics['micro_avg']['f1']:.3f}")
|
|
|
|
|
|
|
|
# ========== SEQUENCE-LEVEL METRICS ==========
|
|
|
|
print("\n" + "=" * 50)
|
|
|
|
print("SEQUENCE-LEVEL METRICS")
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
seq_metrics = calculate_sequence_level_metrics(all_true_seqs, all_pred_seqs)
|
|
|
|
print(f"Exact Match Accuracy: {seq_metrics['exact_match_accuracy']:.3f} ({seq_metrics['exact_matches']}/{seq_metrics['total_sequences']})")
|
|
|
|
|
|
|
|
# ========== PER-SESSION BREAKDOWN ==========
|
|
|
|
print("\n" + "=" * 50)
|
|
|
|
print("PER-SESSION PERFORMANCE BREAKDOWN")
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
for session in sorted(per_session_metrics.keys()):
|
|
|
|
session_errors = per_session_metrics[session]
|
|
|
|
session_mean, session_ci_lower, session_ci_upper = calculate_confidence_intervals(session_errors)
|
|
|
|
print(f"{session}: {session_mean:.2f}% ± {session_mean - session_ci_lower:.2f}% "
|
|
|
|
f"(n={len(session_errors)}, range: {min(session_errors):.1f}%-{max(session_errors):.1f}%)")
|
|
|
|
|
|
|
|
# ========== STATISTICAL SUMMARY ==========
|
|
|
|
print("\n" + "=" * 50)
|
|
|
|
print("STATISTICAL SUMMARY")
|
|
|
|
print("=" * 50)
|
|
|
|
print(f"Mean trial PER: {np.mean(per_trial_errors):.2f}%")
|
|
|
|
print(f"Median trial PER: {np.median(per_trial_errors):.2f}%")
|
|
|
|
print(f"Std deviation: {np.std(per_trial_errors):.2f}%")
|
|
|
|
print(f"Min/Max trial PER: {min(per_trial_errors):.1f}% / {max(per_trial_errors):.1f}%")
|
|
|
|
print(f"Trials with 0% PER: {sum(1 for x in per_trial_errors if x == 0)} ({100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors):.1f}%)")
|
|
|
|
|
|
|
|
# ========== PHONEME FREQUENCY ANALYSIS ==========
|
|
|
|
if args.detailed_metrics:
|
|
|
|
print("\n" + "=" * 50)
|
|
|
|
print("PHONEME FREQUENCY AND ERROR PATTERN ANALYSIS")
|
|
|
|
print("=" * 50)
|
|
|
|
|
|
|
|
freq_analysis = calculate_phoneme_frequency_analysis(all_true_seqs, all_pred_seqs)
|
|
|
|
|
|
|
|
print("Top 10 most frequent true phonemes:")
|
|
|
|
for phoneme, count in freq_analysis['true_phoneme_counts'].most_common(10):
|
|
|
|
print(f" {phoneme}: {count} ({100*count/sum(freq_analysis['true_phoneme_counts'].values()):.1f}%)")
|
|
|
|
|
|
|
|
print("\nTop 10 most frequent predicted phonemes:")
|
|
|
|
for phoneme, count in freq_analysis['pred_phoneme_counts'].most_common(10):
|
|
|
|
print(f" {phoneme}: {count} ({100*count/sum(freq_analysis['pred_phoneme_counts'].values()):.1f}%)")
|
|
|
|
|
|
|
|
print("\nTop 10 most frequent error patterns (true -> predicted):")
|
|
|
|
sorted_errors = sorted(freq_analysis['error_patterns'].items(), key=lambda x: x[1], reverse=True)
|
|
|
|
for (true_p, pred_p), count in sorted_errors[:10]:
|
|
|
|
print(f" {true_p} -> {pred_p}: {count} errors")
|
|
|
|
|
|
|
|
# ========== GENERATE CONFUSION MATRIX ==========
|
|
|
|
if not args.skip_confusion_matrix and args.save_plots:
|
|
|
|
timestamp_for_plots = time.strftime("%Y%m%d_%H%M%S")
|
|
|
|
confusion_matrix_path = f'confusion_matrix_TTA-E_{eval_type}_{timestamp_for_plots}.png'
|
|
|
|
confusion_matrix_path = os.path.join(os.path.dirname(__file__), confusion_matrix_path)
|
|
|
|
|
|
|
|
print(f"\nGenerating confusion matrix...")
|
|
|
|
cm_result = generate_confusion_matrix_plot(all_true_seqs, all_pred_seqs, confusion_matrix_path)
|
|
|
|
if cm_result[0] is None:
|
|
|
|
confusion_matrix_path = None
|
|
|
|
else:
|
|
|
|
print(f"\nSkipping confusion matrix generation (--skip_confusion_matrix or --no-save_plots specified)")
|
|
|
|
confusion_matrix_path = None
|
|
|
|
|
|
|
|
# ========== TIMING METRICS ==========
|
|
|
|
eval_end_time = time.time()
|
|
|
|
eval_duration = eval_end_time - eval_start_time
|
|
|
|
|
|
|
|
print("\n" + "=" * 50)
|
|
|
|
print("TIMING METRICS")
|
|
|
|
print("=" * 50)
|
|
|
|
print(f"Evaluation time: {eval_duration:.2f} seconds")
|
|
|
|
print(f"Time per trial: {eval_duration/len(results['pred_phonemes']):.3f} seconds")
|
|
|
|
|
|
|
|
print("\n" + "=" * 80)
|
|
|
|
print("EVALUATION SUMMARY COMPLETE")
|
|
|
|
print("=" * 80)
|
|
|
|
|
|
|
|
# Write comprehensive results to CSV files
|
2025-10-06 15:17:44 +08:00
|
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
|
|
enabled_augs = '_'.join([k for k, v in tta_weights.items() if v > 0])
|
2025-10-12 09:11:32 +08:00
|
|
|
|
|
|
|
# Basic phoneme predictions CSV (for compatibility)
|
2025-10-06 15:17:44 +08:00
|
|
|
output_file = f'TTA-E_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_{enabled_augs}_{eval_type}_{timestamp}.csv'
|
|
|
|
output_path = os.path.join(os.path.dirname(__file__), output_file)
|
|
|
|
|
|
|
|
ids = [i for i in range(len(results['pred_phonemes']))]
|
|
|
|
phoneme_strings = [" ".join(phonemes) for phonemes in results['pred_phonemes']]
|
|
|
|
df_out = pd.DataFrame({'id': ids, 'phonemes': phoneme_strings})
|
|
|
|
df_out.to_csv(output_path, index=False)
|
|
|
|
|
2025-10-12 09:11:32 +08:00
|
|
|
# Comprehensive metrics CSV (if validation set)
|
|
|
|
if eval_type == 'val':
|
|
|
|
detailed_output_file = f'TTA-E_detailed_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv'
|
|
|
|
detailed_output_path = os.path.join(os.path.dirname(__file__), detailed_output_file)
|
|
|
|
|
|
|
|
# Create detailed DataFrame with all metrics
|
|
|
|
detailed_df = pd.DataFrame({
|
|
|
|
'id': ids,
|
|
|
|
'session': results['session'],
|
|
|
|
'block': results['block'],
|
|
|
|
'trial': results['trial'],
|
|
|
|
'true_sentence': results['true_sentence'],
|
|
|
|
'pred_phonemes': phoneme_strings,
|
|
|
|
'num_phonemes': results['num_phonemes'],
|
|
|
|
'edit_distance': results['edit_distance'],
|
|
|
|
'insertions': results['insertions'],
|
|
|
|
'deletions': results['deletions'],
|
|
|
|
'substitutions': results['substitutions'],
|
|
|
|
'trial_per': results['per_trial_per']
|
|
|
|
})
|
|
|
|
detailed_df.to_csv(detailed_output_path, index=False)
|
|
|
|
|
|
|
|
# Summary metrics CSV
|
|
|
|
summary_output_file = f'TTA-E_summary_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv'
|
|
|
|
summary_output_path = os.path.join(os.path.dirname(__file__), summary_output_file)
|
|
|
|
|
|
|
|
summary_data = {
|
|
|
|
'metric': [
|
|
|
|
'aggregate_per', 'mean_trial_per', 'median_trial_per', 'std_trial_per',
|
|
|
|
'min_trial_per', 'max_trial_per', 'zero_per_trials_count', 'zero_per_trials_percent',
|
|
|
|
'total_phonemes', 'total_edit_distance', 'total_insertions', 'total_deletions',
|
|
|
|
'total_substitutions', 'exact_match_accuracy', 'exact_matches',
|
|
|
|
'macro_precision', 'macro_recall', 'macro_f1',
|
|
|
|
'micro_precision', 'micro_recall', 'micro_f1',
|
|
|
|
'per_mean_95ci_lower', 'per_mean_95ci_upper', 'evaluation_time_seconds'
|
|
|
|
],
|
|
|
|
'value': [
|
|
|
|
aggregate_per, np.mean(per_trial_errors), np.median(per_trial_errors), np.std(per_trial_errors),
|
|
|
|
min(per_trial_errors), max(per_trial_errors),
|
|
|
|
sum(1 for x in per_trial_errors if x == 0),
|
|
|
|
100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors),
|
|
|
|
total_true_length, total_edit_distance, total_insertions, total_deletions,
|
|
|
|
total_substitutions, seq_metrics['exact_match_accuracy'], seq_metrics['exact_matches'],
|
|
|
|
phoneme_metrics['macro_avg']['precision'], phoneme_metrics['macro_avg']['recall'],
|
|
|
|
phoneme_metrics['macro_avg']['f1'], phoneme_metrics['micro_avg']['precision'],
|
|
|
|
phoneme_metrics['micro_avg']['recall'], phoneme_metrics['micro_avg']['f1'],
|
|
|
|
per_ci_lower, per_ci_upper, eval_duration
|
|
|
|
]
|
|
|
|
}
|
|
|
|
|
|
|
|
summary_df = pd.DataFrame(summary_data)
|
|
|
|
summary_df.to_csv(summary_output_path, index=False)
|
|
|
|
|
|
|
|
print(f'\nComprehensive results saved to:')
|
|
|
|
print(f' Basic predictions: {output_path}')
|
|
|
|
print(f' Detailed metrics: {detailed_output_path}')
|
|
|
|
print(f' Summary metrics: {summary_output_path}')
|
|
|
|
if confusion_matrix_path:
|
|
|
|
print(f' Confusion matrix: {confusion_matrix_path}')
|
|
|
|
else:
|
|
|
|
print(f' Confusion matrix: Not generated')
|
|
|
|
else:
|
|
|
|
print(f'\nResults saved to: {output_path}')
|
|
|
|
|
2025-10-06 15:17:44 +08:00
|
|
|
print(f'TTA-E configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}')
|
|
|
|
print(f'Enabled augmentations: {", ".join(enabled_augs.split("_"))}')
|