335 lines
14 KiB
Python
335 lines
14 KiB
Python
![]() |
import os
|
|||
|
import sys
|
|||
|
import torch
|
|||
|
import numpy as np
|
|||
|
import pandas as pd
|
|||
|
from omegaconf import OmegaConf
|
|||
|
import time
|
|||
|
from tqdm import tqdm
|
|||
|
import editdistance
|
|||
|
import argparse
|
|||
|
|
|||
|
# Add parent directories to path to import models
|
|||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
|
|||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
|
|||
|
|
|||
|
from model_training.rnn_model import GRUDecoder
|
|||
|
from model_training_lstm.rnn_model import LSTMDecoder
|
|||
|
from model_training.evaluate_model_helpers import *
|
|||
|
|
|||
|
# argument parser for command line arguments
|
|||
|
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models (without TTA) on the copy task dataset.')
|
|||
|
parser.add_argument('--gru_model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
|
|||
|
help='Path to the pretrained GRU model directory.')
|
|||
|
parser.add_argument('--lstm_model_path', type=str, default='../model_training_lstm/trained_models/baseline_rnn',
|
|||
|
help='Path to the pretrained LSTM model directory.')
|
|||
|
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
|||
|
help='Path to the dataset directory (relative to the current working directory).')
|
|||
|
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
|
|||
|
help='Evaluation type: "val" for validation set, "test" for test set.')
|
|||
|
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
|||
|
help='Path to the CSV file with metadata about the dataset.')
|
|||
|
parser.add_argument('--gpu_number', type=int, default=0,
|
|||
|
help='GPU number to use for model inference. Set to -1 to use CPU.')
|
|||
|
parser.add_argument('--gru_weight', type=float, default=0.5,
|
|||
|
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight).')
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
# Model paths
|
|||
|
gru_model_path = args.gru_model_path
|
|||
|
lstm_model_path = args.lstm_model_path
|
|||
|
data_dir = args.data_dir
|
|||
|
eval_type = args.eval_type
|
|||
|
gru_weight = args.gru_weight
|
|||
|
lstm_weight = 1.0 - gru_weight
|
|||
|
|
|||
|
# Load CSV file
|
|||
|
b2txt_csv_df = pd.read_csv(args.csv_path)
|
|||
|
|
|||
|
# Load model args
|
|||
|
gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml'))
|
|||
|
lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml'))
|
|||
|
|
|||
|
print(f'GRU model path: {gru_model_path}')
|
|||
|
print(f'LSTM model path: {lstm_model_path}')
|
|||
|
print(f'Data directory: {data_dir}')
|
|||
|
print(f'Evaluation type: {eval_type}')
|
|||
|
print(f'GRU weight: {gru_weight:.2f}, LSTM weight: {lstm_weight:.2f}')
|
|||
|
print()
|
|||
|
|
|||
|
# Set up GPU device
|
|||
|
gpu_number = args.gpu_number
|
|||
|
if torch.cuda.is_available() and gpu_number >= 0:
|
|||
|
if gpu_number >= torch.cuda.device_count():
|
|||
|
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
|
|||
|
device = f'cuda:{gpu_number}'
|
|||
|
device = torch.device(device)
|
|||
|
print(f'Using GPU device: {device}')
|
|||
|
else:
|
|||
|
device = torch.device('cpu')
|
|||
|
print('Using CPU device')
|
|||
|
|
|||
|
# Define GRU model
|
|||
|
gru_model = GRUDecoder(
|
|||
|
neural_dim=gru_model_args['model']['n_input_features'],
|
|||
|
n_units=gru_model_args['model']['n_units'],
|
|||
|
n_days=len(gru_model_args['dataset']['sessions']),
|
|||
|
n_classes=gru_model_args['dataset']['n_classes'],
|
|||
|
rnn_dropout=gru_model_args['model']['rnn_dropout'],
|
|||
|
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
|
|||
|
n_layers=gru_model_args['model']['n_layers'],
|
|||
|
patch_size=gru_model_args['model']['patch_size'],
|
|||
|
patch_stride=gru_model_args['model']['patch_stride'],
|
|||
|
)
|
|||
|
|
|||
|
# Load GRU model weights
|
|||
|
gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'),
|
|||
|
weights_only=False, map_location=device)
|
|||
|
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
|
|||
|
for key in list(gru_checkpoint['model_state_dict'].keys()):
|
|||
|
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
|||
|
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
|||
|
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
|
|||
|
|
|||
|
# Define LSTM model
|
|||
|
lstm_model = LSTMDecoder(
|
|||
|
neural_dim=lstm_model_args['model']['n_input_features'],
|
|||
|
n_units=lstm_model_args['model']['n_units'],
|
|||
|
n_days=len(lstm_model_args['dataset']['sessions']),
|
|||
|
n_classes=lstm_model_args['dataset']['n_classes'],
|
|||
|
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
|
|||
|
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
|
|||
|
n_layers=lstm_model_args['model']['n_layers'],
|
|||
|
patch_size=lstm_model_args['model']['patch_size'],
|
|||
|
patch_stride=lstm_model_args['model']['patch_stride'],
|
|||
|
)
|
|||
|
|
|||
|
# Load LSTM model weights
|
|||
|
lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'),
|
|||
|
weights_only=False, map_location=device)
|
|||
|
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
|
|||
|
for key in list(lstm_checkpoint['model_state_dict'].keys()):
|
|||
|
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
|||
|
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
|||
|
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
|
|||
|
|
|||
|
# Add models to device
|
|||
|
gru_model.to(device)
|
|||
|
lstm_model.to(device)
|
|||
|
|
|||
|
# Set models to eval mode
|
|||
|
gru_model.eval()
|
|||
|
lstm_model.eval()
|
|||
|
|
|||
|
print(f'Loaded GRU model from: {gru_model_path}')
|
|||
|
print(f'Loaded LSTM model from: {lstm_model_path}')
|
|||
|
print()
|
|||
|
|
|||
|
def runEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
|
|||
|
device, gru_weight, lstm_weight):
|
|||
|
"""
|
|||
|
Run ensemble inference without TTA:
|
|||
|
1. Apply Gaussian smoothing to input
|
|||
|
2. Run both GRU and LSTM models
|
|||
|
3. Ensemble model outputs with weights
|
|||
|
"""
|
|||
|
# Get smoothing parameters
|
|||
|
smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
|
|||
|
smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
|
|||
|
|
|||
|
# Use autocast for efficiency (disabled for now to avoid dtype issues)
|
|||
|
# with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
|
|||
|
|
|||
|
# Convert to float32 for smoothing operations to avoid dtype mismatch
|
|||
|
x_float = x.float()
|
|||
|
|
|||
|
# Apply Gaussian smoothing
|
|||
|
x_smoothed = gauss_smooth(
|
|||
|
inputs=x_float,
|
|||
|
device=device,
|
|||
|
smooth_kernel_std=smooth_std,
|
|||
|
smooth_kernel_size=smooth_size,
|
|||
|
padding='valid',
|
|||
|
)
|
|||
|
|
|||
|
# Keep as float32 for model inference
|
|||
|
# x_smoothed = x_smoothed.to(torch.bfloat16)
|
|||
|
|
|||
|
with torch.no_grad():
|
|||
|
# Convert to float32 for model inference to avoid einsum dtype mismatch
|
|||
|
x_smoothed_float = x_smoothed.float()
|
|||
|
|
|||
|
# Get GRU logits
|
|||
|
gru_logits, _ = gru_model(
|
|||
|
x=x_smoothed_float,
|
|||
|
day_idx=torch.tensor([input_layer], device=device),
|
|||
|
states=None,
|
|||
|
return_state=True,
|
|||
|
)
|
|||
|
|
|||
|
# Get LSTM logits
|
|||
|
lstm_logits, _ = lstm_model(
|
|||
|
x=x_smoothed_float,
|
|||
|
day_idx=torch.tensor([input_layer], device=device),
|
|||
|
states=None,
|
|||
|
return_state=True,
|
|||
|
)
|
|||
|
|
|||
|
# 🔧 CORRECTED ENSEMBLE METHOD: Scale Normalized Averaging
|
|||
|
# 原始问题:GRU方差~7.97, LSTM方差~5.73,直接平均会偏向GRU
|
|||
|
# 解决方案:方差归一化后再平均
|
|||
|
|
|||
|
# Convert to numpy for easier manipulation
|
|||
|
gru_logits_np = gru_logits.float().cpu().numpy()[0]
|
|||
|
lstm_logits_np = lstm_logits.float().cpu().numpy()[0]
|
|||
|
|
|||
|
# Calculate per-timestep variance for normalization
|
|||
|
gru_var = np.var(gru_logits_np, axis=-1, keepdims=True)
|
|||
|
lstm_var = np.var(lstm_logits_np, axis=-1, keepdims=True)
|
|||
|
|
|||
|
# Normalize by standard deviation to equalize scales
|
|||
|
gru_normalized = gru_logits_np / np.sqrt(gru_var + 1e-8)
|
|||
|
lstm_normalized = lstm_logits_np / np.sqrt(lstm_var + 1e-8)
|
|||
|
|
|||
|
# Now apply weighted averaging on normalized logits
|
|||
|
ensemble_logits_np = gru_weight * gru_normalized + lstm_weight * lstm_normalized
|
|||
|
|
|||
|
# Convert back to tensor
|
|||
|
ensemble_logits = torch.tensor(ensemble_logits_np, device=device, dtype=torch.float32).unsqueeze(0)
|
|||
|
|
|||
|
# Convert logits from bfloat16 to float32
|
|||
|
return ensemble_logits.float().cpu().numpy()
|
|||
|
|
|||
|
# Load data for each session
|
|||
|
test_data = {}
|
|||
|
total_test_trials = 0
|
|||
|
for session in gru_model_args['dataset']['sessions']:
|
|||
|
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
|
|||
|
if f'data_{eval_type}.hdf5' in files:
|
|||
|
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
|
|||
|
|
|||
|
data = load_h5py_file(eval_file, b2txt_csv_df)
|
|||
|
test_data[session] = data
|
|||
|
|
|||
|
total_test_trials += len(test_data[session]["neural_features"])
|
|||
|
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
|
|||
|
print(f'Total number of {eval_type} trials: {total_test_trials}')
|
|||
|
print()
|
|||
|
|
|||
|
# Put neural data through the ensemble model to get phoneme predictions (logits)
|
|||
|
with tqdm(total=total_test_trials, desc=f'Ensemble inference (GRU+LSTM)', unit='trial') as pbar:
|
|||
|
for session, data in test_data.items():
|
|||
|
|
|||
|
data['logits'] = []
|
|||
|
data['pred_seq'] = []
|
|||
|
input_layer = gru_model_args['dataset']['sessions'].index(session)
|
|||
|
|
|||
|
for trial in range(len(data['neural_features'])):
|
|||
|
# Get neural input for the trial
|
|||
|
neural_input = data['neural_features'][trial]
|
|||
|
|
|||
|
# Add batch dimension
|
|||
|
neural_input = np.expand_dims(neural_input, axis=0)
|
|||
|
|
|||
|
# Convert to torch tensor (use float32 to avoid dtype issues)
|
|||
|
neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
|
|||
|
|
|||
|
# Run ensemble decoding step
|
|||
|
ensemble_logits = runEnsembleDecodingStep(
|
|||
|
neural_input, input_layer, gru_model, lstm_model,
|
|||
|
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight
|
|||
|
)
|
|||
|
data['logits'].append(ensemble_logits)
|
|||
|
|
|||
|
pbar.update(1)
|
|||
|
pbar.close()
|
|||
|
|
|||
|
# Convert logits to phoneme sequences and calculate PER
|
|||
|
total_phonemes = 0
|
|||
|
total_phoneme_errors = 0
|
|||
|
per_results = []
|
|||
|
|
|||
|
for session, data in test_data.items():
|
|||
|
data['pred_seq'] = []
|
|||
|
for trial in range(len(data['logits'])):
|
|||
|
logits = data['logits'][trial][0]
|
|||
|
pred_seq = np.argmax(logits, axis=-1)
|
|||
|
# Remove blanks (0)
|
|||
|
pred_seq = [int(p) for p in pred_seq if p != 0]
|
|||
|
# Remove consecutive duplicates
|
|||
|
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
|
|||
|
# Convert to phonemes
|
|||
|
pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq]
|
|||
|
# Add to data
|
|||
|
data['pred_seq'].append(pred_phonemes)
|
|||
|
|
|||
|
# Print out the predicted sequences and calculate PER
|
|||
|
block_num = data['block_num'][trial]
|
|||
|
trial_num = data['trial_num'][trial]
|
|||
|
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
|
|||
|
|
|||
|
if eval_type == 'val':
|
|||
|
sentence_label = data['sentence_label'][trial]
|
|||
|
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
|
|||
|
true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq]
|
|||
|
|
|||
|
# Calculate phoneme error rate (PER)
|
|||
|
phoneme_errors = editdistance.eval(true_phonemes, pred_phonemes)
|
|||
|
num_phonemes = len(true_phonemes)
|
|||
|
per = phoneme_errors / num_phonemes if num_phonemes > 0 else 0
|
|||
|
|
|||
|
total_phonemes += num_phonemes
|
|||
|
total_phoneme_errors += phoneme_errors
|
|||
|
|
|||
|
per_results.append({
|
|||
|
'session': session,
|
|||
|
'block': block_num,
|
|||
|
'trial': trial_num,
|
|||
|
'sentence_label': sentence_label,
|
|||
|
'true_phonemes': true_phonemes,
|
|||
|
'pred_phonemes': pred_phonemes,
|
|||
|
'phoneme_errors': phoneme_errors,
|
|||
|
'num_phonemes': num_phonemes,
|
|||
|
'per': per
|
|||
|
})
|
|||
|
|
|||
|
print(f'Sentence label: {sentence_label}')
|
|||
|
print(f'True phonemes: {" ".join(true_phonemes)}')
|
|||
|
print(f'Pred phonemes: {" ".join(pred_phonemes)}')
|
|||
|
print(f'PER: {phoneme_errors} / {num_phonemes} = {100 * per:.2f}%')
|
|||
|
else:
|
|||
|
print(f'Pred phonemes: {" ".join(pred_phonemes)}')
|
|||
|
print()
|
|||
|
|
|||
|
# Calculate and print aggregate PER if using validation set
|
|||
|
if eval_type == 'val' and total_phonemes > 0:
|
|||
|
aggregate_per = total_phoneme_errors / total_phonemes
|
|||
|
print(f'Total phonemes: {total_phonemes}')
|
|||
|
print(f'Total phoneme errors: {total_phoneme_errors}')
|
|||
|
print(f'Aggregate Phoneme Error Rate (PER): {100 * aggregate_per:.2f}%')
|
|||
|
print()
|
|||
|
|
|||
|
# Save results to CSV
|
|||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|||
|
output_file = f'ensemble_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_{eval_type}_{timestamp}.csv'
|
|||
|
output_path = os.path.join(os.path.dirname(__file__), output_file)
|
|||
|
|
|||
|
if eval_type == 'val':
|
|||
|
# Save detailed results for validation
|
|||
|
df_out = pd.DataFrame(per_results)
|
|||
|
df_out.to_csv(output_path, index=False)
|
|||
|
print(f'Detailed results saved to: {output_path}')
|
|||
|
else:
|
|||
|
# Save only predictions for test set
|
|||
|
ids = []
|
|||
|
pred_phonemes_str = []
|
|||
|
for session, data in test_data.items():
|
|||
|
for trial in range(len(data['pred_seq'])):
|
|||
|
ids.append(len(ids))
|
|||
|
pred_phonemes_str.append(' '.join(data['pred_seq'][trial]))
|
|||
|
|
|||
|
df_out = pd.DataFrame({'id': ids, 'phonemes': pred_phonemes_str})
|
|||
|
df_out.to_csv(output_path, index=False)
|
|||
|
print(f'Predictions saved to: {output_path}')
|
|||
|
|
|||
|
print(f'Ensemble configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}')
|