Files
b2txt25/TTA-E/temp.py

335 lines
14 KiB
Python
Raw Normal View History

2025-10-06 15:17:44 +08:00
import os
import sys
import torch
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
# Add parent directories to path to import models
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
from model_training.rnn_model import GRUDecoder
from model_training_lstm.rnn_model import LSTMDecoder
from model_training.evaluate_model_helpers import *
# argument parser for command line arguments
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models (without TTA) on the copy task dataset.')
parser.add_argument('--gru_model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
help='Path to the pretrained GRU model directory.')
parser.add_argument('--lstm_model_path', type=str, default='../model_training_lstm/trained_models/baseline_rnn',
help='Path to the pretrained LSTM model directory.')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset.')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU number to use for model inference. Set to -1 to use CPU.')
parser.add_argument('--gru_weight', type=float, default=0.5,
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight).')
args = parser.parse_args()
# Model paths
gru_model_path = args.gru_model_path
lstm_model_path = args.lstm_model_path
data_dir = args.data_dir
eval_type = args.eval_type
gru_weight = args.gru_weight
lstm_weight = 1.0 - gru_weight
# Load CSV file
b2txt_csv_df = pd.read_csv(args.csv_path)
# Load model args
gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml'))
lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml'))
print(f'GRU model path: {gru_model_path}')
print(f'LSTM model path: {lstm_model_path}')
print(f'Data directory: {data_dir}')
print(f'Evaluation type: {eval_type}')
print(f'GRU weight: {gru_weight:.2f}, LSTM weight: {lstm_weight:.2f}')
print()
# Set up GPU device
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
if gpu_number >= torch.cuda.device_count():
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
device = f'cuda:{gpu_number}'
device = torch.device(device)
print(f'Using GPU device: {device}')
else:
device = torch.device('cpu')
print('Using CPU device')
# Define GRU model
gru_model = GRUDecoder(
neural_dim=gru_model_args['model']['n_input_features'],
n_units=gru_model_args['model']['n_units'],
n_days=len(gru_model_args['dataset']['sessions']),
n_classes=gru_model_args['dataset']['n_classes'],
rnn_dropout=gru_model_args['model']['rnn_dropout'],
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
n_layers=gru_model_args['model']['n_layers'],
patch_size=gru_model_args['model']['patch_size'],
patch_stride=gru_model_args['model']['patch_stride'],
)
# Load GRU model weights
gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(gru_checkpoint['model_state_dict'].keys()):
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
# Define LSTM model
lstm_model = LSTMDecoder(
neural_dim=lstm_model_args['model']['n_input_features'],
n_units=lstm_model_args['model']['n_units'],
n_days=len(lstm_model_args['dataset']['sessions']),
n_classes=lstm_model_args['dataset']['n_classes'],
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
n_layers=lstm_model_args['model']['n_layers'],
patch_size=lstm_model_args['model']['patch_size'],
patch_stride=lstm_model_args['model']['patch_stride'],
)
# Load LSTM model weights
lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(lstm_checkpoint['model_state_dict'].keys()):
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
# Add models to device
gru_model.to(device)
lstm_model.to(device)
# Set models to eval mode
gru_model.eval()
lstm_model.eval()
print(f'Loaded GRU model from: {gru_model_path}')
print(f'Loaded LSTM model from: {lstm_model_path}')
print()
def runEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
device, gru_weight, lstm_weight):
"""
Run ensemble inference without TTA:
1. Apply Gaussian smoothing to input
2. Run both GRU and LSTM models
3. Ensemble model outputs with weights
"""
# Get smoothing parameters
smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
# Use autocast for efficiency (disabled for now to avoid dtype issues)
# with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
# Convert to float32 for smoothing operations to avoid dtype mismatch
x_float = x.float()
# Apply Gaussian smoothing
x_smoothed = gauss_smooth(
inputs=x_float,
device=device,
smooth_kernel_std=smooth_std,
smooth_kernel_size=smooth_size,
padding='valid',
)
# Keep as float32 for model inference
# x_smoothed = x_smoothed.to(torch.bfloat16)
with torch.no_grad():
# Convert to float32 for model inference to avoid einsum dtype mismatch
x_smoothed_float = x_smoothed.float()
# Get GRU logits
gru_logits, _ = gru_model(
x=x_smoothed_float,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
# Get LSTM logits
lstm_logits, _ = lstm_model(
x=x_smoothed_float,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
# 🔧 CORRECTED ENSEMBLE METHOD: Scale Normalized Averaging
# 原始问题GRU方差~7.97, LSTM方差~5.73直接平均会偏向GRU
# 解决方案:方差归一化后再平均
# Convert to numpy for easier manipulation
gru_logits_np = gru_logits.float().cpu().numpy()[0]
lstm_logits_np = lstm_logits.float().cpu().numpy()[0]
# Calculate per-timestep variance for normalization
gru_var = np.var(gru_logits_np, axis=-1, keepdims=True)
lstm_var = np.var(lstm_logits_np, axis=-1, keepdims=True)
# Normalize by standard deviation to equalize scales
gru_normalized = gru_logits_np / np.sqrt(gru_var + 1e-8)
lstm_normalized = lstm_logits_np / np.sqrt(lstm_var + 1e-8)
# Now apply weighted averaging on normalized logits
ensemble_logits_np = gru_weight * gru_normalized + lstm_weight * lstm_normalized
# Convert back to tensor
ensemble_logits = torch.tensor(ensemble_logits_np, device=device, dtype=torch.float32).unsqueeze(0)
# Convert logits from bfloat16 to float32
return ensemble_logits.float().cpu().numpy()
# Load data for each session
test_data = {}
total_test_trials = 0
for session in gru_model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
print(f'Total number of {eval_type} trials: {total_test_trials}')
print()
# Put neural data through the ensemble model to get phoneme predictions (logits)
with tqdm(total=total_test_trials, desc=f'Ensemble inference (GRU+LSTM)', unit='trial') as pbar:
for session, data in test_data.items():
data['logits'] = []
data['pred_seq'] = []
input_layer = gru_model_args['dataset']['sessions'].index(session)
for trial in range(len(data['neural_features'])):
# Get neural input for the trial
neural_input = data['neural_features'][trial]
# Add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# Convert to torch tensor (use float32 to avoid dtype issues)
neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
# Run ensemble decoding step
ensemble_logits = runEnsembleDecodingStep(
neural_input, input_layer, gru_model, lstm_model,
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight
)
data['logits'].append(ensemble_logits)
pbar.update(1)
pbar.close()
# Convert logits to phoneme sequences and calculate PER
total_phonemes = 0
total_phoneme_errors = 0
per_results = []
for session, data in test_data.items():
data['pred_seq'] = []
for trial in range(len(data['logits'])):
logits = data['logits'][trial][0]
pred_seq = np.argmax(logits, axis=-1)
# Remove blanks (0)
pred_seq = [int(p) for p in pred_seq if p != 0]
# Remove consecutive duplicates
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
# Convert to phonemes
pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq]
# Add to data
data['pred_seq'].append(pred_phonemes)
# Print out the predicted sequences and calculate PER
block_num = data['block_num'][trial]
trial_num = data['trial_num'][trial]
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
if eval_type == 'val':
sentence_label = data['sentence_label'][trial]
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq]
# Calculate phoneme error rate (PER)
phoneme_errors = editdistance.eval(true_phonemes, pred_phonemes)
num_phonemes = len(true_phonemes)
per = phoneme_errors / num_phonemes if num_phonemes > 0 else 0
total_phonemes += num_phonemes
total_phoneme_errors += phoneme_errors
per_results.append({
'session': session,
'block': block_num,
'trial': trial_num,
'sentence_label': sentence_label,
'true_phonemes': true_phonemes,
'pred_phonemes': pred_phonemes,
'phoneme_errors': phoneme_errors,
'num_phonemes': num_phonemes,
'per': per
})
print(f'Sentence label: {sentence_label}')
print(f'True phonemes: {" ".join(true_phonemes)}')
print(f'Pred phonemes: {" ".join(pred_phonemes)}')
print(f'PER: {phoneme_errors} / {num_phonemes} = {100 * per:.2f}%')
else:
print(f'Pred phonemes: {" ".join(pred_phonemes)}')
print()
# Calculate and print aggregate PER if using validation set
if eval_type == 'val' and total_phonemes > 0:
aggregate_per = total_phoneme_errors / total_phonemes
print(f'Total phonemes: {total_phonemes}')
print(f'Total phoneme errors: {total_phoneme_errors}')
print(f'Aggregate Phoneme Error Rate (PER): {100 * aggregate_per:.2f}%')
print()
# Save results to CSV
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_file = f'ensemble_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_{eval_type}_{timestamp}.csv'
output_path = os.path.join(os.path.dirname(__file__), output_file)
if eval_type == 'val':
# Save detailed results for validation
df_out = pd.DataFrame(per_results)
df_out.to_csv(output_path, index=False)
print(f'Detailed results saved to: {output_path}')
else:
# Save only predictions for test set
ids = []
pred_phonemes_str = []
for session, data in test_data.items():
for trial in range(len(data['pred_seq'])):
ids.append(len(ids))
pred_phonemes_str.append(' '.join(data['pred_seq'][trial]))
df_out = pd.DataFrame({'id': ids, 'phonemes': pred_phonemes_str})
df_out.to_csv(output_path, index=False)
print(f'Predictions saved to: {output_path}')
print(f'Ensemble configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}')