| 
									
										
										
										
											2025-07-02 12:18:09 -07:00
										 |  |  | import os | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2025-07-06 21:52:59 -07:00
										 |  |  | import pandas as pd | 
					
						
							| 
									
										
										
										
											2025-07-02 12:18:09 -07:00
										 |  |  | import redis | 
					
						
							|  |  |  | from omegaconf import OmegaConf | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | from tqdm import tqdm | 
					
						
							|  |  |  | import editdistance | 
					
						
							|  |  |  | import argparse | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from rnn_model import GRUDecoder | 
					
						
							|  |  |  | from evaluate_model_helpers import * | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # argument parser for command line arguments | 
					
						
							|  |  |  | parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.') | 
					
						
							|  |  |  | parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline', | 
					
						
							|  |  |  |                     help='Path to the pretrained model directory (relative to the current working directory).') | 
					
						
							| 
									
										
										
										
											2025-07-03 13:54:46 -07:00
										 |  |  | parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', | 
					
						
							| 
									
										
										
										
											2025-07-02 12:18:09 -07:00
										 |  |  |                     help='Path to the dataset directory (relative to the current working directory).') | 
					
						
							|  |  |  | parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'], | 
					
						
							|  |  |  |                     help='Evaluation type: "val" for validation set, "test" for test set. ' | 
					
						
							|  |  |  |                          'If "test", ground truth is not available.') | 
					
						
							| 
									
										
										
										
											2025-07-14 13:58:34 -07:00
										 |  |  | parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', | 
					
						
							|  |  |  |                     help='Path to the CSV file with metadata about the dataset (relative to the current working directory).') | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | parser.add_argument('--gpu_number', type=int, default=-1, | 
					
						
							| 
									
										
										
										
											2025-07-02 12:18:09 -07:00
										 |  |  |                     help='GPU number to use for RNN model inference. Set to -1 to use CPU.') | 
					
						
							|  |  |  | args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # paths to model and data directories | 
					
						
							|  |  |  | # Note: these paths are relative to the current working directory | 
					
						
							|  |  |  | model_path = args.model_path | 
					
						
							|  |  |  | data_dir = args.data_dir | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # define evaluation type | 
					
						
							|  |  |  | eval_type = args.eval_type  # can be 'val' or 'test'. if 'test', ground truth is not available | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-14 13:58:34 -07:00
										 |  |  | # load csv file | 
					
						
							|  |  |  | b2txt_csv_df = pd.read_csv(args.csv_path) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-02 12:18:09 -07:00
										 |  |  | # load model args | 
					
						
							|  |  |  | model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml')) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # set up gpu device | 
					
						
							|  |  |  | gpu_number = args.gpu_number | 
					
						
							|  |  |  | if torch.cuda.is_available() and gpu_number >= 0: | 
					
						
							|  |  |  |     if gpu_number >= torch.cuda.device_count(): | 
					
						
							|  |  |  |         raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}') | 
					
						
							|  |  |  |     device = f'cuda:{gpu_number}' | 
					
						
							|  |  |  |     device = torch.device(device) | 
					
						
							|  |  |  |     print(f'Using {device} for model inference.') | 
					
						
							|  |  |  | else: | 
					
						
							|  |  |  |     if gpu_number >= 0: | 
					
						
							|  |  |  |         print(f'GPU number {gpu_number} requested but not available.') | 
					
						
							|  |  |  |     print('Using CPU for model inference.') | 
					
						
							|  |  |  |     device = torch.device('cpu') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # define model | 
					
						
							|  |  |  | model = GRUDecoder( | 
					
						
							|  |  |  |     neural_dim = model_args['model']['n_input_features'], | 
					
						
							|  |  |  |     n_units = model_args['model']['n_units'],  | 
					
						
							|  |  |  |     n_days = len(model_args['dataset']['sessions']), | 
					
						
							|  |  |  |     n_classes = model_args['dataset']['n_classes'], | 
					
						
							|  |  |  |     rnn_dropout = model_args['model']['rnn_dropout'], | 
					
						
							|  |  |  |     input_dropout = model_args['model']['input_network']['input_layer_dropout'], | 
					
						
							|  |  |  |     n_layers = model_args['model']['n_layers'], | 
					
						
							|  |  |  |     patch_size = model_args['model']['patch_size'], | 
					
						
							|  |  |  |     patch_stride = model_args['model']['patch_stride'], | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # load model weights | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | checkpoint = torch.load( | 
					
						
							|  |  |  |     os.path.join(model_path, 'checkpoint/best_checkpoint'), | 
					
						
							|  |  |  |     map_location=device, | 
					
						
							|  |  |  |     weights_only=False, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2025-07-02 12:18:09 -07:00
										 |  |  | # rename keys to not start with "module." (happens if model was saved with DataParallel) | 
					
						
							|  |  |  | for key in list(checkpoint['model_state_dict'].keys()): | 
					
						
							|  |  |  |     checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key) | 
					
						
							|  |  |  |     checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key) | 
					
						
							|  |  |  | model.load_state_dict(checkpoint['model_state_dict'])   | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # add model to device | 
					
						
							|  |  |  | model.to(device)  | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # set model to eval mode | 
					
						
							|  |  |  | model.eval() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # load data for each session | 
					
						
							|  |  |  | test_data = {} | 
					
						
							|  |  |  | total_test_trials = 0 | 
					
						
							|  |  |  | for session in model_args['dataset']['sessions']: | 
					
						
							|  |  |  |     files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')] | 
					
						
							|  |  |  |     if f'data_{eval_type}.hdf5' in files: | 
					
						
							|  |  |  |         eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-14 13:58:34 -07:00
										 |  |  |         data = load_h5py_file(eval_file, b2txt_csv_df) | 
					
						
							| 
									
										
										
										
											2025-07-02 12:18:09 -07:00
										 |  |  |         test_data[session] = data | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         total_test_trials += len(test_data[session]["neural_features"]) | 
					
						
							|  |  |  |         print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.') | 
					
						
							|  |  |  | print(f'Total number of {eval_type} trials: {total_test_trials}') | 
					
						
							|  |  |  | print() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # put neural data through the pretrained model to get phoneme predictions (logits) | 
					
						
							|  |  |  | with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar: | 
					
						
							|  |  |  |     for session, data in test_data.items(): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         data['logits'] = [] | 
					
						
							|  |  |  |         data['pred_seq'] = [] | 
					
						
							|  |  |  |         input_layer = model_args['dataset']['sessions'].index(session) | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         for trial in range(len(data['neural_features'])): | 
					
						
							|  |  |  |             # get neural input for the trial | 
					
						
							|  |  |  |             neural_input = data['neural_features'][trial] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # add batch dimension | 
					
						
							|  |  |  |             neural_input = np.expand_dims(neural_input, axis=0) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # convert to torch tensor | 
					
						
							|  |  |  |             neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # run decoding step | 
					
						
							|  |  |  |             logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device) | 
					
						
							|  |  |  |             data['logits'].append(logits) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             pbar.update(1) | 
					
						
							|  |  |  | pbar.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # convert logits to phoneme sequences and print them out | 
					
						
							|  |  |  | for session, data in test_data.items(): | 
					
						
							|  |  |  |     data['pred_seq'] = [] | 
					
						
							|  |  |  |     for trial in range(len(data['logits'])): | 
					
						
							|  |  |  |         logits = data['logits'][trial][0] | 
					
						
							|  |  |  |         pred_seq = np.argmax(logits, axis=-1) | 
					
						
							|  |  |  |         # remove blanks (0) | 
					
						
							|  |  |  |         pred_seq = [int(p) for p in pred_seq if p != 0] | 
					
						
							|  |  |  |         # remove consecutive duplicates | 
					
						
							|  |  |  |         pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] | 
					
						
							|  |  |  |         # convert to phonemes | 
					
						
							|  |  |  |         pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq] | 
					
						
							|  |  |  |         # add to data | 
					
						
							|  |  |  |         data['pred_seq'].append(pred_seq) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # print out the predicted sequences | 
					
						
							|  |  |  |         block_num = data['block_num'][trial] | 
					
						
							|  |  |  |         trial_num = data['trial_num'][trial] | 
					
						
							|  |  |  |         print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}') | 
					
						
							|  |  |  |         if eval_type == 'val': | 
					
						
							|  |  |  |             sentence_label = data['sentence_label'][trial] | 
					
						
							|  |  |  |             true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]] | 
					
						
							|  |  |  |             true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             print(f'Sentence label:      {sentence_label}') | 
					
						
							|  |  |  |             print(f'True sequence:       {" ".join(true_seq)}') | 
					
						
							|  |  |  |         print(f'Predicted Sequence:  {" ".join(pred_seq)}') | 
					
						
							|  |  |  |         print() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # language model inference via redis | 
					
						
							|  |  |  | # make sure that the standalone language model is running on the localhost redis ip | 
					
						
							|  |  |  | # see README.md for instructions on how to run the language model | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def connect_to_redis_with_retry(host, port, password, db=0, max_retries=10, retry_delay=3): | 
					
						
							|  |  |  |     """Connect to Redis with retry logic""" | 
					
						
							|  |  |  |     for attempt in range(max_retries): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             print(f"Attempting to connect to Redis at {host}:{port} (attempt {attempt + 1}/{max_retries})...") | 
					
						
							|  |  |  |             r = redis.Redis(host=host, port=port, db=db, password=password) | 
					
						
							|  |  |  |             r.ping()  # Test the connection | 
					
						
							|  |  |  |             print(f"Successfully connected to Redis at {host}:{port}") | 
					
						
							|  |  |  |             return r | 
					
						
							|  |  |  |         except redis.exceptions.ConnectionError as e: | 
					
						
							|  |  |  |             print(f"Redis connection failed (attempt {attempt + 1}/{max_retries}): {e}") | 
					
						
							|  |  |  |             if attempt < max_retries - 1: | 
					
						
							|  |  |  |                 print(f"Retrying in {retry_delay} seconds...") | 
					
						
							|  |  |  |                 time.sleep(retry_delay) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 print("Max retries reached. Could not connect to Redis.") | 
					
						
							|  |  |  |                 raise e | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             print(f"Unexpected error connecting to Redis: {e}") | 
					
						
							|  |  |  |             if attempt < max_retries - 1: | 
					
						
							|  |  |  |                 print(f"Retrying in {retry_delay} seconds...") | 
					
						
							|  |  |  |                 time.sleep(retry_delay) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 raise e | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | r = connect_to_redis_with_retry('hs.zchens.cn', 6379, 'admin01') | 
					
						
							| 
									
										
										
										
											2025-07-03 14:05:04 -07:00
										 |  |  | r.flushall()  # clear all streams in redis | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # define redis streams for the remote language model | 
					
						
							| 
									
										
										
										
											2025-07-02 12:18:09 -07:00
										 |  |  | remote_lm_input_stream = 'remote_lm_input' | 
					
						
							|  |  |  | remote_lm_output_partial_stream = 'remote_lm_output_partial' | 
					
						
							|  |  |  | remote_lm_output_final_stream = 'remote_lm_output_final' | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-03 14:05:04 -07:00
										 |  |  | # set timestamps for last entries seen in the redis streams | 
					
						
							| 
									
										
										
										
											2025-07-02 12:18:09 -07:00
										 |  |  | remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r) | 
					
						
							|  |  |  | remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r) | 
					
						
							|  |  |  | remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r) | 
					
						
							|  |  |  | remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r) | 
					
						
							|  |  |  | remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | lm_results = { | 
					
						
							|  |  |  |     'session': [], | 
					
						
							|  |  |  |     'block': [], | 
					
						
							|  |  |  |     'trial': [], | 
					
						
							|  |  |  |     'true_sentence': [], | 
					
						
							|  |  |  |     'pred_sentence': [], | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # loop through all trials and put logits into the remote language model to get text predictions | 
					
						
							|  |  |  | # note: this takes ~15-20 minutes to run on the entire test split with the 5-gram LM + OPT rescoring (RTX 4090) | 
					
						
							|  |  |  | with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar: | 
					
						
							|  |  |  |     for session in test_data.keys(): | 
					
						
							|  |  |  |         for trial in range(len(test_data[session]['logits'])): | 
					
						
							|  |  |  |             # get trial logits and rearrange them for the LM | 
					
						
							|  |  |  |             logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # reset language model | 
					
						
							|  |  |  |             remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen) | 
					
						
							|  |  |  |              | 
					
						
							|  |  |  |             '''
 | 
					
						
							|  |  |  |             # update language model parameters | 
					
						
							|  |  |  |             remote_lm_done_updating_lastEntrySeen = update_remote_lm_params( | 
					
						
							|  |  |  |                 r, | 
					
						
							|  |  |  |                 remote_lm_done_updating_lastEntrySeen, | 
					
						
							|  |  |  |                 acoustic_scale=0.35, | 
					
						
							|  |  |  |                 blank_penalty=90.0, | 
					
						
							|  |  |  |                 alpha=0.55, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             '''
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # put logits into LM | 
					
						
							|  |  |  |             remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm( | 
					
						
							|  |  |  |                 r, | 
					
						
							|  |  |  |                 remote_lm_input_stream, | 
					
						
							|  |  |  |                 remote_lm_output_partial_stream, | 
					
						
							|  |  |  |                 remote_lm_output_partial_lastEntrySeen, | 
					
						
							|  |  |  |                 logits, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # finalize remote LM | 
					
						
							|  |  |  |             remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm( | 
					
						
							|  |  |  |                 r, | 
					
						
							|  |  |  |                 remote_lm_output_final_stream, | 
					
						
							|  |  |  |                 remote_lm_output_final_lastEntrySeen, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # get the best candidate sentence | 
					
						
							|  |  |  |             best_candidate_sentence = lm_out['candidate_sentences'][0] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # store results | 
					
						
							|  |  |  |             lm_results['session'].append(session) | 
					
						
							|  |  |  |             lm_results['block'].append(test_data[session]['block_num'][trial]) | 
					
						
							|  |  |  |             lm_results['trial'].append(test_data[session]['trial_num'][trial]) | 
					
						
							|  |  |  |             if eval_type == 'val': | 
					
						
							|  |  |  |                 lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial]) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 lm_results['true_sentence'].append(None) | 
					
						
							|  |  |  |             lm_results['pred_sentence'].append(best_candidate_sentence) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # update progress bar | 
					
						
							|  |  |  |             pbar.update(1) | 
					
						
							|  |  |  | pbar.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # if using the validation set, lets calculate the aggregate word error rate (WER) | 
					
						
							|  |  |  | if eval_type == 'val': | 
					
						
							|  |  |  |     total_true_length = 0 | 
					
						
							|  |  |  |     total_edit_distance = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     lm_results['edit_distance'] = [] | 
					
						
							|  |  |  |     lm_results['num_words'] = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for i in range(len(lm_results['pred_sentence'])): | 
					
						
							|  |  |  |         true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip() | 
					
						
							|  |  |  |         pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip() | 
					
						
							|  |  |  |         ed = editdistance.eval(true_sentence.split(), pred_sentence.split()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         total_true_length += len(true_sentence.split()) | 
					
						
							|  |  |  |         total_edit_distance += ed | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         lm_results['edit_distance'].append(ed) | 
					
						
							|  |  |  |         lm_results['num_words'].append(len(true_sentence.split())) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}') | 
					
						
							|  |  |  |         print(f'True sentence:       {true_sentence}') | 
					
						
							|  |  |  |         print(f'Predicted sentence:  {pred_sentence}') | 
					
						
							|  |  |  |         print(f'WER: {ed} / {100 * len(true_sentence.split())} = {ed / len(true_sentence.split()):.2f}%') | 
					
						
							|  |  |  |         print() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print(f'Total true sentence length: {total_true_length}') | 
					
						
							|  |  |  |     print(f'Total edit distance: {total_edit_distance}') | 
					
						
							|  |  |  |     print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-06 21:52:59 -07:00
										 |  |  | # write predicted sentences to a csv file. put a timestamp in the filename (YYYYMMDD_HHMMSS) | 
					
						
							|  |  |  | output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.csv') | 
					
						
							|  |  |  | ids = [i for i in range(len(lm_results['pred_sentence']))] | 
					
						
							|  |  |  | df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']}) | 
					
						
							|  |  |  | df_out.to_csv(output_file, index=False) |