Files
b2txt25/TTA-E/.ipynb_checkpoints/evaluate_model-checkpoint.py
2025-10-06 15:17:44 +08:00

452 lines
20 KiB
Python

import os
import sys
import torch
import numpy as np
import pandas as pd
import redis
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
# Add parent directories to path to import models
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
from model_training.rnn_model import GRUDecoder
from model_training_lstm.rnn_model import LSTMDecoder
from model_training.evaluate_model_helpers import *
# argument parser for command line arguments
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
help='Path to the pretrained GRU model directory.')
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
help='Path to the pretrained LSTM model directory.')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset.')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU number to use for model inference. Set to -1 to use CPU.')
parser.add_argument('--gru_weight', type=float, default=0.5,
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight).')
# TTA parameters
parser.add_argument('--tta_samples', type=int, default=5,
help='Number of TTA augmentation samples per trial.')
parser.add_argument('--tta_noise_std', type=float, default=0.01,
help='Standard deviation for TTA noise augmentation.')
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
help='Range for TTA smoothing kernel variation (±range from default).')
parser.add_argument('--tta_scale_range', type=float, default=0.05,
help='Range for TTA amplitude scaling (±range from 1.0).')
parser.add_argument('--tta_cut_max', type=int, default=3,
help='Maximum number of timesteps to cut from beginning in TTA.')
args = parser.parse_args()
# Model paths
gru_model_path = args.gru_model_path
lstm_model_path = args.lstm_model_path
data_dir = args.data_dir
# Ensemble weights
gru_weight = args.gru_weight
lstm_weight = 1.0 - gru_weight
# TTA parameters
tta_samples = args.tta_samples
tta_noise_std = args.tta_noise_std
tta_smooth_range = args.tta_smooth_range
tta_scale_range = args.tta_scale_range
tta_cut_max = args.tta_cut_max
print(f"TTA-E Configuration:")
print(f"GRU weight: {gru_weight:.2f}")
print(f"LSTM weight: {lstm_weight:.2f}")
print(f"TTA samples per trial: {tta_samples}")
print(f"TTA noise std: {tta_noise_std}")
print(f"TTA smooth range: ±{tta_smooth_range}")
print(f"TTA scale range: ±{tta_scale_range}")
print(f"TTA max cut: {tta_cut_max} timesteps")
print(f"GRU model path: {gru_model_path}")
print(f"LSTM model path: {lstm_model_path}")
print()
# Define evaluation type
eval_type = args.eval_type
# Load CSV file
b2txt_csv_df = pd.read_csv(args.csv_path)
# Load model arguments for both models
gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml'))
lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml'))
# Set up GPU device
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
if gpu_number >= torch.cuda.device_count():
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
device = f'cuda:{gpu_number}'
device = torch.device(device)
print(f'Using {device} for model inference.')
else:
if gpu_number >= 0:
print(f'GPU number {gpu_number} requested but not available.')
print('Using CPU for model inference.')
device = torch.device('cpu')
# Define GRU model
gru_model = GRUDecoder(
neural_dim=gru_model_args['model']['n_input_features'],
n_units=gru_model_args['model']['n_units'],
n_days=len(gru_model_args['dataset']['sessions']),
n_classes=gru_model_args['dataset']['n_classes'],
rnn_dropout=gru_model_args['model']['rnn_dropout'],
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
n_layers=gru_model_args['model']['n_layers'],
patch_size=gru_model_args['model']['patch_size'],
patch_stride=gru_model_args['model']['patch_stride'],
)
# Load GRU model weights
gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(gru_checkpoint['model_state_dict'].keys()):
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
# Define LSTM model
lstm_model = LSTMDecoder(
neural_dim=lstm_model_args['model']['n_input_features'],
n_units=lstm_model_args['model']['n_units'],
n_days=len(lstm_model_args['dataset']['sessions']),
n_classes=lstm_model_args['dataset']['n_classes'],
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
n_layers=lstm_model_args['model']['n_layers'],
patch_size=lstm_model_args['model']['patch_size'],
patch_stride=lstm_model_args['model']['patch_stride'],
)
# Load LSTM model weights
lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(lstm_checkpoint['model_state_dict'].keys()):
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
# Add models to device
gru_model.to(device)
lstm_model.to(device)
# Set models to eval mode
gru_model.eval()
lstm_model.eval()
print("Both models loaded successfully!")
print()
# TTA-E inference function
def runTTAEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
device, gru_weight, lstm_weight, tta_samples, tta_noise_std,
tta_smooth_range, tta_scale_range, tta_cut_max):
"""
Run TTA-E (Test Time Augmentation + Ensemble) inference:
1. Apply multiple data augmentations to each input
2. Run both GRU and LSTM models on each augmented version
3. Ensemble model outputs with weights
4. Average across all TTA samples
"""
all_ensemble_logits = []
# Get default smoothing parameters
default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
for tta_iter in range(tta_samples):
# Apply different augmentation strategies
x_augmented = x.clone()
if tta_iter == 0:
# Original data (baseline)
augmentation_type = "original"
elif tta_iter == 1:
# Add Gaussian noise
noise = torch.randn_like(x_augmented) * tta_noise_std
x_augmented = x_augmented + noise
augmentation_type = f"noise_std_{tta_noise_std}"
elif tta_iter == 2:
# Amplitude scaling
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range
x_augmented = x_augmented * scale_factor
augmentation_type = f"scale_{scale_factor:.3f}"
elif tta_iter == 3 and tta_cut_max > 0:
# Time shift (circular shift instead of cutting to maintain length)
shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8))
# Circular shift: move beginning to end
x_augmented = torch.cat([x_augmented[:, shift_amount:, :],
x_augmented[:, :shift_amount, :]], dim=1)
augmentation_type = f"shift_{shift_amount}"
else:
# Smoothing variation
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range
varied_smooth_std = max(0.5, default_smooth_std + smooth_variation)
augmentation_type = f"smooth_std_{varied_smooth_std:.2f}"
# Use autocast for efficiency
with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
# Apply Gaussian smoothing with potentially varied parameters
if tta_iter < 4 or tta_iter == 0:
# Use default smoothing for most augmentations
x_smoothed = gauss_smooth(
inputs=x_augmented,
device=device,
smooth_kernel_std=default_smooth_std,
smooth_kernel_size=default_smooth_size,
padding='valid',
)
else:
# Use varied smoothing
x_smoothed = gauss_smooth(
inputs=x_augmented,
device=device,
smooth_kernel_std=varied_smooth_std,
smooth_kernel_size=default_smooth_size,
padding='valid',
)
with torch.no_grad():
# Get GRU logits
gru_logits, _ = gru_model(
x=x_smoothed,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
# Get LSTM logits
lstm_logits, _ = lstm_model(
x=x_smoothed,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
# Ensemble using weighted averaging
ensemble_logits = gru_weight * gru_logits + lstm_weight * lstm_logits
all_ensemble_logits.append(ensemble_logits)
# TTA fusion: Handle potentially different tensor shapes by finding minimum length
if len(all_ensemble_logits) > 1:
# Find the minimum sequence length among all TTA samples
min_length = min([logits.shape[1] for logits in all_ensemble_logits])
# Truncate all tensors to the minimum length
truncated_logits = []
for logits in all_ensemble_logits:
if logits.shape[1] > min_length:
truncated_logits.append(logits[:, :min_length, :])
else:
truncated_logits.append(logits)
# Now stack and average
final_logits = torch.mean(torch.stack(truncated_logits), dim=0)
else:
final_logits = all_ensemble_logits[0]
# Convert logits from bfloat16 to float32
return final_logits.float().cpu().numpy()
# Load data for each session (using GRU model args as reference since they should be compatible)
test_data = {}
total_test_trials = 0
for session in gru_model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
print(f'Total number of {eval_type} trials: {total_test_trials}')
print()
# Put neural data through the TTA-E ensemble model to get phoneme predictions (logits)
with tqdm(total=total_test_trials, desc=f'TTA-E inference ({tta_samples} samples/trial)', unit='trial') as pbar:
for session, data in test_data.items():
data['logits'] = []
data['pred_seq'] = []
input_layer = gru_model_args['dataset']['sessions'].index(session)
for trial in range(len(data['neural_features'])):
# Get neural input for the trial
neural_input = data['neural_features'][trial]
# Add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# Convert to torch tensor
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
# Run TTA-E decoding step
ensemble_logits = runTTAEnsembleDecodingStep(
neural_input, input_layer, gru_model, lstm_model,
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
tta_samples, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
)
data['logits'].append(ensemble_logits)
pbar.update(1)
pbar.close()
# Convert logits to phoneme sequences and print them out
for session, data in test_data.items():
data['pred_seq'] = []
for trial in range(len(data['logits'])):
logits = data['logits'][trial][0]
pred_seq = np.argmax(logits, axis=-1)
# Remove blanks (0)
pred_seq = [int(p) for p in pred_seq if p != 0]
# Remove consecutive duplicates
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
# Convert to phonemes
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
# Add to data
data['pred_seq'].append(pred_seq)
# Print out the predicted sequences
block_num = data['block_num'][trial]
trial_num = data['trial_num'][trial]
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
if eval_type == 'val':
sentence_label = data['sentence_label'][trial]
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
print(f'Sentence label: {sentence_label}')
print(f'True sequence: {" ".join(true_seq)}')
print(f'Predicted Sequence: {" ".join(pred_seq)}')
print()
# Language model inference via redis
# Make sure that the standalone language model is running on the localhost redis ip
# See README.md for instructions on how to run the language model
r = redis.Redis(host='localhost', port=6379, db=0)
r.flushall() # Clear all streams in redis
# Define redis streams for the remote language model
remote_lm_input_stream = 'remote_lm_input'
remote_lm_output_partial_stream = 'remote_lm_output_partial'
remote_lm_output_final_stream = 'remote_lm_output_final'
# Set timestamps for last entries seen in the redis streams
remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r)
lm_results = {
'session': [],
'block': [],
'trial': [],
'true_sentence': [],
'pred_sentence': [],
}
# Loop through all trials and put logits into the remote language model to get text predictions
with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar:
for session in test_data.keys():
for trial in range(len(test_data[session]['logits'])):
# Get trial logits and rearrange them for the LM
logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0]
# Reset language model
remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen)
# Put logits into LM
remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm(
r,
remote_lm_input_stream,
remote_lm_output_partial_stream,
remote_lm_output_partial_lastEntrySeen,
logits,
)
# Finalize remote LM
remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm(
r,
remote_lm_output_final_stream,
remote_lm_output_final_lastEntrySeen,
)
# Get the best candidate sentence
best_candidate_sentence = lm_out['candidate_sentences'][0]
# Store results
lm_results['session'].append(session)
lm_results['block'].append(test_data[session]['block_num'][trial])
lm_results['trial'].append(test_data[session]['trial_num'][trial])
if eval_type == 'val':
lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial])
else:
lm_results['true_sentence'].append(None)
lm_results['pred_sentence'].append(best_candidate_sentence)
# Update progress bar
pbar.update(1)
pbar.close()
# If using the validation set, calculate the aggregate word error rate (WER)
if eval_type == 'val':
total_true_length = 0
total_edit_distance = 0
lm_results['edit_distance'] = []
lm_results['num_words'] = []
for i in range(len(lm_results['pred_sentence'])):
true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip()
pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip()
ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
total_true_length += len(true_sentence.split())
total_edit_distance += ed
lm_results['edit_distance'].append(ed)
lm_results['num_words'].append(len(true_sentence.split()))
print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}')
print(f'True sentence: {true_sentence}')
print(f'Predicted sentence: {pred_sentence}')
print(f'WER: {ed} / {len(true_sentence.split())} = {100 * ed / len(true_sentence.split()):.2f}%')
print()
print(f'Total true sentence length: {total_true_length}')
print(f'Total edit distance: {total_edit_distance}')
print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%')
# Write predicted sentences to a CSV file with timestamp and TTA-E info
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_file = f'TTA-E_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_samples{tta_samples}_{eval_type}_{timestamp}.csv'
output_path = os.path.join(os.path.dirname(__file__), output_file)
ids = [i for i in range(len(lm_results['pred_sentence']))]
df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']})
df_out.to_csv(output_path, index=False)
print(f'\nResults saved to: {output_path}')
print(f'TTA-E configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}, TTA samples = {tta_samples}')