452 lines
20 KiB
Python
452 lines
20 KiB
Python
import os
|
|
import sys
|
|
import torch
|
|
import numpy as np
|
|
import pandas as pd
|
|
import redis
|
|
from omegaconf import OmegaConf
|
|
import time
|
|
from tqdm import tqdm
|
|
import editdistance
|
|
import argparse
|
|
|
|
# Add parent directories to path to import models
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
|
|
|
|
from model_training.rnn_model import GRUDecoder
|
|
from model_training_lstm.rnn_model import LSTMDecoder
|
|
from model_training.evaluate_model_helpers import *
|
|
|
|
# argument parser for command line arguments
|
|
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
|
|
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
|
|
help='Path to the pretrained GRU model directory.')
|
|
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
|
|
help='Path to the pretrained LSTM model directory.')
|
|
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
|
help='Path to the dataset directory (relative to the current working directory).')
|
|
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
|
|
help='Evaluation type: "val" for validation set, "test" for test set.')
|
|
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
|
help='Path to the CSV file with metadata about the dataset.')
|
|
parser.add_argument('--gpu_number', type=int, default=0,
|
|
help='GPU number to use for model inference. Set to -1 to use CPU.')
|
|
parser.add_argument('--gru_weight', type=float, default=0.5,
|
|
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight).')
|
|
# TTA parameters
|
|
parser.add_argument('--tta_samples', type=int, default=5,
|
|
help='Number of TTA augmentation samples per trial.')
|
|
parser.add_argument('--tta_noise_std', type=float, default=0.01,
|
|
help='Standard deviation for TTA noise augmentation.')
|
|
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
|
|
help='Range for TTA smoothing kernel variation (±range from default).')
|
|
parser.add_argument('--tta_scale_range', type=float, default=0.05,
|
|
help='Range for TTA amplitude scaling (±range from 1.0).')
|
|
parser.add_argument('--tta_cut_max', type=int, default=3,
|
|
help='Maximum number of timesteps to cut from beginning in TTA.')
|
|
args = parser.parse_args()
|
|
|
|
# Model paths
|
|
gru_model_path = args.gru_model_path
|
|
lstm_model_path = args.lstm_model_path
|
|
data_dir = args.data_dir
|
|
|
|
# Ensemble weights
|
|
gru_weight = args.gru_weight
|
|
lstm_weight = 1.0 - gru_weight
|
|
|
|
# TTA parameters
|
|
tta_samples = args.tta_samples
|
|
tta_noise_std = args.tta_noise_std
|
|
tta_smooth_range = args.tta_smooth_range
|
|
tta_scale_range = args.tta_scale_range
|
|
tta_cut_max = args.tta_cut_max
|
|
|
|
print(f"TTA-E Configuration:")
|
|
print(f"GRU weight: {gru_weight:.2f}")
|
|
print(f"LSTM weight: {lstm_weight:.2f}")
|
|
print(f"TTA samples per trial: {tta_samples}")
|
|
print(f"TTA noise std: {tta_noise_std}")
|
|
print(f"TTA smooth range: ±{tta_smooth_range}")
|
|
print(f"TTA scale range: ±{tta_scale_range}")
|
|
print(f"TTA max cut: {tta_cut_max} timesteps")
|
|
print(f"GRU model path: {gru_model_path}")
|
|
print(f"LSTM model path: {lstm_model_path}")
|
|
print()
|
|
|
|
# Define evaluation type
|
|
eval_type = args.eval_type
|
|
|
|
# Load CSV file
|
|
b2txt_csv_df = pd.read_csv(args.csv_path)
|
|
|
|
# Load model arguments for both models
|
|
gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml'))
|
|
lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml'))
|
|
|
|
# Set up GPU device
|
|
gpu_number = args.gpu_number
|
|
if torch.cuda.is_available() and gpu_number >= 0:
|
|
if gpu_number >= torch.cuda.device_count():
|
|
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
|
|
device = f'cuda:{gpu_number}'
|
|
device = torch.device(device)
|
|
print(f'Using {device} for model inference.')
|
|
else:
|
|
if gpu_number >= 0:
|
|
print(f'GPU number {gpu_number} requested but not available.')
|
|
print('Using CPU for model inference.')
|
|
device = torch.device('cpu')
|
|
|
|
# Define GRU model
|
|
gru_model = GRUDecoder(
|
|
neural_dim=gru_model_args['model']['n_input_features'],
|
|
n_units=gru_model_args['model']['n_units'],
|
|
n_days=len(gru_model_args['dataset']['sessions']),
|
|
n_classes=gru_model_args['dataset']['n_classes'],
|
|
rnn_dropout=gru_model_args['model']['rnn_dropout'],
|
|
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
|
|
n_layers=gru_model_args['model']['n_layers'],
|
|
patch_size=gru_model_args['model']['patch_size'],
|
|
patch_stride=gru_model_args['model']['patch_stride'],
|
|
)
|
|
|
|
# Load GRU model weights
|
|
gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'),
|
|
weights_only=False, map_location=device)
|
|
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
|
|
for key in list(gru_checkpoint['model_state_dict'].keys()):
|
|
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
|
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
|
|
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
|
|
|
|
# Define LSTM model
|
|
lstm_model = LSTMDecoder(
|
|
neural_dim=lstm_model_args['model']['n_input_features'],
|
|
n_units=lstm_model_args['model']['n_units'],
|
|
n_days=len(lstm_model_args['dataset']['sessions']),
|
|
n_classes=lstm_model_args['dataset']['n_classes'],
|
|
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
|
|
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
|
|
n_layers=lstm_model_args['model']['n_layers'],
|
|
patch_size=lstm_model_args['model']['patch_size'],
|
|
patch_stride=lstm_model_args['model']['patch_stride'],
|
|
)
|
|
|
|
# Load LSTM model weights
|
|
lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'),
|
|
weights_only=False, map_location=device)
|
|
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
|
|
for key in list(lstm_checkpoint['model_state_dict'].keys()):
|
|
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
|
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
|
|
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
|
|
|
|
# Add models to device
|
|
gru_model.to(device)
|
|
lstm_model.to(device)
|
|
|
|
# Set models to eval mode
|
|
gru_model.eval()
|
|
lstm_model.eval()
|
|
|
|
print("Both models loaded successfully!")
|
|
print()
|
|
|
|
# TTA-E inference function
|
|
def runTTAEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
|
|
device, gru_weight, lstm_weight, tta_samples, tta_noise_std,
|
|
tta_smooth_range, tta_scale_range, tta_cut_max):
|
|
"""
|
|
Run TTA-E (Test Time Augmentation + Ensemble) inference:
|
|
1. Apply multiple data augmentations to each input
|
|
2. Run both GRU and LSTM models on each augmented version
|
|
3. Ensemble model outputs with weights
|
|
4. Average across all TTA samples
|
|
"""
|
|
all_ensemble_logits = []
|
|
|
|
# Get default smoothing parameters
|
|
default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
|
|
default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
|
|
|
|
for tta_iter in range(tta_samples):
|
|
# Apply different augmentation strategies
|
|
x_augmented = x.clone()
|
|
|
|
if tta_iter == 0:
|
|
# Original data (baseline)
|
|
augmentation_type = "original"
|
|
elif tta_iter == 1:
|
|
# Add Gaussian noise
|
|
noise = torch.randn_like(x_augmented) * tta_noise_std
|
|
x_augmented = x_augmented + noise
|
|
augmentation_type = f"noise_std_{tta_noise_std}"
|
|
elif tta_iter == 2:
|
|
# Amplitude scaling
|
|
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range
|
|
x_augmented = x_augmented * scale_factor
|
|
augmentation_type = f"scale_{scale_factor:.3f}"
|
|
elif tta_iter == 3 and tta_cut_max > 0:
|
|
# Time shift (circular shift instead of cutting to maintain length)
|
|
shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8))
|
|
# Circular shift: move beginning to end
|
|
x_augmented = torch.cat([x_augmented[:, shift_amount:, :],
|
|
x_augmented[:, :shift_amount, :]], dim=1)
|
|
augmentation_type = f"shift_{shift_amount}"
|
|
else:
|
|
# Smoothing variation
|
|
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range
|
|
varied_smooth_std = max(0.5, default_smooth_std + smooth_variation)
|
|
augmentation_type = f"smooth_std_{varied_smooth_std:.2f}"
|
|
|
|
# Use autocast for efficiency
|
|
with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
|
|
|
|
# Apply Gaussian smoothing with potentially varied parameters
|
|
if tta_iter < 4 or tta_iter == 0:
|
|
# Use default smoothing for most augmentations
|
|
x_smoothed = gauss_smooth(
|
|
inputs=x_augmented,
|
|
device=device,
|
|
smooth_kernel_std=default_smooth_std,
|
|
smooth_kernel_size=default_smooth_size,
|
|
padding='valid',
|
|
)
|
|
else:
|
|
# Use varied smoothing
|
|
x_smoothed = gauss_smooth(
|
|
inputs=x_augmented,
|
|
device=device,
|
|
smooth_kernel_std=varied_smooth_std,
|
|
smooth_kernel_size=default_smooth_size,
|
|
padding='valid',
|
|
)
|
|
|
|
with torch.no_grad():
|
|
# Get GRU logits
|
|
gru_logits, _ = gru_model(
|
|
x=x_smoothed,
|
|
day_idx=torch.tensor([input_layer], device=device),
|
|
states=None,
|
|
return_state=True,
|
|
)
|
|
|
|
# Get LSTM logits
|
|
lstm_logits, _ = lstm_model(
|
|
x=x_smoothed,
|
|
day_idx=torch.tensor([input_layer], device=device),
|
|
states=None,
|
|
return_state=True,
|
|
)
|
|
|
|
# Ensemble using weighted averaging
|
|
ensemble_logits = gru_weight * gru_logits + lstm_weight * lstm_logits
|
|
all_ensemble_logits.append(ensemble_logits)
|
|
|
|
# TTA fusion: Handle potentially different tensor shapes by finding minimum length
|
|
if len(all_ensemble_logits) > 1:
|
|
# Find the minimum sequence length among all TTA samples
|
|
min_length = min([logits.shape[1] for logits in all_ensemble_logits])
|
|
|
|
# Truncate all tensors to the minimum length
|
|
truncated_logits = []
|
|
for logits in all_ensemble_logits:
|
|
if logits.shape[1] > min_length:
|
|
truncated_logits.append(logits[:, :min_length, :])
|
|
else:
|
|
truncated_logits.append(logits)
|
|
|
|
# Now stack and average
|
|
final_logits = torch.mean(torch.stack(truncated_logits), dim=0)
|
|
else:
|
|
final_logits = all_ensemble_logits[0]
|
|
|
|
# Convert logits from bfloat16 to float32
|
|
return final_logits.float().cpu().numpy()
|
|
|
|
# Load data for each session (using GRU model args as reference since they should be compatible)
|
|
test_data = {}
|
|
total_test_trials = 0
|
|
for session in gru_model_args['dataset']['sessions']:
|
|
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
|
|
if f'data_{eval_type}.hdf5' in files:
|
|
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
|
|
|
|
data = load_h5py_file(eval_file, b2txt_csv_df)
|
|
test_data[session] = data
|
|
|
|
total_test_trials += len(test_data[session]["neural_features"])
|
|
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
|
|
print(f'Total number of {eval_type} trials: {total_test_trials}')
|
|
print()
|
|
|
|
# Put neural data through the TTA-E ensemble model to get phoneme predictions (logits)
|
|
with tqdm(total=total_test_trials, desc=f'TTA-E inference ({tta_samples} samples/trial)', unit='trial') as pbar:
|
|
for session, data in test_data.items():
|
|
|
|
data['logits'] = []
|
|
data['pred_seq'] = []
|
|
input_layer = gru_model_args['dataset']['sessions'].index(session)
|
|
|
|
for trial in range(len(data['neural_features'])):
|
|
# Get neural input for the trial
|
|
neural_input = data['neural_features'][trial]
|
|
|
|
# Add batch dimension
|
|
neural_input = np.expand_dims(neural_input, axis=0)
|
|
|
|
# Convert to torch tensor
|
|
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
|
|
|
|
# Run TTA-E decoding step
|
|
ensemble_logits = runTTAEnsembleDecodingStep(
|
|
neural_input, input_layer, gru_model, lstm_model,
|
|
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight,
|
|
tta_samples, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max
|
|
)
|
|
data['logits'].append(ensemble_logits)
|
|
|
|
pbar.update(1)
|
|
pbar.close()
|
|
|
|
# Convert logits to phoneme sequences and print them out
|
|
for session, data in test_data.items():
|
|
data['pred_seq'] = []
|
|
for trial in range(len(data['logits'])):
|
|
logits = data['logits'][trial][0]
|
|
pred_seq = np.argmax(logits, axis=-1)
|
|
# Remove blanks (0)
|
|
pred_seq = [int(p) for p in pred_seq if p != 0]
|
|
# Remove consecutive duplicates
|
|
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
|
|
# Convert to phonemes
|
|
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
|
|
# Add to data
|
|
data['pred_seq'].append(pred_seq)
|
|
|
|
# Print out the predicted sequences
|
|
block_num = data['block_num'][trial]
|
|
trial_num = data['trial_num'][trial]
|
|
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
|
|
if eval_type == 'val':
|
|
sentence_label = data['sentence_label'][trial]
|
|
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
|
|
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
|
|
|
|
print(f'Sentence label: {sentence_label}')
|
|
print(f'True sequence: {" ".join(true_seq)}')
|
|
print(f'Predicted Sequence: {" ".join(pred_seq)}')
|
|
print()
|
|
|
|
# Language model inference via redis
|
|
# Make sure that the standalone language model is running on the localhost redis ip
|
|
# See README.md for instructions on how to run the language model
|
|
r = redis.Redis(host='localhost', port=6379, db=0)
|
|
r.flushall() # Clear all streams in redis
|
|
|
|
# Define redis streams for the remote language model
|
|
remote_lm_input_stream = 'remote_lm_input'
|
|
remote_lm_output_partial_stream = 'remote_lm_output_partial'
|
|
remote_lm_output_final_stream = 'remote_lm_output_final'
|
|
|
|
# Set timestamps for last entries seen in the redis streams
|
|
remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r)
|
|
remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r)
|
|
remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r)
|
|
remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r)
|
|
remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r)
|
|
|
|
lm_results = {
|
|
'session': [],
|
|
'block': [],
|
|
'trial': [],
|
|
'true_sentence': [],
|
|
'pred_sentence': [],
|
|
}
|
|
|
|
# Loop through all trials and put logits into the remote language model to get text predictions
|
|
with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar:
|
|
for session in test_data.keys():
|
|
for trial in range(len(test_data[session]['logits'])):
|
|
# Get trial logits and rearrange them for the LM
|
|
logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0]
|
|
|
|
# Reset language model
|
|
remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen)
|
|
|
|
# Put logits into LM
|
|
remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm(
|
|
r,
|
|
remote_lm_input_stream,
|
|
remote_lm_output_partial_stream,
|
|
remote_lm_output_partial_lastEntrySeen,
|
|
logits,
|
|
)
|
|
|
|
# Finalize remote LM
|
|
remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm(
|
|
r,
|
|
remote_lm_output_final_stream,
|
|
remote_lm_output_final_lastEntrySeen,
|
|
)
|
|
|
|
# Get the best candidate sentence
|
|
best_candidate_sentence = lm_out['candidate_sentences'][0]
|
|
|
|
# Store results
|
|
lm_results['session'].append(session)
|
|
lm_results['block'].append(test_data[session]['block_num'][trial])
|
|
lm_results['trial'].append(test_data[session]['trial_num'][trial])
|
|
if eval_type == 'val':
|
|
lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial])
|
|
else:
|
|
lm_results['true_sentence'].append(None)
|
|
lm_results['pred_sentence'].append(best_candidate_sentence)
|
|
|
|
# Update progress bar
|
|
pbar.update(1)
|
|
pbar.close()
|
|
|
|
# If using the validation set, calculate the aggregate word error rate (WER)
|
|
if eval_type == 'val':
|
|
total_true_length = 0
|
|
total_edit_distance = 0
|
|
|
|
lm_results['edit_distance'] = []
|
|
lm_results['num_words'] = []
|
|
|
|
for i in range(len(lm_results['pred_sentence'])):
|
|
true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip()
|
|
pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip()
|
|
ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
|
|
|
|
total_true_length += len(true_sentence.split())
|
|
total_edit_distance += ed
|
|
|
|
lm_results['edit_distance'].append(ed)
|
|
lm_results['num_words'].append(len(true_sentence.split()))
|
|
|
|
print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}')
|
|
print(f'True sentence: {true_sentence}')
|
|
print(f'Predicted sentence: {pred_sentence}')
|
|
print(f'WER: {ed} / {len(true_sentence.split())} = {100 * ed / len(true_sentence.split()):.2f}%')
|
|
print()
|
|
|
|
print(f'Total true sentence length: {total_true_length}')
|
|
print(f'Total edit distance: {total_edit_distance}')
|
|
print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%')
|
|
|
|
# Write predicted sentences to a CSV file with timestamp and TTA-E info
|
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
|
output_file = f'TTA-E_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_samples{tta_samples}_{eval_type}_{timestamp}.csv'
|
|
output_path = os.path.join(os.path.dirname(__file__), output_file)
|
|
|
|
ids = [i for i in range(len(lm_results['pred_sentence']))]
|
|
df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']})
|
|
df_out.to_csv(output_path, index=False)
|
|
|
|
print(f'\nResults saved to: {output_path}')
|
|
print(f'TTA-E configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}, TTA samples = {tta_samples}')
|