typo fix, get corpus for each trial

This commit is contained in:
nckcard
2025-07-14 13:58:34 -07:00
parent 82274632af
commit e93cff1e2e
3 changed files with 16 additions and 4 deletions

View File

@@ -21,6 +21,8 @@ parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set. '
'If "test", ground truth is not available.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
parser.add_argument('--gpu_number', type=int, default=1,
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
args = parser.parse_args()
@@ -33,6 +35,9 @@ data_dir = args.data_dir
# define evaluation type
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
# load csv file
b2txt_csv_df = pd.read_csv(args.csv_path)
# load model args
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
@@ -85,7 +90,7 @@ for session in model_args['dataset']['sessions']:
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
data = load_h5py_file(eval_file)
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])