typo fix, get corpus for each trial
This commit is contained in:
@@ -26,7 +26,7 @@ def _extract_transcription(input):
|
||||
trans += chr(input[c])
|
||||
return trans
|
||||
|
||||
def load_h5py_file(file_path):
|
||||
def load_h5py_file(file_path, b2txt_csv_df):
|
||||
data = {
|
||||
'neural_features': [],
|
||||
'n_time_steps': [],
|
||||
@@ -36,7 +36,8 @@ def load_h5py_file(file_path):
|
||||
'sentence_label': [],
|
||||
'session': [],
|
||||
'block_num': [],
|
||||
'trial_num': []
|
||||
'trial_num': [],
|
||||
'corpus': [],
|
||||
}
|
||||
# Open the hdf5 file for that day
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
@@ -57,6 +58,11 @@ def load_h5py_file(file_path):
|
||||
block_num = g.attrs['block_num']
|
||||
trial_num = g.attrs['trial_num']
|
||||
|
||||
# match this trial up with the csv to get the corpus name
|
||||
year, month, day = session.split('.')[1:]
|
||||
date = f'{year}-{month}-{day}'
|
||||
row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]
|
||||
corpus_name = row['Corpus'].values[0]
|
||||
|
||||
data['neural_features'].append(neural_features)
|
||||
data['n_time_steps'].append(n_time_steps)
|
||||
@@ -67,6 +73,7 @@ def load_h5py_file(file_path):
|
||||
data['session'].append(session)
|
||||
data['block_num'].append(block_num)
|
||||
data['trial_num'].append(trial_num)
|
||||
data['corpus'].append(corpus_name)
|
||||
return data
|
||||
|
||||
def rearrange_speech_logits_pt(logits):
|
||||
|
Reference in New Issue
Block a user