1063 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1063 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import sys
 | |
| import torch
 | |
| import numpy as np
 | |
| import pandas as pd
 | |
| from omegaconf import OmegaConf
 | |
| import time
 | |
| from tqdm import tqdm
 | |
| import editdistance
 | |
| import argparse
 | |
| from collections import defaultdict, Counter
 | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
 | |
| import matplotlib.pyplot as plt
 | |
| import seaborn as sns
 | |
| 
 | |
| # Add parent directories to path to import models
 | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
 | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
 | |
| 
 | |
| from model_training.rnn_model import GRUDecoder
 | |
| from model_training_lstm.rnn_model import LSTMDecoder
 | |
| from model_training.evaluate_model_helpers import *
 | |
| 
 | |
| # ========== COMPREHENSIVE EVALUATION METRICS FUNCTIONS ==========
 | |
| 
 | |
| def calculate_phoneme_operations(true_seq, pred_seq):
 | |
|     """Calculate insertions, deletions, and substitutions using dynamic programming."""
 | |
|     m, n = len(true_seq), len(pred_seq)
 | |
| 
 | |
|     # DP matrix for edit distance calculation
 | |
|     dp = [[0] * (n + 1) for _ in range(m + 1)]
 | |
| 
 | |
|     # Initialize base cases
 | |
|     for i in range(m + 1):
 | |
|         dp[i][0] = i  # deletions
 | |
|     for j in range(n + 1):
 | |
|         dp[0][j] = j  # insertions
 | |
| 
 | |
|     # Fill DP matrix
 | |
|     for i in range(1, m + 1):
 | |
|         for j in range(1, n + 1):
 | |
|             if true_seq[i-1] == pred_seq[j-1]:
 | |
|                 dp[i][j] = dp[i-1][j-1]  # match
 | |
|             else:
 | |
|                 dp[i][j] = 1 + min(
 | |
|                     dp[i-1][j],    # deletion
 | |
|                     dp[i][j-1],    # insertion
 | |
|                     dp[i-1][j-1]   # substitution
 | |
|                 )
 | |
| 
 | |
|     # Backtrack to count operations
 | |
|     i, j = m, n
 | |
|     insertions = deletions = substitutions = 0
 | |
| 
 | |
|     while i > 0 or j > 0:
 | |
|         if i > 0 and j > 0 and true_seq[i-1] == pred_seq[j-1]:
 | |
|             i -= 1
 | |
|             j -= 1
 | |
|         elif i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + 1:
 | |
|             substitutions += 1
 | |
|             i -= 1
 | |
|             j -= 1
 | |
|         elif i > 0 and dp[i][j] == dp[i-1][j] + 1:
 | |
|             deletions += 1
 | |
|             i -= 1
 | |
|         else:
 | |
|             insertions += 1
 | |
|             j -= 1
 | |
| 
 | |
|     return insertions, deletions, substitutions
 | |
| 
 | |
| def calculate_phoneme_level_metrics(true_seqs, pred_seqs):
 | |
|     """Calculate phoneme-level precision, recall, and F1-score using alignment."""
 | |
|     all_true_phonemes = []
 | |
|     all_pred_phonemes = []
 | |
| 
 | |
|     # Align sequences properly to handle different lengths
 | |
|     for true_seq, pred_seq in zip(true_seqs, pred_seqs):
 | |
|         # Use dynamic programming alignment similar to edit distance
 | |
|         # For simplicity, we'll pad shorter sequences and align based on minimum length
 | |
|         min_len = min(len(true_seq), len(pred_seq))
 | |
|         max_len = max(len(true_seq), len(pred_seq))
 | |
| 
 | |
|         # Take only the overlapping part for fair comparison
 | |
|         if min_len > 0:
 | |
|             # Align by taking proportional samples
 | |
|             if len(true_seq) == len(pred_seq):
 | |
|                 # Same length - direct alignment
 | |
|                 all_true_phonemes.extend(true_seq)
 | |
|                 all_pred_phonemes.extend(pred_seq)
 | |
|             else:
 | |
|                 # Different lengths - align by truncating to shorter length
 | |
|                 truncated_true = true_seq[:min_len]
 | |
|                 truncated_pred = pred_seq[:min_len]
 | |
|                 all_true_phonemes.extend(truncated_true)
 | |
|                 all_pred_phonemes.extend(truncated_pred)
 | |
| 
 | |
|     # Verify lengths match
 | |
|     if len(all_true_phonemes) != len(all_pred_phonemes):
 | |
|         print(f"Warning: After alignment, lengths still don't match: {len(all_true_phonemes)} vs {len(all_pred_phonemes)}")
 | |
|         # Take minimum length to avoid sklearn error
 | |
|         min_len = min(len(all_true_phonemes), len(all_pred_phonemes))
 | |
|         all_true_phonemes = all_true_phonemes[:min_len]
 | |
|         all_pred_phonemes = all_pred_phonemes[:min_len]
 | |
| 
 | |
|     if len(all_true_phonemes) == 0:
 | |
|         # Handle empty case
 | |
|         return {
 | |
|             'phoneme_metrics': {
 | |
|                 'phonemes': [],
 | |
|                 'precision': np.array([]),
 | |
|                 'recall': np.array([]),
 | |
|                 'f1': np.array([]),
 | |
|                 'support': np.array([])
 | |
|             },
 | |
|             'macro_avg': {
 | |
|                 'precision': 0.0,
 | |
|                 'recall': 0.0,
 | |
|                 'f1': 0.0
 | |
|             },
 | |
|             'micro_avg': {
 | |
|                 'precision': 0.0,
 | |
|                 'recall': 0.0,
 | |
|                 'f1': 0.0
 | |
|             }
 | |
|         }
 | |
| 
 | |
|     # Get unique phonemes
 | |
|     unique_phonemes = sorted(list(set(all_true_phonemes + all_pred_phonemes)))
 | |
| 
 | |
|     # Calculate metrics for each phoneme class
 | |
|     precision, recall, f1, support = precision_recall_fscore_support(
 | |
|         all_true_phonemes, all_pred_phonemes,
 | |
|         labels=unique_phonemes, average=None, zero_division=0
 | |
|     )
 | |
| 
 | |
|     # Calculate macro averages
 | |
|     macro_precision = np.mean(precision)
 | |
|     macro_recall = np.mean(recall)
 | |
|     macro_f1 = np.mean(f1)
 | |
| 
 | |
|     # Calculate micro averages
 | |
|     micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
 | |
|         all_true_phonemes, all_pred_phonemes, average='micro', zero_division=0
 | |
|     )
 | |
| 
 | |
|     return {
 | |
|         'phoneme_metrics': {
 | |
|             'phonemes': unique_phonemes,
 | |
|             'precision': precision,
 | |
|             'recall': recall,
 | |
|             'f1': f1,
 | |
|             'support': support
 | |
|         },
 | |
|         'macro_avg': {
 | |
|             'precision': macro_precision,
 | |
|             'recall': macro_recall,
 | |
|             'f1': macro_f1
 | |
|         },
 | |
|         'micro_avg': {
 | |
|             'precision': micro_precision,
 | |
|             'recall': micro_recall,
 | |
|             'f1': micro_f1
 | |
|         }
 | |
|     }
 | |
| 
 | |
| def calculate_sequence_level_metrics(true_seqs, pred_seqs):
 | |
|     """Calculate sequence-level metrics."""
 | |
|     exact_matches = 0
 | |
|     total_sequences = len(true_seqs)
 | |
| 
 | |
|     for true_seq, pred_seq in zip(true_seqs, pred_seqs):
 | |
|         if true_seq == pred_seq:
 | |
|             exact_matches += 1
 | |
| 
 | |
|     exact_match_accuracy = exact_matches / total_sequences if total_sequences > 0 else 0
 | |
| 
 | |
|     return {
 | |
|         'exact_match_accuracy': exact_match_accuracy,
 | |
|         'exact_matches': exact_matches,
 | |
|         'total_sequences': total_sequences
 | |
|     }
 | |
| 
 | |
| def calculate_confidence_intervals(errors, confidence=0.95):
 | |
|     """Calculate confidence intervals for error rates."""
 | |
|     n = len(errors)
 | |
|     if n == 0:
 | |
|         return 0, 0, 0
 | |
| 
 | |
|     mean_error = np.mean(errors)
 | |
|     std_error = np.std(errors, ddof=1) if n > 1 else 0
 | |
| 
 | |
|     # Use t-distribution for small samples, normal for large samples
 | |
|     if n < 30:
 | |
|         from scipy import stats
 | |
|         t_val = stats.t.ppf((1 + confidence) / 2, n - 1)
 | |
|         margin = t_val * std_error / np.sqrt(n)
 | |
|     else:
 | |
|         # For large samples, use normal distribution
 | |
|         z_val = 1.96 if confidence == 0.95 else 2.576  # 95% or 99%
 | |
|         margin = z_val * std_error / np.sqrt(n)
 | |
| 
 | |
|     return mean_error, mean_error - margin, mean_error + margin
 | |
| 
 | |
| def calculate_phoneme_frequency_analysis(true_seqs, pred_seqs):
 | |
|     """Analyze phoneme frequency and error patterns."""
 | |
|     true_phoneme_counts = Counter()
 | |
|     pred_phoneme_counts = Counter()
 | |
|     error_patterns = defaultdict(int)  # Maps (true_phoneme, pred_phoneme) -> count
 | |
| 
 | |
|     for true_seq, pred_seq in zip(true_seqs, pred_seqs):
 | |
|         true_phoneme_counts.update(true_seq)
 | |
|         pred_phoneme_counts.update(pred_seq)
 | |
| 
 | |
|         # For error pattern analysis, align sequences by taking minimum length
 | |
|         min_len = min(len(true_seq), len(pred_seq))
 | |
|         if min_len > 0:
 | |
|             for i in range(min_len):
 | |
|                 true_p = true_seq[i]
 | |
|                 pred_p = pred_seq[i]
 | |
|                 if true_p != pred_p:
 | |
|                     error_patterns[(true_p, pred_p)] += 1
 | |
| 
 | |
|     return {
 | |
|         'true_phoneme_counts': true_phoneme_counts,
 | |
|         'pred_phoneme_counts': pred_phoneme_counts,
 | |
|         'error_patterns': dict(error_patterns)
 | |
|     }
 | |
| 
 | |
| def generate_confusion_matrix_plot(true_seqs, pred_seqs, save_path=None):
 | |
|     """Generate and save confusion matrix plot for phonemes."""
 | |
|     all_true = []
 | |
|     all_pred = []
 | |
| 
 | |
|     # Align sequences properly to handle different lengths
 | |
|     for true_seq, pred_seq in zip(true_seqs, pred_seqs):
 | |
|         min_len = min(len(true_seq), len(pred_seq))
 | |
|         if min_len > 0:
 | |
|             # Truncate both to same length for fair comparison
 | |
|             all_true.extend(true_seq[:min_len])
 | |
|             all_pred.extend(pred_seq[:min_len])
 | |
| 
 | |
|     if len(all_true) == 0 or len(all_pred) == 0:
 | |
|         print("Warning: No aligned phonemes found for confusion matrix")
 | |
|         return None, []
 | |
| 
 | |
|     unique_phonemes = sorted(list(set(all_true + all_pred)))
 | |
|     cm = confusion_matrix(all_true, all_pred, labels=unique_phonemes)
 | |
| 
 | |
|     plt.figure(figsize=(12, 10))
 | |
|     sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
 | |
|                 xticklabels=unique_phonemes, yticklabels=unique_phonemes)
 | |
|     plt.title('Phoneme Confusion Matrix')
 | |
|     plt.xlabel('Predicted Phonemes')
 | |
|     plt.ylabel('True Phonemes')
 | |
|     plt.tight_layout()
 | |
| 
 | |
|     if save_path:
 | |
|         plt.savefig(save_path, dpi=300, bbox_inches='tight')
 | |
|         print(f"Confusion matrix saved to: {save_path}")
 | |
| 
 | |
|     plt.close()
 | |
| 
 | |
|     return cm, unique_phonemes
 | |
| 
 | |
| # argument parser for command line arguments
 | |
| parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
 | |
| parser.add_argument('--gru_model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
 | |
|                     help='Path to the pretrained GRU model directory.')
 | |
| parser.add_argument('--lstm_model_path', type=str, default='../data/t15_pretrained_lstm_baseline',
 | |
|                     help='Path to the pretrained LSTM model directory.')
 | |
| parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
 | |
|                     help='Path to the dataset directory (relative to the current working directory).')
 | |
| parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
 | |
|                     help='Evaluation type: "val" for validation set, "test" for test set.')
 | |
| parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
 | |
|                     help='Path to the CSV file with metadata about the dataset.')
 | |
| parser.add_argument('--gpu_number', type=int, default=0,
 | |
|                     help='GPU number to use for model inference. Set to -1 to use CPU.')
 | |
| parser.add_argument('--gru_weight', type=float, default=1,
 | |
|                     help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.')
 | |
| # TTA parameters
 | |
| parser.add_argument('--tta_samples', type=int, default=4,
 | |
|                     help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.')
 | |
| parser.add_argument('--tta_noise_std', type=float, default=0.01,
 | |
|                     help='Standard deviation for TTA noise augmentation.')
 | |
| parser.add_argument('--tta_smooth_range', type=float, default=0.5,
 | |
|                     help='Range for TTA smoothing kernel variation (±range from default).')
 | |
| parser.add_argument('--tta_scale_range', type=float, default=0.05,
 | |
|                     help='Range for TTA amplitude scaling (±range from 1.0).')
 | |
| parser.add_argument('--tta_cut_max', type=int, default=3,
 | |
|                     help='Maximum number of timesteps to cut from beginning in TTA.')
 | |
| # TTA样本权重配置
 | |
| parser.add_argument('--tta_weights', type=str, default='2.0,1.0,0.0,0.0,0.0',
 | |
|                     help='Comma-separated weights for each TTA augmentation type: original,noise,scale,shift,smooth. Set to 0 to disable.')
 | |
| 
 | |
| # Evaluation metrics configuration
 | |
| parser.add_argument('--skip_confusion_matrix', action='store_true',
 | |
|                     help='Skip confusion matrix generation to save time.')
 | |
| parser.add_argument('--detailed_metrics', action='store_true', default=True,
 | |
|                     help='Calculate detailed phoneme-level and sequence-level metrics.')
 | |
| parser.add_argument('--save_plots', action='store_true', default=True,
 | |
|                     help='Save confusion matrix and other plots.')
 | |
| 
 | |
| # Memory management options
 | |
| parser.add_argument('--max_trials_per_session', type=int, default=1000,
 | |
|                     help='Maximum number of trials to load per session (for memory management).')
 | |
| parser.add_argument('--batch_inference', action='store_true',
 | |
|                     help='Process trials in batches to reduce memory usage during inference.')
 | |
| parser.add_argument('--batch_size', type=int, default=10,
 | |
|                     help='Batch size for inference when using --batch_inference.')
 | |
| 
 | |
| args = parser.parse_args()
 | |
| 
 | |
| # Model paths
 | |
| gru_model_path = args.gru_model_path
 | |
| lstm_model_path = args.lstm_model_path
 | |
| data_dir = args.data_dir
 | |
| 
 | |
| # Ensemble weights
 | |
| gru_weight = args.gru_weight
 | |
| lstm_weight = 1.0 - gru_weight
 | |
| 
 | |
| # TTA parameters
 | |
| tta_samples = args.tta_samples
 | |
| tta_noise_std = args.tta_noise_std
 | |
| tta_smooth_range = args.tta_smooth_range
 | |
| tta_scale_range = args.tta_scale_range
 | |
| tta_cut_max = args.tta_cut_max
 | |
| 
 | |
| # Parse TTA weights
 | |
| tta_weights_str = args.tta_weights.split(',')
 | |
| if len(tta_weights_str) != 5:
 | |
|     raise ValueError("TTA weights must have exactly 5 values: original,noise,scale,shift,smooth")
 | |
| tta_weights = {
 | |
|     'original': float(tta_weights_str[0]),
 | |
|     'noise': float(tta_weights_str[1]),
 | |
|     'scale': float(tta_weights_str[2]),
 | |
|     'shift': float(tta_weights_str[3]),
 | |
|     'smooth': float(tta_weights_str[4])
 | |
| }
 | |
| 
 | |
| # Calculate effective number of TTA samples based on enabled augmentations
 | |
| enabled_augmentations = [k for k, v in tta_weights.items() if v > 0]
 | |
| effective_tta_samples = len(enabled_augmentations)
 | |
| 
 | |
| print(f"Improved TTA-E Configuration:")
 | |
| print(f"GRU weight: {gru_weight:.2f} (balanced for better ensemble)")
 | |
| print(f"LSTM weight: {lstm_weight:.2f}")
 | |
| print(f"TTA augmentation weights:")
 | |
| for aug_type, weight in tta_weights.items():
 | |
|     status = "✓ enabled" if weight > 0 else "✗ disabled"
 | |
|     print(f"  - {aug_type}: {weight:.1f} ({status})")
 | |
| print(f"Effective TTA samples: {effective_tta_samples}")
 | |
| print(f"TTA noise std: {tta_noise_std}")
 | |
| print(f"TTA smooth range: ±{tta_smooth_range}")
 | |
| print(f"TTA scale range: ±{tta_scale_range}")
 | |
| print(f"TTA max cut: {tta_cut_max} timesteps")
 | |
| print(f"Key improvements:")
 | |
| print(f"  - Customizable TTA sample weights")
 | |
| print(f"  - Probability-based ensemble instead of logit averaging")
 | |
| print(f"  - Geometric mean for more robust fusion")
 | |
| print(f"  - Direct PER evaluation without language model dependency")
 | |
| print(f"GRU model path: {gru_model_path}")
 | |
| print(f"LSTM model path: {lstm_model_path}")
 | |
| print()
 | |
| 
 | |
| # Define evaluation type
 | |
| eval_type = args.eval_type
 | |
| 
 | |
| # Load CSV file
 | |
| b2txt_csv_df = pd.read_csv(args.csv_path)
 | |
| 
 | |
| # Load model arguments for both models
 | |
| gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml'))
 | |
| lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml'))
 | |
| 
 | |
| # Set up GPU device
 | |
| gpu_number = args.gpu_number
 | |
| if torch.cuda.is_available() and gpu_number >= 0:
 | |
|     if gpu_number >= torch.cuda.device_count():
 | |
|         raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
 | |
|     device = f'cuda:{gpu_number}'
 | |
|     device = torch.device(device)
 | |
|     print(f'Using {device} for model inference.')
 | |
| else:
 | |
|     if gpu_number >= 0:
 | |
|         print(f'GPU number {gpu_number} requested but not available.')
 | |
|     print('Using CPU for model inference.')
 | |
|     device = torch.device('cpu')
 | |
| 
 | |
| # Define GRU model
 | |
| gru_model = GRUDecoder(
 | |
|     neural_dim=gru_model_args['model']['n_input_features'],
 | |
|     n_units=gru_model_args['model']['n_units'], 
 | |
|     n_days=len(gru_model_args['dataset']['sessions']),
 | |
|     n_classes=gru_model_args['dataset']['n_classes'],
 | |
|     rnn_dropout=gru_model_args['model']['rnn_dropout'],
 | |
|     input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
 | |
|     n_layers=gru_model_args['model']['n_layers'],
 | |
|     patch_size=gru_model_args['model']['patch_size'],
 | |
|     patch_stride=gru_model_args['model']['patch_stride'],
 | |
| )
 | |
| 
 | |
| # Load GRU model weights
 | |
| gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'), 
 | |
|                            weights_only=False, map_location=device)
 | |
| # Rename keys to not start with "module." (happens if model was saved with DataParallel)
 | |
| for key in list(gru_checkpoint['model_state_dict'].keys()):
 | |
|     gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
 | |
|     gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
 | |
| gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
 | |
| 
 | |
| # Define LSTM model
 | |
| lstm_model = LSTMDecoder(
 | |
|     neural_dim=lstm_model_args['model']['n_input_features'],
 | |
|     n_units=lstm_model_args['model']['n_units'], 
 | |
|     n_days=len(lstm_model_args['dataset']['sessions']),
 | |
|     n_classes=lstm_model_args['dataset']['n_classes'],
 | |
|     rnn_dropout=lstm_model_args['model']['rnn_dropout'],
 | |
|     input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
 | |
|     n_layers=lstm_model_args['model']['n_layers'],
 | |
|     patch_size=lstm_model_args['model']['patch_size'],
 | |
|     patch_stride=lstm_model_args['model']['patch_stride'],
 | |
| )
 | |
| 
 | |
| # Load LSTM model weights
 | |
| lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'), 
 | |
|                             weights_only=False, map_location=device)
 | |
| # Rename keys to not start with "module." (happens if model was saved with DataParallel)
 | |
| for key in list(lstm_checkpoint['model_state_dict'].keys()):
 | |
|     lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
 | |
|     lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
 | |
| lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
 | |
| 
 | |
| # Add models to device
 | |
| gru_model.to(device)
 | |
| lstm_model.to(device)
 | |
| 
 | |
| # Set models to eval mode
 | |
| gru_model.eval()
 | |
| lstm_model.eval()
 | |
| 
 | |
| print("Both models loaded successfully!")
 | |
| print()
 | |
| 
 | |
| # TTA-E inference function
 | |
| def runTTAEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, 
 | |
|                               device, gru_weight, lstm_weight, tta_weights, tta_noise_std, 
 | |
|                               tta_smooth_range, tta_scale_range, tta_cut_max):
 | |
|     """
 | |
|     Run Customizable TTA-E (Test Time Augmentation + Ensemble) inference:
 | |
|     1. Apply selected data augmentations based on weights
 | |
|     2. Run both GRU and LSTM models on each augmented version
 | |
|     3. Use probability-based ensemble with customizable sample weights
 | |
|     4. Average predictions in probability space for better stability
 | |
|     """
 | |
|     all_gru_probs = []
 | |
|     all_lstm_probs = []
 | |
|     sample_weights = []
 | |
|     
 | |
|     # Get default smoothing parameters
 | |
|     default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
 | |
|     default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
 | |
|     
 | |
|     # Generate augmented samples based on enabled augmentations
 | |
|     augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
 | |
|     
 | |
|     for aug_type in augmentation_types:
 | |
|         if tta_weights[aug_type] <= 0:
 | |
|             continue  # Skip disabled augmentations
 | |
|             
 | |
|         x_augmented = x.clone()
 | |
|         
 | |
|         if aug_type == 'original':
 | |
|             # Original data (baseline)
 | |
|             pass
 | |
|         elif aug_type == 'noise':
 | |
|             # Add Gaussian noise with varying intensity
 | |
|             noise_scale = tta_noise_std * (0.5 + 0.5 * np.random.rand())
 | |
|             noise = torch.randn_like(x_augmented) * noise_scale
 | |
|             x_augmented = x_augmented + noise
 | |
|         elif aug_type == 'scale':
 | |
|             # Amplitude scaling with more variation
 | |
|             scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range
 | |
|             x_augmented = x_augmented * scale_factor
 | |
|         elif aug_type == 'shift' and tta_cut_max > 0:
 | |
|             # Time shift (circular shift)
 | |
|             shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8))
 | |
|             x_augmented = torch.cat([x_augmented[:, shift_amount:, :], 
 | |
|                                    x_augmented[:, :shift_amount, :]], dim=1)
 | |
|         elif aug_type == 'smooth':
 | |
|             # Apply smoothing variation after augmentation
 | |
|             smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range
 | |
|             varied_smooth_std = max(0.3, default_smooth_std + smooth_variation)
 | |
| 
 | |
|         # Use autocast for efficiency - auto-detect device type
 | |
|         device_type = "cuda" if device.type == "cuda" else "cpu"
 | |
|         use_amp = gru_model_args.get('use_amp', False) and device.type == "cuda"
 | |
|         with torch.autocast(device_type=device_type, enabled=use_amp, dtype=torch.bfloat16 if device.type == "cuda" else torch.float32):
 | |
|             
 | |
|             # Apply Gaussian smoothing
 | |
|             if aug_type == 'smooth':
 | |
|                 # Use varied smoothing
 | |
|                 x_smoothed = gauss_smooth(
 | |
|                     inputs=x_augmented, 
 | |
|                     device=device,
 | |
|                     smooth_kernel_std=varied_smooth_std,
 | |
|                     smooth_kernel_size=default_smooth_size,
 | |
|                     padding='valid',
 | |
|                 )
 | |
|             else:
 | |
|                 # Use default smoothing
 | |
|                 x_smoothed = gauss_smooth(
 | |
|                     inputs=x_augmented, 
 | |
|                     device=device,
 | |
|                     smooth_kernel_std=default_smooth_std,
 | |
|                     smooth_kernel_size=default_smooth_size,
 | |
|                     padding='valid',
 | |
|                 )
 | |
| 
 | |
|             with torch.no_grad():
 | |
|                 # Get GRU logits and convert to probabilities
 | |
|                 gru_logits, _ = gru_model(
 | |
|                     x=x_smoothed,
 | |
|                     day_idx=torch.tensor([input_layer], device=device),
 | |
|                     states=None,
 | |
|                     return_state=True,
 | |
|                 )
 | |
|                 gru_probs = torch.softmax(gru_logits, dim=-1)
 | |
|                 
 | |
|                 # Get LSTM logits and convert to probabilities
 | |
|                 lstm_logits, _ = lstm_model(
 | |
|                     x=x_smoothed,
 | |
|                     day_idx=torch.tensor([input_layer], device=device),
 | |
|                     states=None,
 | |
|                     return_state=True,
 | |
|                 )
 | |
|                 lstm_probs = torch.softmax(lstm_logits, dim=-1)
 | |
|                 
 | |
|                 all_gru_probs.append(gru_probs)
 | |
|                 all_lstm_probs.append(lstm_probs)
 | |
|                 sample_weights.append(tta_weights[aug_type])
 | |
| 
 | |
|     if len(all_gru_probs) == 0:
 | |
|         raise ValueError("No TTA augmentations are enabled. Please set at least one weight > 0.")
 | |
| 
 | |
|     # TTA fusion: Handle potentially different tensor shapes
 | |
|     if len(all_gru_probs) > 1:
 | |
|         # Find the minimum sequence length among all TTA samples
 | |
|         min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs])
 | |
|         
 | |
|         # Truncate all tensors to the minimum length
 | |
|         truncated_gru_probs = []
 | |
|         truncated_lstm_probs = []
 | |
|         for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs):
 | |
|             if gru_probs.shape[1] > min_length:
 | |
|                 truncated_gru_probs.append(gru_probs[:, :min_length, :])
 | |
|             else:
 | |
|                 truncated_gru_probs.append(gru_probs)
 | |
|             
 | |
|             if lstm_probs.shape[1] > min_length:
 | |
|                 truncated_lstm_probs.append(lstm_probs[:, :min_length, :])
 | |
|             else:
 | |
|                 truncated_lstm_probs.append(lstm_probs)
 | |
|         
 | |
|         # Weighted average probabilities across TTA samples
 | |
|         sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32, device=device)
 | |
|         sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum()  # Normalize weights
 | |
|         
 | |
|         weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0])
 | |
|         weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0])
 | |
|         
 | |
|         for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)):
 | |
|             weighted_gru_probs += weight * gru_probs
 | |
|             weighted_lstm_probs += weight * lstm_probs
 | |
|         
 | |
|         avg_gru_probs = weighted_gru_probs
 | |
|         avg_lstm_probs = weighted_lstm_probs
 | |
|     else:
 | |
|         avg_gru_probs = all_gru_probs[0]
 | |
|         avg_lstm_probs = all_lstm_probs[0]
 | |
|     
 | |
|     # Improved ensemble: geometric mean of probabilities (more robust than weighted average)
 | |
|     # Apply epsilon smoothing to avoid log(0)
 | |
|     epsilon = 1e-8
 | |
|     avg_gru_probs = avg_gru_probs + epsilon
 | |
|     avg_lstm_probs = avg_lstm_probs + epsilon
 | |
|     
 | |
|     # Weighted geometric mean in log space
 | |
|     log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) + 
 | |
|                          lstm_weight * torch.log(avg_lstm_probs))
 | |
|     ensemble_probs = torch.exp(log_ensemble_probs)
 | |
|     
 | |
|     # Normalize to ensure proper probability distribution
 | |
|     ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
 | |
|     
 | |
|     # Convert back to logits for compatibility with downstream code
 | |
|     final_logits = torch.log(ensemble_probs + epsilon)
 | |
|     
 | |
|     # Convert from bfloat16 to float32
 | |
|     return final_logits.float().cpu().numpy()
 | |
| 
 | |
| # Load data for each session with memory management
 | |
| test_data = {}
 | |
| total_test_trials = 0
 | |
| 
 | |
| print("Loading data with memory management...")
 | |
| for session in gru_model_args['dataset']['sessions']:
 | |
|     files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
 | |
|     if f'data_{eval_type}.hdf5' in files:
 | |
|         eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
 | |
| 
 | |
|         try:
 | |
|             # Try to load data normally first
 | |
|             print(f'Loading {session}...')
 | |
|             data = load_h5py_file(eval_file, b2txt_csv_df)
 | |
|             test_data[session] = data
 | |
|             total_test_trials += len(test_data[session]["neural_features"])
 | |
|             print(f'Successfully loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
 | |
|         except (MemoryError, np._core._exceptions._ArrayMemoryError) as e:
 | |
|             print(f'Memory error loading {session}: {e}')
 | |
|             print(f'Attempting to load {session} in smaller chunks...')
 | |
| 
 | |
|             # If memory error, try loading in chunks or skip this session
 | |
|             import h5py
 | |
|             try:
 | |
|                 with h5py.File(eval_file, 'r') as h5file:
 | |
|                     # Get the size first
 | |
|                     input_features_shape = h5file['input_features'].shape
 | |
|                     print(f'Session {session} has {input_features_shape[0]} trials with {input_features_shape[1]} features')
 | |
| 
 | |
|                     # If it's too large, skip or implement chunk loading
 | |
|                     max_trials_per_session = args.max_trials_per_session  # Limit trials per session
 | |
|                     if input_features_shape[0] > max_trials_per_session:
 | |
|                         print(f'Session {session} has {input_features_shape[0]} trials, limiting to {max_trials_per_session} for memory management')
 | |
|                         # Could implement chunked loading here if needed
 | |
|                         print(f'Skipping session {session} due to memory constraints')
 | |
|                         continue
 | |
|                     else:
 | |
|                         # Try one more time with garbage collection
 | |
|                         import gc
 | |
|                         gc.collect()
 | |
|                         data = load_h5py_file(eval_file, b2txt_csv_df)
 | |
|                         test_data[session] = data
 | |
|                         total_test_trials += len(test_data[session]["neural_features"])
 | |
|                         print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session} after GC.')
 | |
|             except Exception as e2:
 | |
|                 print(f'Failed to load session {session}: {e2}')
 | |
|                 print(f'Skipping session {session}')
 | |
|                 continue
 | |
| 
 | |
| print(f'Total number of {eval_type} trials loaded: {total_test_trials}')
 | |
| if total_test_trials == 0:
 | |
|     print("ERROR: No trials loaded! Check data paths and memory availability.")
 | |
|     sys.exit(1)
 | |
| print()
 | |
| 
 | |
| # Put neural data through the TTA-E ensemble model to get phoneme predictions (logits)
 | |
| import gc
 | |
| with tqdm(total=total_test_trials, desc=f'Customizable TTA-E inference ({effective_tta_samples} augmentation types)', unit='trial') as pbar:
 | |
|     for session, data in test_data.items():
 | |
| 
 | |
|         data['logits'] = []
 | |
|         data['pred_seq'] = []
 | |
|         input_layer = gru_model_args['dataset']['sessions'].index(session)
 | |
| 
 | |
|         # Process trials in batches if requested
 | |
|         if args.batch_inference:
 | |
|             batch_size = args.batch_size
 | |
|             num_trials = len(data['neural_features'])
 | |
| 
 | |
|             for batch_start in range(0, num_trials, batch_size):
 | |
|                 batch_end = min(batch_start + batch_size, num_trials)
 | |
|                 batch_logits = []
 | |
| 
 | |
|                 for trial in range(batch_start, batch_end):
 | |
|                     # Get neural input for the trial
 | |
|                     neural_input = data['neural_features'][trial]
 | |
| 
 | |
|                     # Add batch dimension
 | |
|                     neural_input = np.expand_dims(neural_input, axis=0)
 | |
| 
 | |
|                     # Convert to torch tensor - use float32 to avoid dtype issues with smoothing
 | |
|                     neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
 | |
| 
 | |
|                     # Run TTA-E decoding step with customizable weights
 | |
|                     ensemble_logits = runTTAEnsembleDecodingStep(
 | |
|                         neural_input, input_layer, gru_model, lstm_model,
 | |
|                         gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
 | |
|                         tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
 | |
|                     )
 | |
|                     batch_logits.append(ensemble_logits)
 | |
| 
 | |
|                     # Clear GPU memory
 | |
|                     del neural_input
 | |
|                     if device.type == 'cuda':
 | |
|                         torch.cuda.empty_cache()
 | |
| 
 | |
|                     pbar.update(1)
 | |
| 
 | |
|                 # Add batch results
 | |
|                 data['logits'].extend(batch_logits)
 | |
| 
 | |
|                 # Clear batch memory
 | |
|                 del batch_logits
 | |
|                 gc.collect()
 | |
|         else:
 | |
|             # Process all trials individually (original method)
 | |
|             for trial in range(len(data['neural_features'])):
 | |
|                 # Get neural input for the trial
 | |
|                 neural_input = data['neural_features'][trial]
 | |
| 
 | |
|                 # Add batch dimension
 | |
|                 neural_input = np.expand_dims(neural_input, axis=0)
 | |
| 
 | |
|                 # Convert to torch tensor - use float32 to avoid dtype issues with smoothing
 | |
|                 neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
 | |
| 
 | |
|                 # Run TTA-E decoding step with customizable weights
 | |
|                 ensemble_logits = runTTAEnsembleDecodingStep(
 | |
|                     neural_input, input_layer, gru_model, lstm_model,
 | |
|                     gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
 | |
|                     tta_weights, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
 | |
|                 )
 | |
|                 data['logits'].append(ensemble_logits)
 | |
| 
 | |
|                 # Clear memory periodically
 | |
|                 if trial % 50 == 0:
 | |
|                     del neural_input
 | |
|                     if device.type == 'cuda':
 | |
|                         torch.cuda.empty_cache()
 | |
|                     gc.collect()
 | |
| 
 | |
|                 pbar.update(1)
 | |
| 
 | |
|         # Clear session data memory after processing
 | |
|         gc.collect()
 | |
| 
 | |
| pbar.close()
 | |
| 
 | |
| # Convert logits to phoneme sequences and print them out
 | |
| results = {
 | |
|     'session': [],
 | |
|     'block': [],
 | |
|     'trial': [],
 | |
|     'true_sentence': [],
 | |
|     'pred_phonemes': [],
 | |
| }
 | |
| 
 | |
| for session, data in test_data.items():
 | |
|     data['pred_seq'] = []
 | |
|     for trial in range(len(data['logits'])):
 | |
|         logits = data['logits'][trial][0]
 | |
|         pred_seq = np.argmax(logits, axis=-1)
 | |
|         # Remove blanks (0)
 | |
|         pred_seq = [int(p) for p in pred_seq if p != 0]
 | |
|         # Remove consecutive duplicates
 | |
|         pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
 | |
|         # Convert to phonemes
 | |
|         pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
 | |
|         # Add to data
 | |
|         data['pred_seq'].append(pred_seq)
 | |
| 
 | |
|         # Store results
 | |
|         results['session'].append(session)
 | |
|         results['block'].append(data['block_num'][trial])
 | |
|         results['trial'].append(data['trial_num'][trial])
 | |
|         if eval_type == 'val':
 | |
|             results['true_sentence'].append(data['sentence_label'][trial])
 | |
|         else:
 | |
|             results['true_sentence'].append(None)
 | |
|         results['pred_phonemes'].append(pred_seq)
 | |
| 
 | |
|         # Print out the predicted sequences
 | |
|         block_num = data['block_num'][trial]
 | |
|         trial_num = data['trial_num'][trial]
 | |
|         print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
 | |
|         if eval_type == 'val':
 | |
|             sentence_label = data['sentence_label'][trial]
 | |
|             true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
 | |
|             true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
 | |
| 
 | |
|             print(f'Sentence label:      {sentence_label}')
 | |
|             print(f'True sequence:       {" ".join(true_seq)}')
 | |
|         print(f'Predicted Sequence:  {" ".join(pred_seq)}')
 | |
|         print()
 | |
| 
 | |
| # If using the validation set, calculate comprehensive evaluation metrics
 | |
| if eval_type == 'val':
 | |
|     print("=" * 80)
 | |
|     print("COMPREHENSIVE EVALUATION METRICS")
 | |
|     print("=" * 80)
 | |
| 
 | |
|     # Record performance timing
 | |
|     eval_start_time = time.time()
 | |
| 
 | |
|     # Initialize metrics storage
 | |
|     total_true_length = 0
 | |
|     total_edit_distance = 0
 | |
|     all_true_seqs = []
 | |
|     all_pred_seqs = []
 | |
|     per_trial_errors = []
 | |
|     per_session_metrics = defaultdict(list)
 | |
| 
 | |
|     # Additional detailed metrics
 | |
|     results['edit_distance'] = []
 | |
|     results['num_phonemes'] = []
 | |
|     results['insertions'] = []
 | |
|     results['deletions'] = []
 | |
|     results['substitutions'] = []
 | |
|     results['per_trial_per'] = []
 | |
| 
 | |
|     print(f"\nCalculating detailed metrics for {len(results['pred_phonemes'])} trials...")
 | |
| 
 | |
|     for i in range(len(results['pred_phonemes'])):
 | |
|         # Get true phoneme sequence
 | |
|         session = results['session'][i]
 | |
|         trial_idx = None
 | |
|         for trial in range(len(test_data[session]['seq_class_ids'])):
 | |
|             if (test_data[session]['block_num'][trial] == results['block'][i] and
 | |
|                 test_data[session]['trial_num'][trial] == results['trial'][i]):
 | |
|                 trial_idx = trial
 | |
|                 break
 | |
| 
 | |
|         if trial_idx is not None:
 | |
|             true_seq = test_data[session]['seq_class_ids'][trial_idx][0:test_data[session]['seq_len'][trial_idx]]
 | |
|             true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
 | |
|             pred_seq = results['pred_phonemes'][i]
 | |
| 
 | |
|             all_true_seqs.append(true_seq)
 | |
|             all_pred_seqs.append(pred_seq)
 | |
| 
 | |
|             # Calculate basic edit distance
 | |
|             ed = editdistance.eval(true_seq, pred_seq)
 | |
|             trial_per = 100 * ed / len(true_seq) if len(true_seq) > 0 else 0
 | |
| 
 | |
|             # Calculate detailed phoneme operations
 | |
|             insertions, deletions, substitutions = calculate_phoneme_operations(true_seq, pred_seq)
 | |
| 
 | |
|             # Store metrics
 | |
|             total_true_length += len(true_seq)
 | |
|             total_edit_distance += ed
 | |
|             per_trial_errors.append(trial_per)
 | |
|             per_session_metrics[session].append(trial_per)
 | |
| 
 | |
|             results['edit_distance'].append(ed)
 | |
|             results['num_phonemes'].append(len(true_seq))
 | |
|             results['insertions'].append(insertions)
 | |
|             results['deletions'].append(deletions)
 | |
|             results['substitutions'].append(substitutions)
 | |
|             results['per_trial_per'].append(trial_per)
 | |
| 
 | |
|             # Print detailed trial information
 | |
|             print(f'{results["session"][i]} - Block {results["block"][i]}, Trial {results["trial"][i]}')
 | |
|             print(f'True phonemes ({len(true_seq)}):       {" ".join(true_seq)}')
 | |
|             print(f'Predicted phonemes ({len(pred_seq)}):  {" ".join(pred_seq)}')
 | |
|             print(f'Operations: I={insertions}, D={deletions}, S={substitutions}, Total ED={ed}')
 | |
|             print(f'Trial PER: {ed} / {len(true_seq)} = {trial_per:.2f}%')
 | |
|             print()
 | |
| 
 | |
|     # ========== BASIC METRICS ==========
 | |
|     aggregate_per = 100 * total_edit_distance / total_true_length if total_true_length > 0 else 0
 | |
| 
 | |
|     print("=" * 50)
 | |
|     print("BASIC PHONEME ERROR RATE METRICS")
 | |
|     print("=" * 50)
 | |
|     print(f'Total true phoneme length: {total_true_length}')
 | |
|     print(f'Total edit distance: {total_edit_distance}')
 | |
|     print(f'Aggregate Phoneme Error Rate (PER): {aggregate_per:.2f}%')
 | |
| 
 | |
|     # Calculate PER confidence intervals
 | |
|     per_mean, per_ci_lower, per_ci_upper = calculate_confidence_intervals(per_trial_errors)
 | |
|     print(f'Per-trial PER: {per_mean:.2f}% (95% CI: [{per_ci_lower:.2f}%, {per_ci_upper:.2f}%])')
 | |
| 
 | |
|     # ========== PHONEME OPERATION ANALYSIS ==========
 | |
|     total_insertions = sum(results['insertions'])
 | |
|     total_deletions = sum(results['deletions'])
 | |
|     total_substitutions = sum(results['substitutions'])
 | |
| 
 | |
|     print("\n" + "=" * 50)
 | |
|     print("PHONEME OPERATION BREAKDOWN")
 | |
|     print("=" * 50)
 | |
|     print(f'Total insertions: {total_insertions} ({100*total_insertions/total_edit_distance:.1f}% of errors)')
 | |
|     print(f'Total deletions: {total_deletions} ({100*total_deletions/total_edit_distance:.1f}% of errors)')
 | |
|     print(f'Total substitutions: {total_substitutions} ({100*total_substitutions/total_edit_distance:.1f}% of errors)')
 | |
|     print(f'Verification: I+D+S = {total_insertions + total_deletions + total_substitutions} (should equal {total_edit_distance})')
 | |
| 
 | |
|     # ========== PHONEME-LEVEL CLASSIFICATION METRICS ==========
 | |
|     print("\n" + "=" * 50)
 | |
|     print("PHONEME-LEVEL CLASSIFICATION METRICS")
 | |
|     print("=" * 50)
 | |
| 
 | |
|     phoneme_metrics = calculate_phoneme_level_metrics(all_true_seqs, all_pred_seqs)
 | |
| 
 | |
|     print(f"Macro-averaged metrics:")
 | |
|     print(f"  Precision: {phoneme_metrics['macro_avg']['precision']:.3f}")
 | |
|     print(f"  Recall: {phoneme_metrics['macro_avg']['recall']:.3f}")
 | |
|     print(f"  F1-Score: {phoneme_metrics['macro_avg']['f1']:.3f}")
 | |
| 
 | |
|     print(f"\nMicro-averaged metrics:")
 | |
|     print(f"  Precision: {phoneme_metrics['micro_avg']['precision']:.3f}")
 | |
|     print(f"  Recall: {phoneme_metrics['micro_avg']['recall']:.3f}")
 | |
|     print(f"  F1-Score: {phoneme_metrics['micro_avg']['f1']:.3f}")
 | |
| 
 | |
|     # ========== SEQUENCE-LEVEL METRICS ==========
 | |
|     print("\n" + "=" * 50)
 | |
|     print("SEQUENCE-LEVEL METRICS")
 | |
|     print("=" * 50)
 | |
| 
 | |
|     seq_metrics = calculate_sequence_level_metrics(all_true_seqs, all_pred_seqs)
 | |
|     print(f"Exact Match Accuracy: {seq_metrics['exact_match_accuracy']:.3f} ({seq_metrics['exact_matches']}/{seq_metrics['total_sequences']})")
 | |
| 
 | |
|     # ========== PER-SESSION BREAKDOWN ==========
 | |
|     print("\n" + "=" * 50)
 | |
|     print("PER-SESSION PERFORMANCE BREAKDOWN")
 | |
|     print("=" * 50)
 | |
| 
 | |
|     for session in sorted(per_session_metrics.keys()):
 | |
|         session_errors = per_session_metrics[session]
 | |
|         session_mean, session_ci_lower, session_ci_upper = calculate_confidence_intervals(session_errors)
 | |
|         print(f"{session}: {session_mean:.2f}% ± {session_mean - session_ci_lower:.2f}% "
 | |
|               f"(n={len(session_errors)}, range: {min(session_errors):.1f}%-{max(session_errors):.1f}%)")
 | |
| 
 | |
|     # ========== STATISTICAL SUMMARY ==========
 | |
|     print("\n" + "=" * 50)
 | |
|     print("STATISTICAL SUMMARY")
 | |
|     print("=" * 50)
 | |
|     print(f"Mean trial PER: {np.mean(per_trial_errors):.2f}%")
 | |
|     print(f"Median trial PER: {np.median(per_trial_errors):.2f}%")
 | |
|     print(f"Std deviation: {np.std(per_trial_errors):.2f}%")
 | |
|     print(f"Min/Max trial PER: {min(per_trial_errors):.1f}% / {max(per_trial_errors):.1f}%")
 | |
|     print(f"Trials with 0% PER: {sum(1 for x in per_trial_errors if x == 0)} ({100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors):.1f}%)")
 | |
| 
 | |
|     # ========== PHONEME FREQUENCY ANALYSIS ==========
 | |
|     if args.detailed_metrics:
 | |
|         print("\n" + "=" * 50)
 | |
|         print("PHONEME FREQUENCY AND ERROR PATTERN ANALYSIS")
 | |
|         print("=" * 50)
 | |
| 
 | |
|         freq_analysis = calculate_phoneme_frequency_analysis(all_true_seqs, all_pred_seqs)
 | |
| 
 | |
|         print("Top 10 most frequent true phonemes:")
 | |
|         for phoneme, count in freq_analysis['true_phoneme_counts'].most_common(10):
 | |
|             print(f"  {phoneme}: {count} ({100*count/sum(freq_analysis['true_phoneme_counts'].values()):.1f}%)")
 | |
| 
 | |
|         print("\nTop 10 most frequent predicted phonemes:")
 | |
|         for phoneme, count in freq_analysis['pred_phoneme_counts'].most_common(10):
 | |
|             print(f"  {phoneme}: {count} ({100*count/sum(freq_analysis['pred_phoneme_counts'].values()):.1f}%)")
 | |
| 
 | |
|         print("\nTop 10 most frequent error patterns (true -> predicted):")
 | |
|         sorted_errors = sorted(freq_analysis['error_patterns'].items(), key=lambda x: x[1], reverse=True)
 | |
|         for (true_p, pred_p), count in sorted_errors[:10]:
 | |
|             print(f"  {true_p} -> {pred_p}: {count} errors")
 | |
| 
 | |
|     # ========== GENERATE CONFUSION MATRIX ==========
 | |
|     if not args.skip_confusion_matrix and args.save_plots:
 | |
|         timestamp_for_plots = time.strftime("%Y%m%d_%H%M%S")
 | |
|         confusion_matrix_path = f'confusion_matrix_TTA-E_{eval_type}_{timestamp_for_plots}.png'
 | |
|         confusion_matrix_path = os.path.join(os.path.dirname(__file__), confusion_matrix_path)
 | |
| 
 | |
|         print(f"\nGenerating confusion matrix...")
 | |
|         cm_result = generate_confusion_matrix_plot(all_true_seqs, all_pred_seqs, confusion_matrix_path)
 | |
|         if cm_result[0] is None:
 | |
|             confusion_matrix_path = None
 | |
|     else:
 | |
|         print(f"\nSkipping confusion matrix generation (--skip_confusion_matrix or --no-save_plots specified)")
 | |
|         confusion_matrix_path = None
 | |
| 
 | |
|     # ========== TIMING METRICS ==========
 | |
|     eval_end_time = time.time()
 | |
|     eval_duration = eval_end_time - eval_start_time
 | |
| 
 | |
|     print("\n" + "=" * 50)
 | |
|     print("TIMING METRICS")
 | |
|     print("=" * 50)
 | |
|     print(f"Evaluation time: {eval_duration:.2f} seconds")
 | |
|     print(f"Time per trial: {eval_duration/len(results['pred_phonemes']):.3f} seconds")
 | |
| 
 | |
|     print("\n" + "=" * 80)
 | |
|     print("EVALUATION SUMMARY COMPLETE")
 | |
|     print("=" * 80)
 | |
| 
 | |
| # Write comprehensive results to CSV files
 | |
| timestamp = time.strftime("%Y%m%d_%H%M%S")
 | |
| enabled_augs = '_'.join([k for k, v in tta_weights.items() if v > 0])
 | |
| 
 | |
| # Basic phoneme predictions CSV (for compatibility)
 | |
| output_file = f'TTA-E_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_{enabled_augs}_{eval_type}_{timestamp}.csv'
 | |
| output_path = os.path.join(os.path.dirname(__file__), output_file)
 | |
| 
 | |
| ids = [i for i in range(len(results['pred_phonemes']))]
 | |
| phoneme_strings = [" ".join(phonemes) for phonemes in results['pred_phonemes']]
 | |
| df_out = pd.DataFrame({'id': ids, 'phonemes': phoneme_strings})
 | |
| df_out.to_csv(output_path, index=False)
 | |
| 
 | |
| # Comprehensive metrics CSV (if validation set)
 | |
| if eval_type == 'val':
 | |
|     detailed_output_file = f'TTA-E_detailed_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv'
 | |
|     detailed_output_path = os.path.join(os.path.dirname(__file__), detailed_output_file)
 | |
| 
 | |
|     # Create detailed DataFrame with all metrics
 | |
|     detailed_df = pd.DataFrame({
 | |
|         'id': ids,
 | |
|         'session': results['session'],
 | |
|         'block': results['block'],
 | |
|         'trial': results['trial'],
 | |
|         'true_sentence': results['true_sentence'],
 | |
|         'pred_phonemes': phoneme_strings,
 | |
|         'num_phonemes': results['num_phonemes'],
 | |
|         'edit_distance': results['edit_distance'],
 | |
|         'insertions': results['insertions'],
 | |
|         'deletions': results['deletions'],
 | |
|         'substitutions': results['substitutions'],
 | |
|         'trial_per': results['per_trial_per']
 | |
|     })
 | |
|     detailed_df.to_csv(detailed_output_path, index=False)
 | |
| 
 | |
|     # Summary metrics CSV
 | |
|     summary_output_file = f'TTA-E_summary_metrics_{enabled_augs}_{eval_type}_{timestamp}.csv'
 | |
|     summary_output_path = os.path.join(os.path.dirname(__file__), summary_output_file)
 | |
| 
 | |
|     summary_data = {
 | |
|         'metric': [
 | |
|             'aggregate_per', 'mean_trial_per', 'median_trial_per', 'std_trial_per',
 | |
|             'min_trial_per', 'max_trial_per', 'zero_per_trials_count', 'zero_per_trials_percent',
 | |
|             'total_phonemes', 'total_edit_distance', 'total_insertions', 'total_deletions',
 | |
|             'total_substitutions', 'exact_match_accuracy', 'exact_matches',
 | |
|             'macro_precision', 'macro_recall', 'macro_f1',
 | |
|             'micro_precision', 'micro_recall', 'micro_f1',
 | |
|             'per_mean_95ci_lower', 'per_mean_95ci_upper', 'evaluation_time_seconds'
 | |
|         ],
 | |
|         'value': [
 | |
|             aggregate_per, np.mean(per_trial_errors), np.median(per_trial_errors), np.std(per_trial_errors),
 | |
|             min(per_trial_errors), max(per_trial_errors),
 | |
|             sum(1 for x in per_trial_errors if x == 0),
 | |
|             100*sum(1 for x in per_trial_errors if x == 0)/len(per_trial_errors),
 | |
|             total_true_length, total_edit_distance, total_insertions, total_deletions,
 | |
|             total_substitutions, seq_metrics['exact_match_accuracy'], seq_metrics['exact_matches'],
 | |
|             phoneme_metrics['macro_avg']['precision'], phoneme_metrics['macro_avg']['recall'],
 | |
|             phoneme_metrics['macro_avg']['f1'], phoneme_metrics['micro_avg']['precision'],
 | |
|             phoneme_metrics['micro_avg']['recall'], phoneme_metrics['micro_avg']['f1'],
 | |
|             per_ci_lower, per_ci_upper, eval_duration
 | |
|         ]
 | |
|     }
 | |
| 
 | |
|     summary_df = pd.DataFrame(summary_data)
 | |
|     summary_df.to_csv(summary_output_path, index=False)
 | |
| 
 | |
|     print(f'\nComprehensive results saved to:')
 | |
|     print(f'  Basic predictions: {output_path}')
 | |
|     print(f'  Detailed metrics: {detailed_output_path}')
 | |
|     print(f'  Summary metrics: {summary_output_path}')
 | |
|     if confusion_matrix_path:
 | |
|         print(f'  Confusion matrix: {confusion_matrix_path}')
 | |
|     else:
 | |
|         print(f'  Confusion matrix: Not generated')
 | |
| else:
 | |
|     print(f'\nResults saved to: {output_path}')
 | |
| 
 | |
| print(f'TTA-E configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}')
 | |
| print(f'Enabled augmentations: {", ".join(enabled_augs.split("_"))}')
 | 
