import os import sys import torch import numpy as np import pandas as pd from omegaconf import OmegaConf import time from tqdm import tqdm import editdistance import argparse # Add parent directories to path to import models sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) from model_training.rnn_model import GRUDecoder from model_training_lstm.rnn_model import LSTMDecoder from model_training.evaluate_model_helpers import * # argument parser for command line arguments parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models (without TTA) on the copy task dataset.') parser.add_argument('--gru_model_path', type=str, default='../data/t15_pretrained_rnn_baseline', help='Path to the pretrained GRU model directory.') parser.add_argument('--lstm_model_path', type=str, default='../model_training_lstm/trained_models/baseline_rnn', help='Path to the pretrained LSTM model directory.') parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', help='Path to the dataset directory (relative to the current working directory).') parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'], help='Evaluation type: "val" for validation set, "test" for test set.') parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', help='Path to the CSV file with metadata about the dataset.') parser.add_argument('--gpu_number', type=int, default=0, help='GPU number to use for model inference. Set to -1 to use CPU.') parser.add_argument('--gru_weight', type=float, default=0.5, help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight).') args = parser.parse_args() # Model paths gru_model_path = args.gru_model_path lstm_model_path = args.lstm_model_path data_dir = args.data_dir eval_type = args.eval_type gru_weight = args.gru_weight lstm_weight = 1.0 - gru_weight # Load CSV file b2txt_csv_df = pd.read_csv(args.csv_path) # Load model args gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml')) lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml')) print(f'GRU model path: {gru_model_path}') print(f'LSTM model path: {lstm_model_path}') print(f'Data directory: {data_dir}') print(f'Evaluation type: {eval_type}') print(f'GRU weight: {gru_weight:.2f}, LSTM weight: {lstm_weight:.2f}') print() # Set up GPU device gpu_number = args.gpu_number if torch.cuda.is_available() and gpu_number >= 0: if gpu_number >= torch.cuda.device_count(): raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}') device = f'cuda:{gpu_number}' device = torch.device(device) print(f'Using GPU device: {device}') else: device = torch.device('cpu') print('Using CPU device') # Define GRU model gru_model = GRUDecoder( neural_dim=gru_model_args['model']['n_input_features'], n_units=gru_model_args['model']['n_units'], n_days=len(gru_model_args['dataset']['sessions']), n_classes=gru_model_args['dataset']['n_classes'], rnn_dropout=gru_model_args['model']['rnn_dropout'], input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'], n_layers=gru_model_args['model']['n_layers'], patch_size=gru_model_args['model']['patch_size'], patch_stride=gru_model_args['model']['patch_stride'], ) # Load GRU model weights gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) # Rename keys to not start with "module." (happens if model was saved with DataParallel) for key in list(gru_checkpoint['model_state_dict'].keys()): gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key) gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key) gru_model.load_state_dict(gru_checkpoint['model_state_dict']) # Define LSTM model lstm_model = LSTMDecoder( neural_dim=lstm_model_args['model']['n_input_features'], n_units=lstm_model_args['model']['n_units'], n_days=len(lstm_model_args['dataset']['sessions']), n_classes=lstm_model_args['dataset']['n_classes'], rnn_dropout=lstm_model_args['model']['rnn_dropout'], input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'], n_layers=lstm_model_args['model']['n_layers'], patch_size=lstm_model_args['model']['patch_size'], patch_stride=lstm_model_args['model']['patch_stride'], ) # Load LSTM model weights lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) # Rename keys to not start with "module." (happens if model was saved with DataParallel) for key in list(lstm_checkpoint['model_state_dict'].keys()): lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key) lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key) lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) # Add models to device gru_model.to(device) lstm_model.to(device) # Set models to eval mode gru_model.eval() lstm_model.eval() print(f'Loaded GRU model from: {gru_model_path}') print(f'Loaded LSTM model from: {lstm_model_path}') print() def runEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, device, gru_weight, lstm_weight): """ Run ensemble inference without TTA: 1. Apply Gaussian smoothing to input 2. Run both GRU and LSTM models 3. Ensemble model outputs with weights """ # Get smoothing parameters smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size'] # Use autocast for efficiency (disabled for now to avoid dtype issues) # with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16): # Convert to float32 for smoothing operations to avoid dtype mismatch x_float = x.float() # Apply Gaussian smoothing x_smoothed = gauss_smooth( inputs=x_float, device=device, smooth_kernel_std=smooth_std, smooth_kernel_size=smooth_size, padding='valid', ) # Keep as float32 for model inference # x_smoothed = x_smoothed.to(torch.bfloat16) with torch.no_grad(): # Convert to float32 for model inference to avoid einsum dtype mismatch x_smoothed_float = x_smoothed.float() # Get GRU logits gru_logits, _ = gru_model( x=x_smoothed_float, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True, ) # Get LSTM logits lstm_logits, _ = lstm_model( x=x_smoothed_float, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True, ) # 🔧 CORRECTED ENSEMBLE METHOD: Scale Normalized Averaging # 原始问题:GRU方差~7.97, LSTM方差~5.73,直接平均会偏向GRU # 解决方案:方差归一化后再平均 # Convert to numpy for easier manipulation gru_logits_np = gru_logits.float().cpu().numpy()[0] lstm_logits_np = lstm_logits.float().cpu().numpy()[0] # Calculate per-timestep variance for normalization gru_var = np.var(gru_logits_np, axis=-1, keepdims=True) lstm_var = np.var(lstm_logits_np, axis=-1, keepdims=True) # Normalize by standard deviation to equalize scales gru_normalized = gru_logits_np / np.sqrt(gru_var + 1e-8) lstm_normalized = lstm_logits_np / np.sqrt(lstm_var + 1e-8) # Now apply weighted averaging on normalized logits ensemble_logits_np = gru_weight * gru_normalized + lstm_weight * lstm_normalized # Convert back to tensor ensemble_logits = torch.tensor(ensemble_logits_np, device=device, dtype=torch.float32).unsqueeze(0) # Convert logits from bfloat16 to float32 return ensemble_logits.float().cpu().numpy() # Load data for each session test_data = {} total_test_trials = 0 for session in gru_model_args['dataset']['sessions']: files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')] if f'data_{eval_type}.hdf5' in files: eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5') data = load_h5py_file(eval_file, b2txt_csv_df) test_data[session] = data total_test_trials += len(test_data[session]["neural_features"]) print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.') print(f'Total number of {eval_type} trials: {total_test_trials}') print() # Put neural data through the ensemble model to get phoneme predictions (logits) with tqdm(total=total_test_trials, desc=f'Ensemble inference (GRU+LSTM)', unit='trial') as pbar: for session, data in test_data.items(): data['logits'] = [] data['pred_seq'] = [] input_layer = gru_model_args['dataset']['sessions'].index(session) for trial in range(len(data['neural_features'])): # Get neural input for the trial neural_input = data['neural_features'][trial] # Add batch dimension neural_input = np.expand_dims(neural_input, axis=0) # Convert to torch tensor (use float32 to avoid dtype issues) neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32) # Run ensemble decoding step ensemble_logits = runEnsembleDecodingStep( neural_input, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, device, gru_weight, lstm_weight ) data['logits'].append(ensemble_logits) pbar.update(1) pbar.close() # Convert logits to phoneme sequences and calculate PER total_phonemes = 0 total_phoneme_errors = 0 per_results = [] for session, data in test_data.items(): data['pred_seq'] = [] for trial in range(len(data['logits'])): logits = data['logits'][trial][0] pred_seq = np.argmax(logits, axis=-1) # Remove blanks (0) pred_seq = [int(p) for p in pred_seq if p != 0] # Remove consecutive duplicates pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] # Convert to phonemes pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq] # Add to data data['pred_seq'].append(pred_phonemes) # Print out the predicted sequences and calculate PER block_num = data['block_num'][trial] trial_num = data['trial_num'][trial] print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}') if eval_type == 'val': sentence_label = data['sentence_label'][trial] true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]] true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq] # Calculate phoneme error rate (PER) phoneme_errors = editdistance.eval(true_phonemes, pred_phonemes) num_phonemes = len(true_phonemes) per = phoneme_errors / num_phonemes if num_phonemes > 0 else 0 total_phonemes += num_phonemes total_phoneme_errors += phoneme_errors per_results.append({ 'session': session, 'block': block_num, 'trial': trial_num, 'sentence_label': sentence_label, 'true_phonemes': true_phonemes, 'pred_phonemes': pred_phonemes, 'phoneme_errors': phoneme_errors, 'num_phonemes': num_phonemes, 'per': per }) print(f'Sentence label: {sentence_label}') print(f'True phonemes: {" ".join(true_phonemes)}') print(f'Pred phonemes: {" ".join(pred_phonemes)}') print(f'PER: {phoneme_errors} / {num_phonemes} = {100 * per:.2f}%') else: print(f'Pred phonemes: {" ".join(pred_phonemes)}') print() # Calculate and print aggregate PER if using validation set if eval_type == 'val' and total_phonemes > 0: aggregate_per = total_phoneme_errors / total_phonemes print(f'Total phonemes: {total_phonemes}') print(f'Total phoneme errors: {total_phoneme_errors}') print(f'Aggregate Phoneme Error Rate (PER): {100 * aggregate_per:.2f}%') print() # Save results to CSV timestamp = time.strftime("%Y%m%d_%H%M%S") output_file = f'ensemble_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_{eval_type}_{timestamp}.csv' output_path = os.path.join(os.path.dirname(__file__), output_file) if eval_type == 'val': # Save detailed results for validation df_out = pd.DataFrame(per_results) df_out.to_csv(output_path, index=False) print(f'Detailed results saved to: {output_path}') else: # Save only predictions for test set ids = [] pred_phonemes_str = [] for session, data in test_data.items(): for trial in range(len(data['pred_seq'])): ids.append(len(ids)) pred_phonemes_str.append(' '.join(data['pred_seq'][trial])) df_out = pd.DataFrame({'id': ids, 'phonemes': pred_phonemes_str}) df_out.to_csv(output_path, index=False) print(f'Predictions saved to: {output_path}') print(f'Ensemble configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}')