import os import sys import torch import numpy as np import pandas as pd from omegaconf import OmegaConf import time from tqdm import tqdm import editdistance import argparse from collections import defaultdict, Counter from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix import matplotlib.pyplot as plt import seaborn as sns # Add parent directories to path to import models sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) from model_training.rnn_model import GRUDecoder from model_training_lstm.rnn_model import LSTMDecoder from model_training.evaluate_model_helpers import * # ========== COMPREHENSIVE EVALUATION METRICS FUNCTIONS ========== def calculate_phoneme_operations(true_seq, pred_seq): """Calculate insertions, deletions, and substitutions using dynamic programming.""" m, n = len(true_seq), len(pred_seq) # DP matrix for edit distance calculation dp = [[0] * (n + 1) for _ in range(m + 1)] # Initialize base cases for i in range(m + 1): dp[i][0] = i # deletions for j in range(n + 1): dp[0][j] = j # insertions # Fill DP matrix for i in range(1, m + 1): for j in range(1, n + 1): if true_seq[i-1] == pred_seq[j-1]: dp[i][j] = dp[i-1][j-1] # match else: dp[i][j] = 1 + min( dp[i-1][j], # deletion dp[i][j-1], # insertion dp[i-1][j-1] # substitution ) # Backtrack to count operations i, j = m, n insertions = deletions = substitutions = 0 while i > 0 or j > 0: if i > 0 and j > 0 and true_seq[i-1] == pred_seq[j-1]: i -= 1 j -= 1 elif i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + 1: substitutions += 1 i -= 1 j -= 1 elif i > 0 and dp[i][j] == dp[i-1][j] + 1: deletions += 1 i -= 1 else: insertions += 1 j -= 1 return insertions, deletions, substitutions def calculate_phoneme_level_metrics(true_seqs, pred_seqs): """Calculate phoneme-level precision, recall, and F1-score using alignment.""" all_true_phonemes = [] all_pred_phonemes = [] # Align sequences properly to handle different lengths for true_seq, pred_seq in zip(true_seqs, pred_seqs): # Use dynamic programming alignment similar to edit distance # For simplicity, we'll pad shorter sequences and align based on minimum length min_len = min(len(true_seq), len(pred_seq)) max_len = max(len(true_seq), len(pred_seq)) # Take only the overlapping part for fair comparison if min_len > 0: # Align by taking proportional samples if len(true_seq) == len(pred_seq): # Same length - direct alignment all_true_phonemes.extend(true_seq) all_pred_phonemes.extend(pred_seq) else: # Different lengths - align by truncating to shorter length truncated_true = true_seq[:min_len] truncated_pred = pred_seq[:min_len] all_true_phonemes.extend(truncated_true) all_pred_phonemes.extend(truncated_pred) # Verify lengths match if len(all_true_phonemes) != len(all_pred_phonemes): print(f"Warning: After alignment, lengths still don't match: {len(all_true_phonemes)} vs {len(all_pred_phonemes)}") # Take minimum length to avoid sklearn error min_len = min(len(all_true_phonemes), len(all_pred_phonemes)) all_true_phonemes = all_true_phonemes[:min_len] all_pred_phonemes = all_pred_phonemes[:min_len] if len(all_true_phonemes) == 0: # Handle empty case return { 'phoneme_metrics': { 'phonemes': [], 'precision': np.array([]), 'recall': np.array([]), 'f1': np.array([]), 'support': np.array([]) }, 'macro_avg': { 'precision': 0.0, 'recall': 0.0, 'f1': 0.0 }, 'micro_avg': { 'precision': 0.0, 'recall': 0.0, 'f1': 0.0 } } # Get unique phonemes unique_phonemes = sorted(list(set(all_true_phonemes + all_pred_phonemes))) # Calculate metrics for each phoneme class precision, recall, f1, support = precision_recall_fscore_support( all_true_phonemes, all_pred_phonemes, labels=unique_phonemes, average=None, zero_division=0 ) # Calculate macro averages macro_precision = np.mean(precision) macro_recall = np.mean(recall) macro_f1 = np.mean(f1) # Calculate micro averages micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support( all_true_phonemes, all_pred_phonemes, average='micro', zero_division=0 ) return { 'phoneme_metrics': { 'phonemes': unique_phonemes, 'precision': precision, 'recall': recall, 'f1': f1, 'support': support }, 'macro_avg': { 'precision': macro_precision, 'recall': macro_recall, 'f1': macro_f1 }, 'micro_avg': { 'precision': micro_precision, 'recall': micro_recall, 'f1': micro_f1 } } def calculate_sequence_level_metrics(true_seqs, pred_seqs): """Calculate sequence-level metrics.""" exact_matches = 0 total_sequences = len(true_seqs) for true_seq, pred_seq in zip(true_seqs, pred_seqs): if true_seq == pred_seq: exact_matches += 1 exact_match_accuracy = exact_matches / total_sequences if total_sequences > 0 else 0 return { 'exact_match_accuracy': exact_match_accuracy, 'exact_matches': exact_matches, 'total_sequences': total_sequences } def calculate_confidence_intervals(errors, confidence=0.95): """Calculate confidence intervals for error rates.""" n = len(errors) if n == 0: return 0, 0, 0 mean_error = np.mean(errors) std_error = np.std(errors, ddof=1) if n > 1 else 0 # Use t-distribution for small samples, normal for large samples if n < 30: from scipy import stats t_val = stats.t.ppf((1 + confidence) / 2, n - 1) margin = t_val * std_error / np.sqrt(n) else: # For large samples, use normal distribution z_val = 1.96 if confidence == 0.95 else 2.576 # 95% or 99% margin = z_val * std_error / np.sqrt(n) return mean_error, mean_error - margin, mean_error + margin def calculate_phoneme_frequency_analysis(true_seqs, pred_seqs): """Analyze phoneme frequency and error patterns.""" true_phoneme_counts = Counter() pred_phoneme_counts = Counter() error_patterns = defaultdict(int) # Maps (true_phoneme, pred_phoneme) -> count for true_seq, pred_seq in zip(true_seqs, pred_seqs): true_phoneme_counts.update(true_seq) pred_phoneme_counts.update(pred_seq) # For error pattern analysis, align sequences by taking minimum length min_len = min(len(true_seq), len(pred_seq)) if min_len > 0: for i in range(min_len): true_p = true_seq[i] pred_p = pred_seq[i] if true_p != pred_p: error_patterns[(true_p, pred_p)] += 1 return { 'true_phoneme_counts': true_phoneme_counts, 'pred_phoneme_counts': pred_phoneme_counts, 'error_patterns': dict(error_patterns) } def generate_confusion_matrix_plot(true_seqs, pred_seqs, save_path=None): """Generate and save confusion matrix plot for phonemes.""" all_true = [] all_pred = [] # Align sequences properly to handle different lengths for true_seq, pred_seq in zip(true_seqs, pred_seqs): min_len = min(len(true_seq), len(pred_seq)) if min_len > 0: # Truncate both to same length for fair comparison all_true.extend(true_seq[:min_len]) all_pred.extend(pred_seq[:min_len]) if len(all_true) == 0 or len(all_pred) == 0: print("Warning: No aligned phonemes found for confusion matrix") return None, [] unique_phonemes = sorted(list(set(all_true + all_pred))) cm = confusion_matrix(all_true, all_pred, labels=unique_phonemes) plt.figure(figsize=(12, 10)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_phonemes, yticklabels=unique_phonemes) plt.title('Phoneme Confusion Matrix') plt.xlabel('Predicted Phonemes') plt.ylabel('True Phonemes') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"Confusion matrix saved to: {save_path}") plt.close() return cm, unique_phonemes # argument parser for command line arguments parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.') parser.add_argument('--gru_model_path', type=str, default='../data/t15_pretrained_rnn_baseline', help='Path to the pretrained GRU model directory.') parser.add_argument('--lstm_model_path', type=str, default='../data/t15_pretrained_lstm_baseline', help='Path to the pretrained LSTM model directory.') parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', help='Path to the dataset directory (relative to the current working directory).') parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'], help='Evaluation type: "val" for validation set, "test" for test set.') parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', help='Path to the CSV file with metadata about the dataset.') parser.add_argument('--gpu_number', type=int, default=0, help='GPU number to use for model inference. Set to -1 to use CPU.') parser.add_argument('--gru_weight', type=float, default=1, help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.') # TTA parameters parser.add_argument('--tta_samples', type=int, default=4, help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.') parser.add_argument('--tta_noise_std', type=float, default=0.01, help='Standard deviation for TTA noise augmentation.') parser.add_argument('--tta_smooth_range', type=float, default=0.5, help='Range for TTA smoothing kernel variation (±range from default).') parser.add_argument('--tta_scale_range', type=float, default=0.05, help='Range for TTA amplitude scaling (±range from 1.0).') parser.add_argument('--tta_cut_max', type=int, default=3, help='Maximum number of timesteps to cut from beginning in TTA.') # TTA样本权重配置 parser.add_argument('--tta_weights', type=str, default='2.0,1.0,0.0,0.0,0.0', help='Comma-separated weights for each TTA augmentation type: original,noise,scale,shift,smooth. Set to 0 to disable.') # Evaluation metrics configuration parser.add_argument('--skip_confusion_matrix', action='store_true', help='Skip confusion matrix generation to save time.') parser.add_argument('--detailed_metrics', action='store_true', default=True, help='Calculate detailed phoneme-level and sequence-level metrics.') parser.add_argument('--save_plots', action='store_true', default=True, help='Save confusion matrix and other plots.') # Memory management options parser.add_argument('--max_trials_per_session', type=int, default=1000, help='Maximum number of trials to load per session (for memory management).') parser.add_argument('--batch_inference', action='store_true', help='Process trials in batches to reduce memory usage during inference.') parser.add_argument('--batch_size', type=int, default=10, help='Batch size for inference when using --batch_inference.') args = parser.parse_args() # Model paths gru_model_path = args.gru_model_path lstm_model_path = args.lstm_model_path data_dir = args.data_dir # Ensemble weights gru_weight = args.gru_weight lstm_weight = 1.0 - gru_weight # TTA parameters tta_samples = args.tta_samples tta_noise_std = args.tta_noise_std tta_smooth_range = args.tta_smooth_range tta_scale_range = args.tta_scale_range tta_cut_max = args.tta_cut_max # Parse TTA weights tta_weights_str = args.tta_weights.split(',') if len(tta_weights_str) != 5: raise ValueError("TTA weights must have exactly 5 values: original,noise,scale,shift,smooth") tta_weights = { 'original': float(tta_weights_str[0]), 'noise': float(tta_weights_str[1]), 'scale': float(tta_weights_str[2]), 'shift': float(tta_weights_str[3]), 'smooth': float(tta_weights_str[4]) } # Calculate effective number of TTA samples based on enabled augmentations enabled_augmentations = [k for k, v in tta_weights.items() if v > 0] effective_tta_samples = len(enabled_augmentations) print(f"Improved TTA-E Configuration:") print(f"GRU weight: {gru_weight:.2f} (balanced for better ensemble)") print(f"LSTM weight: {lstm_weight:.2f}") print(f"TTA augmentation weights:") for aug_type, weight in tta_weights.items(): status = "✓ enabled" if weight > 0 else "✗ disabled" print(f" - {aug_type}: {weight:.1f} ({status})") print(f"Effective TTA samples: {effective_tta_samples}") print(f"TTA noise std: {tta_noise_std}") print(f"TTA smooth range: ±{tta_smooth_range}") print(f"TTA scale range: ±{tta_scale_range}") print(f"TTA max cut: {tta_cut_max} timesteps") print(f"Key improvements:") print(f" - Customizable TTA sample weights") print(f" - Probability-based ensemble instead of logit averaging") print(f" - Geometric mean for more robust fusion") print(f" - Direct PER evaluation without language model dependency") print(f"GRU model path: {gru_model_path}") print(f"LSTM model path: {lstm_model_path}") print() # Define evaluation type eval_type = args.eval_type # Load CSV file b2txt_csv_df = pd.read_csv(args.csv_path) # Load model arguments for both models gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml')) lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml')) # Set up GPU device gpu_number = args.gpu_number if torch.cuda.is_available() and gpu_number >= 0: if gpu_number >= torch.cuda.device_count(): raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}') device = f'cuda:{gpu_number}' device = torch.device(device) print(f'Using {device} for model inference.') else: if gpu_number >= 0: print(f'GPU number {gpu_number} requested but not available.') print('Using CPU for model inference.') device = torch.device('cpu') # Define GRU model gru_model = GRUDecoder( neural_dim=gru_model_args['model']['n_input_features'], n_units=gru_model_args['model']['n_units'], n_days=len(gru_model_args['dataset']['sessions']), n_classes=gru_model_args['dataset']['n_classes'], rnn_dropout=gru_model_args['model']['rnn_dropout'], input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'], n_layers=gru_model_args['model']['n_layers'], patch_size=gru_model_args['model']['patch_size'], patch_stride=gru_model_args['model']['patch_stride'], ) # Load GRU model weights gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) # Rename keys to not start with "module." (happens if model was saved with DataParallel) for key in list(gru_checkpoint['model_state_dict'].keys()): gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key) gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key) gru_model.load_state_dict(gru_checkpoint['model_state_dict']) # Define LSTM model lstm_model = LSTMDecoder( neural_dim=lstm_model_args['model']['n_input_features'], n_units=lstm_model_args['model']['n_units'], n_days=len(lstm_model_args['dataset']['sessions']), n_classes=lstm_model_args['dataset']['n_classes'], rnn_dropout=lstm_model_args['model']['rnn_dropout'], input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'], n_layers=lstm_model_args['model']['n_layers'], patch_size=lstm_model_args['model']['patch_size'], patch_stride=lstm_model_args['model']['patch_stride'], ) # Load LSTM model weights lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) # Rename keys to not start with "module." (happens if model was saved with DataParallel) for key in list(lstm_checkpoint['model_state_dict'].keys()): lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key) lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key) lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) # Add models to device gru_model.to(device) lstm_model.to(device) # Set models to eval mode gru_model.eval() lstm_model.eval() print("Both models loaded successfully!") print() # TTA-E inference function def runTTAEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, device, gru_weight, lstm_weight, tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max): """ Run Customizable TTA-E (Test Time Augmentation + Ensemble) inference: 1. Apply selected data augmentations based on weights 2. Run both GRU and LSTM models on each augmented version 3. Use probability-based ensemble with customizable sample weights 4. Average predictions in probability space for better stability """ all_gru_probs = [] all_lstm_probs = [] sample_weights = [] # Get default smoothing parameters default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size'] # Generate augmented samples based on enabled augmentations augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth'] for aug_type in augmentation_types: if tta_weights[aug_type] <= 0: continue # Skip disabled augmentations x_augmented = x.clone() if aug_type == 'original': # Original data (baseline) pass elif aug_type == 'noise': # Add Gaussian noise with varying intensity noise_scale = tta_noise_std * (0.5 + 0.5 * np.random.rand()) noise = torch.randn_like(x_augmented) * noise_scale x_augmented = x_augmented + noise elif aug_type == 'scale': # Amplitude scaling with more variation scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range x_augmented = x_augmented * scale_factor elif aug_type == 'shift' and tta_cut_max > 0: # Time shift (circular shift) shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8)) x_augmented = torch.cat([x_augmented[:, shift_amount:, :], x_augmented[:, :shift_amount, :]], dim=1) elif aug_type == 'smooth': # Apply smoothing variation after augmentation smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range varied_smooth_std = max(0.3, default_smooth_std + smooth_variation) # Use autocast for efficiency - auto-detect device type device_type = "cuda" if device.type == "cuda" else "cpu" use_amp = gru_model_args.get('use_amp', False) and device.type == "cuda" with torch.autocast(device_type=device_type, enabled=use_amp, dtype=torch.bfloat16 if device.type == "cuda" else torch.float32): # Apply Gaussian smoothing if aug_type == 'smooth': # Use varied smoothing x_smoothed = gauss_smooth( inputs=x_augmented, device=device, smooth_kernel_std=varied_smooth_std, smooth_kernel_size=default_smooth_size, padding='valid', ) else: # Use default smoothing x_smoothed = gauss_smooth( inputs=x_augmented, device=device, smooth_kernel_std=default_smooth_std, smooth_kernel_size=default_smooth_size, padding='valid', ) with torch.no_grad(): # Get GRU logits and convert to probabilities gru_logits, _ = gru_model( x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True, ) gru_probs = torch.softmax(gru_logits, dim=-1) # Get LSTM logits and convert to probabilities lstm_logits, _ = lstm_model( x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True, ) lstm_probs = torch.softmax(lstm_logits, dim=-1) all_gru_probs.append(gru_probs) all_lstm_probs.append(lstm_probs) sample_weights.append(tta_weights[aug_type]) if len(all_gru_probs) == 0: raise ValueError("No TTA augmentations are enabled. Please set at least one weight > 0.") # TTA fusion: Handle potentially different tensor shapes if len(all_gru_probs) > 1: # Find the minimum sequence length among all TTA samples min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs]) # Truncate all tensors to the minimum length truncated_gru_probs = [] truncated_lstm_probs = [] for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs): if gru_probs.shape[1] > min_length: truncated_gru_probs.append(gru_probs[:, :min_length, :]) else: truncated_gru_probs.append(gru_probs) if lstm_probs.shape[1] > min_length: truncated_lstm_probs.append(lstm_probs[:, :min_length, :]) else: truncated_lstm_probs.append(lstm_probs) # Weighted average probabilities across TTA samples sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32, device=device) sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum() # Normalize weights weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0]) weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0]) for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)): weighted_gru_probs += weight * gru_probs weighted_lstm_probs += weight * lstm_probs avg_gru_probs = weighted_gru_probs avg_lstm_probs = weighted_lstm_probs else: avg_gru_probs = all_gru_probs[0] avg_lstm_probs = all_lstm_probs[0] # Improved ensemble: geometric mean of probabilities (more robust than weighted average) # Apply epsilon smoothing to avoid log(0) epsilon = 1e-8 avg_gru_probs = avg_gru_probs + epsilon avg_lstm_probs = avg_lstm_probs + epsilon # Weighted geometric mean in log space log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) + lstm_weight * torch.log(avg_lstm_probs)) ensemble_probs = torch.exp(log_ensemble_probs) # Normalize to ensure proper probability distribution ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True) # Convert back to logits for compatibility with downstream code final_logits = torch.log(ensemble_probs + epsilon) # Convert from bfloat16 to float32 return final_logits.float().cpu().numpy() # Load data for each session with memory management test_data = {} total_test_trials = 0 print("Loading data with memory management...") for session in gru_model_args['dataset']['sessions']: files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')] if f'data_{eval_type}.hdf5' in files: eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5') try: # Try to load data normally first print(f'Loading {session}...') data = load_h5py_file(eval_file, b2txt_csv_df) test_data[session] = data total_test_trials += len(test_data[session]["neural_features"]) print(f'Successfully loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.') except (MemoryError, np._core._exceptions._ArrayMemoryError) as e: print(f'Memory error loading {session}: {e}') print(f'Attempting to load {session} in smaller chunks...') # If memory error, try loading in chunks or skip this session import h5py try: with h5py.File(eval_file, 'r') as h5file: # Get the size first input_features_shape = h5file['input_features'].shape print(f'Session {session} has {input_features_shape[0]} trials with {input_features_shape[1]} features') # If it's too large, skip or implement chunk loading max_trials_per_session = args.max_trials_per_session # Limit trials per session if input_features_shape[0] > max_trials_per_session: print(f'Session {session} has {input_features_shape[0]} trials, limiting to {max_trials_per_session} for memory management') # Could implement chunked loading here if needed print(f'Skipping session {session} due to memory constraints') continue else: # Try one more time with garbage collection import gc gc.collect() data = load_h5py_file(eval_file, b2txt_csv_df) test_data[session] = data total_test_trials += len(test_data[session]["neural_features"]) print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session} after GC.') except Exception as e2: print(f'Failed to load session {session}: {e2}') print(f'Skipping session {session}') continue print(f'Total number of {eval_type} trials loaded: {total_test_trials}') if total_test_trials == 0: print("ERROR: No trials loaded! Check data paths and memory availability.") sys.exit(1) print() # Put neural data through the TTA-E ensemble model to get phoneme predictions (logits) import gc with tqdm(total=total_test_trials, desc=f'Customizable TTA-E inference ({effective_tta_samples} augmentation types)', unit='trial') as pbar: for session, data in test_data.items(): data['logits'] = [] data['pred_seq'] = [] input_layer = gru_model_args['dataset']['sessions'].index(session) # Process trials in batches if requested if args.batch_inference: batch_size = args.batch_size num_trials = len(data['neural_features']) for batch_start in range(0, num_trials, batch_size): batch_end = min(batch_start + batch_size, num_trials) batch_logits = [] for trial in range(batch_start, batch_end): # Get neural input for the trial neural_input = data['neural_features'][trial] # Add batch dimension neural_input = np.expand_dims(neural_input, axis=0) # Convert to torch tensor - use float32 to avoid dtype issues with smoothing neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32) # Run TTA-E decoding step with customizable weights ensemble_logits = runTTAEnsembleDecodingStep( neural_input, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, device, gru_weight, lstm_weight, tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max ) batch_logits.append(ensemble_logits) # Clear GPU memory del neural_input if device.type == 'cuda': torch.cuda.empty_cache() pbar.update(1) # Add batch results data['logits'].extend(batch_logits) # Clear batch memory del batch_logits gc.collect() else: # Process all trials individually (original method) for trial in range(len(data['neural_features'])): # Get neural input for the trial neural_input = data['neural_features'][trial] # Add batch dimension neural_input = np.expand_dims(neural_input, axis=0) # Convert to torch tensor - use float32 to avoid dtype issues with smoothing neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32) # Run TTA-E decoding step with customizable weights ensemble_logits = runTTAEnsembleDecodingStep( neural_input, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, device, gru_weight, lstm_weight, tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max ) data['logits'].append(ensemble_logits) # Clear memory periodically if trial % 50 == 0: del neural_input if device.type == 'cuda': torch.cuda.empty_cache() gc.collect() pbar.update(1) # Clear session data memory after processing gc.collect() pbar.close() # Convert logits to phoneme sequences and print them out results = { 'session': [], 'block': [], 'trial': [], 'true_sentence': [], 'pred_phonemes': [], } for session, data in test_data.items(): data['pred_seq'] = [] for trial in range(len(data['logits'])): logits = data['logits'][trial][0] pred_seq = np.argmax(logits, axis=-1) # Remove blanks (0) pred_seq = [int(p) for p in pred_seq if p != 0] # Remove consecutive duplicates pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] # Convert to phonemes pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq] # Add to data data['pred_seq'].append(pred_seq) # Store results results['session'].append(session) results['block'].append(data['block_num'][trial]) results['trial'].append(data['trial_num'][trial]) if eval_type == 'val': results['true_sentence'].append(data['sentence_label'][trial]) else: results['true_sentence'].append(None) results['pred_phonemes'].append(pred_seq) # Print out the predicted sequences block_num = data['block_num'][trial] trial_num = data['trial_num'][trial] print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}') if eval_type == 'val': sentence_label = data['sentence_label'][trial] true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]] true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq] print(f'Sentence label: {sentence_label}') print(f'True sequence: {" ".join(true_seq)}') print(f'Predicted Sequence: {" ".join(pred_seq)}') print() # If using the validation set, calculate comprehensive evaluation metrics if eval_type == 'val': print("=" * 80) print("COMPREHENSIVE EVALUATION METRICS") print("=" * 80) # Record performance timing eval_start_time = time.time() # Initialize metrics storage total_true_length = 0 total_edit_distance = 0 all_true_seqs = [] all_pred_seqs = [] per_trial_errors = [] per_session_metrics = defaultdict(list) # Additional detailed metrics results['edit_distance'] = [] results['num_phonemes'] = [] results['insertions'] = [] results['deletions'] = [] results['substitutions'] = [] results['per_trial_per'] = [] print(f"\nCalculating detailed metrics for {len(results['pred_phonemes'])} trials...") for i in range(len(results['pred_phonemes'])): # Get true phoneme sequence session = results['session'][i] trial_idx = None for trial in range(len(test_data[session]['seq_class_ids'])): if (test_data[session]['block_num'][trial] == results['block'][i] and test_data[session]['trial_num'][trial] == results['trial'][i]): trial_idx = trial break if trial_idx is not None: true_seq = test_data[session]['seq_class_ids'][trial_idx][0:test_data[session]['seq_len'][trial_idx]] true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq] pred_seq = results['pred_phonemes'][i] all_true_seqs.append(true_seq) all_pred_seqs.append(pred_seq) # Calculate basic edit distance ed = editdistance.eval(true_seq, pred_seq) trial_per = 100 * ed / len(true_seq) if len(true_seq) > 0 else 0 # Calculate detailed phoneme operations insertions, deletions, substitutions = calculate_phoneme_operations(true_seq, pred_seq) # Store metrics total_true_length += len(true_seq) total_edit_distance += ed per_trial_errors.append(trial_per) per_session_metrics[session].append(trial_per) results['edit_distance'].append(ed) results['num_phonemes'].append(len(true_seq)) results['insertions'].append(insertions) results['deletions'].append(deletions) results['substitutions'].append(substitutions) results['per_trial_per'].append(trial_per) # Print detailed trial information print(f'{results["session"][i]} - Block {results["block"][i]}, Trial {results["trial"][i]}') print(f'True phonemes ({len(true_seq)}): {" ".join(true_seq)}') print(f'Predicted phonemes ({len(pred_seq)}): {" ".join(pred_seq)}') print(f'Operations: I={insertions}, D={deletions}, S={substitutions}, Total ED={ed}') print(f'Trial PER: {ed} / {len(true_seq)} = {trial_per:.2f}%') print() # ========== BASIC METRICS ========== aggregate_per = 100 * total_edit_distance / total_true_length if total_true_length > 0 else 0 print("=" * 50) print("BASIC PHONEME ERROR RATE METRICS") print("=" * 50) print(f'Total true phoneme length: {total_true_length}') print(f'Total edit distance: {total_edit_distance}') print(f'Aggregate Phoneme Error Rate (PER): {aggregate_per:.2f}%') # Calculate PER confidence intervals per_mean, per_ci_lower, per_ci_upper = calculate_confidence_intervals(per_trial_errors) print(f'Per-trial PER: {per_mean:.2f}% (95% CI: [{per_ci_lower:.2f}%, {per_ci_upper:.2f}%])') # ========== PHONEME OPERATION ANALYSIS ========== total_insertions = sum(results['insertions']) total_deletions = sum(results['deletions']) total_substitutions = sum(results['substitutions']) print("\n" + "=" * 50) print("PHONEME OPERATION BREAKDOWN") print("=" * 50) print(f'Total insertions: {total_insertions} ({100*total_insertions/total_edit_distance:.1f}% of errors)') print(f'Total deletions: {total_deletions} ({100*total_deletions/total_edit_distance:.1f}% of errors)') print(f'Total substitutions: {total_substitutions} ({100*total_substitutions/total_edit_distance:.1f}% of errors)') print(f'Verification: I+D+S = {total_insertions + total_deletions + total_substitutions} (should equal {total_edit_distance})') # ========== PHONEME-LEVEL CLASSIFICATION METRICS ========== print("\n" + "=" * 50) print("PHONEME-LEVEL CLASSIFICATION METRICS") print("=" * 50) phoneme_metrics = calculate_phoneme_level_metrics(all_true_seqs, all_pred_seqs) print(f"Macro-averaged metrics:") print(f" Precision: {phoneme_metrics['macro_avg']['precision']:.3f}") print(f" Recall: {phoneme_metrics['macro_avg']['recall']:.3f}") print(f" F1-Score: {phoneme_metrics['macro_avg']['f1']:.3f}") print(f"\nMicro-averaged metrics:") print(f" Precision: {phoneme_metrics['micro_avg']['precision']:.3f}") print(f" Recall: {phoneme_metrics['micro_avg']['recall']:.3f}") print(f" F1-Score: {phoneme_metrics['micro_avg']['f1']:.3f}") # ========== SEQUENCE-LEVEL METRICS ========== print("\n" + "=" * 50) print("SEQUENCE-LEVEL METRICS") print("=" * 50) seq_metrics = calculate_sequence_level_metrics(all_true_seqs, all_pred_seqs) print(f"Exact Match Accuracy: {seq_metrics['exact_match_accuracy']:.3f} ({seq_metrics['exact_matches']}/{seq_metrics['total_sequences']})") # ========== PER-SESSION BREAKDOWN ========== print("\n" + "=" * 50) print("PER-SESSION PERFORMANCE BREAKDOWN") print("=" * 50) for session in sorted(per_session_metrics.keys()): session_errors = per_session_metrics[session] session_mean, session_ci_lower, session_ci_upper = calculate_confidence_intervals(session_errors) print(f"{session}: {session_mean:.2f}% ± {session_mean - session_ci_lower:.2f}% " f"(n={len(session_errors)}, range: {min(session_errors):.1f}%-{max(session_errors):.1f}%)") # ========== STATISTICAL SUMMARY ========== print("\n" + "=" * 50) print("STATISTICAL SUMMARY") print("=" * 50) print(f"Mean trial PER: {np.mean(per_trial_errors):.2f}%") print(f"Median trial PER: {np.median(per_trial_errors):.2f}%") print(f"Std deviation: {np.std(per_trial_errors):.2f}%") print(f"Min/Max trial PER: {min(per_trial_errors):.1f}% / {max(per_trial_errors):.1f}%") print(f"Trials with 0% PER: {sum(1 for x in per_trial_errors if x == 0)} ({100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors):.1f}%)") # ========== PHONEME FREQUENCY ANALYSIS ========== if args.detailed_metrics: print("\n" + "=" * 50) print("PHONEME FREQUENCY AND ERROR PATTERN ANALYSIS") print("=" * 50) freq_analysis = calculate_phoneme_frequency_analysis(all_true_seqs, all_pred_seqs) print("Top 10 most frequent true phonemes:") for phoneme, count in freq_analysis['true_phoneme_counts'].most_common(10): print(f" {phoneme}: {count} ({100*count/sum(freq_analysis['true_phoneme_counts'].values()):.1f}%)") print("\nTop 10 most frequent predicted phonemes:") for phoneme, count in freq_analysis['pred_phoneme_counts'].most_common(10): print(f" {phoneme}: {count} ({100*count/sum(freq_analysis['pred_phoneme_counts'].values()):.1f}%)") print("\nTop 10 most frequent error patterns (true -> predicted):") sorted_errors = sorted(freq_analysis['error_patterns'].items(), key=lambda x: x[1], reverse=True) for (true_p, pred_p), count in sorted_errors[:10]: print(f" {true_p} -> {pred_p}: {count} errors") # ========== GENERATE CONFUSION MATRIX ========== if not args.skip_confusion_matrix and args.save_plots: timestamp_for_plots = time.strftime("%Y%m%d_%H%M%S") confusion_matrix_path = f'confusion_matrix_TTA-E_{eval_type}_{timestamp_for_plots}.png' confusion_matrix_path = os.path.join(os.path.dirname(__file__), confusion_matrix_path) print(f"\nGenerating confusion matrix...") cm_result = generate_confusion_matrix_plot(all_true_seqs, all_pred_seqs, confusion_matrix_path) if cm_result[0] is None: confusion_matrix_path = None else: print(f"\nSkipping confusion matrix generation (--skip_confusion_matrix or --no-save_plots specified)") confusion_matrix_path = None # ========== TIMING METRICS ========== eval_end_time = time.time() eval_duration = eval_end_time - eval_start_time print("\n" + "=" * 50) print("TIMING METRICS") print("=" * 50) print(f"Evaluation time: {eval_duration:.2f} seconds") print(f"Time per trial: {eval_duration/len(results['pred_phonemes']):.3f} seconds") print("\n" + "=" * 80) print("EVALUATION SUMMARY COMPLETE") print("=" * 80) # Write comprehensive results to CSV files timestamp = time.strftime("%Y%m%d_%H%M%S") enabled_augs = '_'.join([k for k, v in tta_weights.items() if v > 0]) # Basic phoneme predictions CSV (for compatibility) output_file = f'TTA-E_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_{enabled_augs}_{eval_type}_{timestamp}.csv' output_path = os.path.join(os.path.dirname(__file__), output_file) ids = [i for i in range(len(results['pred_phonemes']))] phoneme_strings = [" ".join(phonemes) for phonemes in results['pred_phonemes']] df_out = pd.DataFrame({'id': ids, 'phonemes': phoneme_strings}) df_out.to_csv(output_path, index=False) # Comprehensive metrics CSV (if validation set) if eval_type == 'val': detailed_output_file = f'TTA-E_detailed_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv' detailed_output_path = os.path.join(os.path.dirname(__file__), detailed_output_file) # Create detailed DataFrame with all metrics detailed_df = pd.DataFrame({ 'id': ids, 'session': results['session'], 'block': results['block'], 'trial': results['trial'], 'true_sentence': results['true_sentence'], 'pred_phonemes': phoneme_strings, 'num_phonemes': results['num_phonemes'], 'edit_distance': results['edit_distance'], 'insertions': results['insertions'], 'deletions': results['deletions'], 'substitutions': results['substitutions'], 'trial_per': results['per_trial_per'] }) detailed_df.to_csv(detailed_output_path, index=False) # Summary metrics CSV summary_output_file = f'TTA-E_summary_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv' summary_output_path = os.path.join(os.path.dirname(__file__), summary_output_file) summary_data = { 'metric': [ 'aggregate_per', 'mean_trial_per', 'median_trial_per', 'std_trial_per', 'min_trial_per', 'max_trial_per', 'zero_per_trials_count', 'zero_per_trials_percent', 'total_phonemes', 'total_edit_distance', 'total_insertions', 'total_deletions', 'total_substitutions', 'exact_match_accuracy', 'exact_matches', 'macro_precision', 'macro_recall', 'macro_f1', 'micro_precision', 'micro_recall', 'micro_f1', 'per_mean_95ci_lower', 'per_mean_95ci_upper', 'evaluation_time_seconds' ], 'value': [ aggregate_per, np.mean(per_trial_errors), np.median(per_trial_errors), np.std(per_trial_errors), min(per_trial_errors), max(per_trial_errors), sum(1 for x in per_trial_errors if x == 0), 100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors), total_true_length, total_edit_distance, total_insertions, total_deletions, total_substitutions, seq_metrics['exact_match_accuracy'], seq_metrics['exact_matches'], phoneme_metrics['macro_avg']['precision'], phoneme_metrics['macro_avg']['recall'], phoneme_metrics['macro_avg']['f1'], phoneme_metrics['micro_avg']['precision'], phoneme_metrics['micro_avg']['recall'], phoneme_metrics['micro_avg']['f1'], per_ci_lower, per_ci_upper, eval_duration ] } summary_df = pd.DataFrame(summary_data) summary_df.to_csv(summary_output_path, index=False) print(f'\nComprehensive results saved to:') print(f' Basic predictions: {output_path}') print(f' Detailed metrics: {detailed_output_path}') print(f' Summary metrics: {summary_output_path}') if confusion_matrix_path: print(f' Confusion matrix: {confusion_matrix_path}') else: print(f' Confusion matrix: Not generated') else: print(f'\nResults saved to: {output_path}') print(f'TTA-E configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}') print(f'Enabled augmentations: {", ".join(enabled_augs.split("_"))}')