From e93cff1e2e12323cef2665f24eb1b35adb6eaa92 Mon Sep 17 00:00:00 2001 From: nckcard Date: Mon, 14 Jul 2025 13:58:34 -0700 Subject: [PATCH] typo fix, get corpus for each trial --- README.md | 2 +- model_training/evaluate_model.py | 7 ++++++- model_training/evaluate_model_helpers.py | 11 +++++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index af32565..94632e1 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ The data used in this repository (which can be downloaded from [Dryad](https://d - `t15_copyTask.pkl`: This file contains the online Copy Task results required for generating Figure 2. - `t15_personalUse.pkl`: This file contains the Conversation Mode data required for generating Figure 4. - `t15_copyTask_neuralData.zip`: This dataset contains the neural data for the Copy Task. - - There are more than 11,300 sentences from 45 sessions spanning 20 months. Each trial of data includes: + - There are 10,948 sentences from 45 sessions spanning 20 months. Each trial of data includes: - The session date, block number, and trial number - 512 neural features (2 features [-4.5 RMS threshold crossings and spike band power] per electrode, 256 electrodes), binned at 20 ms resolution. The data were recorded from the speech motor cortex via four high-density microelectrode arrays (64 electrodes each). The 512 features are ordered as follows in all data files: - 0-64: ventral 6v threshold crossings diff --git a/model_training/evaluate_model.py b/model_training/evaluate_model.py index 4e8b282..81e79c2 100644 --- a/model_training/evaluate_model.py +++ b/model_training/evaluate_model.py @@ -21,6 +21,8 @@ parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'], help='Evaluation type: "val" for validation set, "test" for test set. ' 'If "test", ground truth is not available.') +parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', + help='Path to the CSV file with metadata about the dataset (relative to the current working directory).') parser.add_argument('--gpu_number', type=int, default=1, help='GPU number to use for RNN model inference. Set to -1 to use CPU.') args = parser.parse_args() @@ -33,6 +35,9 @@ data_dir = args.data_dir # define evaluation type eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available +# load csv file +b2txt_csv_df = pd.read_csv(args.csv_path) + # load model args model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml')) @@ -85,7 +90,7 @@ for session in model_args['dataset']['sessions']: if f'data_{eval_type}.hdf5' in files: eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5') - data = load_h5py_file(eval_file) + data = load_h5py_file(eval_file, b2txt_csv_df) test_data[session] = data total_test_trials += len(test_data[session]["neural_features"]) diff --git a/model_training/evaluate_model_helpers.py b/model_training/evaluate_model_helpers.py index 755b631..b16dd06 100644 --- a/model_training/evaluate_model_helpers.py +++ b/model_training/evaluate_model_helpers.py @@ -26,7 +26,7 @@ def _extract_transcription(input): trans += chr(input[c]) return trans -def load_h5py_file(file_path): +def load_h5py_file(file_path, b2txt_csv_df): data = { 'neural_features': [], 'n_time_steps': [], @@ -36,7 +36,8 @@ def load_h5py_file(file_path): 'sentence_label': [], 'session': [], 'block_num': [], - 'trial_num': [] + 'trial_num': [], + 'corpus': [], } # Open the hdf5 file for that day with h5py.File(file_path, 'r') as f: @@ -57,6 +58,11 @@ def load_h5py_file(file_path): block_num = g.attrs['block_num'] trial_num = g.attrs['trial_num'] + # match this trial up with the csv to get the corpus name + year, month, day = session.split('.')[1:] + date = f'{year}-{month}-{day}' + row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)] + corpus_name = row['Corpus'].values[0] data['neural_features'].append(neural_features) data['n_time_steps'].append(n_time_steps) @@ -67,6 +73,7 @@ def load_h5py_file(file_path): data['session'].append(session) data['block_num'].append(block_num) data['trial_num'].append(trial_num) + data['corpus'].append(corpus_name) return data def rearrange_speech_logits_pt(logits):