typo fix, get corpus for each trial

This commit is contained in:
nckcard
2025-07-14 13:58:34 -07:00
parent 82274632af
commit e93cff1e2e
3 changed files with 16 additions and 4 deletions

View File

@@ -26,7 +26,7 @@ def _extract_transcription(input):
trans += chr(input[c])
return trans
def load_h5py_file(file_path):
def load_h5py_file(file_path, b2txt_csv_df):
data = {
'neural_features': [],
'n_time_steps': [],
@@ -36,7 +36,8 @@ def load_h5py_file(file_path):
'sentence_label': [],
'session': [],
'block_num': [],
'trial_num': []
'trial_num': [],
'corpus': [],
}
# Open the hdf5 file for that day
with h5py.File(file_path, 'r') as f:
@@ -57,6 +58,11 @@ def load_h5py_file(file_path):
block_num = g.attrs['block_num']
trial_num = g.attrs['trial_num']
# match this trial up with the csv to get the corpus name
year, month, day = session.split('.')[1:]
date = f'{year}-{month}-{day}'
row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]
corpus_name = row['Corpus'].values[0]
data['neural_features'].append(neural_features)
data['n_time_steps'].append(n_time_steps)
@@ -67,6 +73,7 @@ def load_h5py_file(file_path):
data['session'].append(session)
data['block_num'].append(block_num)
data['trial_num'].append(trial_num)
data['corpus'].append(corpus_name)
return data
def rearrange_speech_logits_pt(logits):