typo fix, get corpus for each trial
This commit is contained in:
@@ -21,6 +21,8 @@ parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
||||
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
|
||||
help='Evaluation type: "val" for validation set, "test" for test set. '
|
||||
'If "test", ground truth is not available.')
|
||||
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
||||
help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
|
||||
parser.add_argument('--gpu_number', type=int, default=1,
|
||||
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
|
||||
args = parser.parse_args()
|
||||
@@ -33,6 +35,9 @@ data_dir = args.data_dir
|
||||
# define evaluation type
|
||||
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
|
||||
|
||||
# load csv file
|
||||
b2txt_csv_df = pd.read_csv(args.csv_path)
|
||||
|
||||
# load model args
|
||||
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
|
||||
|
||||
@@ -85,7 +90,7 @@ for session in model_args['dataset']['sessions']:
|
||||
if f'data_{eval_type}.hdf5' in files:
|
||||
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
|
||||
|
||||
data = load_h5py_file(eval_file)
|
||||
data = load_h5py_file(eval_file, b2txt_csv_df)
|
||||
test_data[session] = data
|
||||
|
||||
total_test_trials += len(test_data[session]["neural_features"])
|
||||
|
Reference in New Issue
Block a user