297 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			297 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import torch | ||
|  | import numpy as np | ||
|  | import h5py | ||
|  | import time | ||
|  | import re | ||
|  | 
 | ||
|  | from data_augmentations import gauss_smooth | ||
|  | 
 | ||
|  | LOGIT_TO_PHONEME = [ | ||
|  |     'BLANK', | ||
|  |     'AA', 'AE', 'AH', 'AO', 'AW', | ||
|  |     'AY', 'B',  'CH', 'D', 'DH', | ||
|  |     'EH', 'ER', 'EY', 'F', 'G', | ||
|  |     'HH', 'IH', 'IY', 'JH', 'K', | ||
|  |     'L', 'M', 'N', 'NG', 'OW', | ||
|  |     'OY', 'P', 'R', 'S', 'SH', | ||
|  |     'T', 'TH', 'UH', 'UW', 'V', | ||
|  |     'W', 'Y', 'Z', 'ZH', | ||
|  |     ' | ', | ||
|  | ] | ||
|  | 
 | ||
|  | def _extract_transcription(input): | ||
|  |     endIdx = np.argwhere(input == 0)[0, 0] | ||
|  |     trans = '' | ||
|  |     for c in range(endIdx): | ||
|  |         trans += chr(input[c]) | ||
|  |     return trans | ||
|  | 
 | ||
|  | def load_h5py_file(file_path, b2txt_csv_df): | ||
|  |     data = { | ||
|  |         'neural_features': [], | ||
|  |         'n_time_steps': [], | ||
|  |         'seq_class_ids': [], | ||
|  |         'seq_len': [], | ||
|  |         'transcriptions': [], | ||
|  |         'sentence_label': [], | ||
|  |         'session': [], | ||
|  |         'block_num': [], | ||
|  |         'trial_num': [], | ||
|  |         'corpus': [], | ||
|  |     } | ||
|  |     # Open the hdf5 file for that day | ||
|  |     with h5py.File(file_path, 'r') as f: | ||
|  | 
 | ||
|  |         keys = list(f.keys()) | ||
|  | 
 | ||
|  |         # For each trial in the selected trials in that day | ||
|  |         for key in keys: | ||
|  |             g = f[key] | ||
|  | 
 | ||
|  |             neural_features = g['input_features'][:] | ||
|  |             n_time_steps = g.attrs['n_time_steps'] | ||
|  |             seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None | ||
|  |             seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None | ||
|  |             transcription = g['transcription'][:] if 'transcription' in g else None | ||
|  |             sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None | ||
|  |             session = g.attrs['session'] | ||
|  |             block_num = g.attrs['block_num'] | ||
|  |             trial_num = g.attrs['trial_num'] | ||
|  | 
 | ||
|  |             # match this trial up with the csv to get the corpus name | ||
|  |             year, month, day = session.split('.')[1:] | ||
|  |             date = f'{year}-{month}-{day}' | ||
|  |             row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)] | ||
|  |             corpus_name = row['Corpus'].values[0] | ||
|  | 
 | ||
|  |             data['neural_features'].append(neural_features) | ||
|  |             data['n_time_steps'].append(n_time_steps) | ||
|  |             data['seq_class_ids'].append(seq_class_ids) | ||
|  |             data['seq_len'].append(seq_len) | ||
|  |             data['transcriptions'].append(transcription) | ||
|  |             data['sentence_label'].append(sentence_label) | ||
|  |             data['session'].append(session) | ||
|  |             data['block_num'].append(block_num) | ||
|  |             data['trial_num'].append(trial_num) | ||
|  |             data['corpus'].append(corpus_name) | ||
|  |     return data | ||
|  | 
 | ||
|  | def rearrange_speech_logits_pt(logits): | ||
|  |     # original order is [BLANK, phonemes..., SIL] | ||
|  |     # rearrange so the order is [BLANK, SIL, phonemes...] | ||
|  |     logits = np.concatenate((logits[:, :, 0:1], logits[:, :, -1:], logits[:, :, 1:-1]), axis=-1) | ||
|  |     return logits | ||
|  | 
 | ||
|  | # single decoding step function. | ||
|  | # smooths data and puts it through the model. | ||
|  | def runSingleDecodingStep(x, input_layer, model, model_args, device): | ||
|  | 
 | ||
|  |     # Use autocast for efficiency | ||
|  |     with torch.autocast(device_type = "cuda", enabled = model_args['use_amp'], dtype = torch.bfloat16): | ||
|  | 
 | ||
|  |         x = gauss_smooth( | ||
|  |             inputs = x,  | ||
|  |             device = device, | ||
|  |             smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'], | ||
|  |             smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'], | ||
|  |             padding = 'valid', | ||
|  |         ) | ||
|  | 
 | ||
|  |         with torch.no_grad(): | ||
|  |             logits, _ = model( | ||
|  |                 x = x, | ||
|  |                 day_idx = torch.tensor([input_layer], device=device), | ||
|  |                 states = None, # no initial states | ||
|  |                 return_state = True, | ||
|  |             ) | ||
|  | 
 | ||
|  |     # convert logits from bfloat16 to float32 | ||
|  |     logits = logits.float().cpu().numpy() | ||
|  | 
 | ||
|  |     # # original order is [BLANK, phonemes..., SIL] | ||
|  |     # # rearrange so the order is [BLANK, SIL, phonemes...] | ||
|  |     # logits = rearrange_speech_logits_pt(logits) | ||
|  | 
 | ||
|  |     return logits | ||
|  | 
 | ||
|  | def remove_punctuation(sentence): | ||
|  |     # Remove punctuation | ||
|  |     sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence) | ||
|  |     sentence = sentence.replace('- ', ' ').lower() | ||
|  |     sentence = sentence.replace('--', '').lower() | ||
|  |     sentence = sentence.replace(" '", "'").lower() | ||
|  | 
 | ||
|  |     sentence = sentence.strip() | ||
|  |     sentence = ' '.join([word for word in sentence.split() if word != '']) | ||
|  | 
 | ||
|  |     return sentence | ||
|  | 
 | ||
|  | def get_current_redis_time_ms(redis_conn): | ||
|  |     t = redis_conn.time() | ||
|  |     return int(t[0]*1000 + t[1]/1000) | ||
|  | 
 | ||
|  | 
 | ||
|  | ######### language model helper functions ########## | ||
|  | 
 | ||
|  | def reset_remote_language_model( | ||
|  |         r, | ||
|  |         remote_lm_done_resetting_lastEntrySeen, | ||
|  |     ): | ||
|  |      | ||
|  |     r.xadd('remote_lm_reset', {'done': 0}) | ||
|  |     time.sleep(0.001) | ||
|  |     # print('Resetting remote language model before continuing...') | ||
|  |     remote_lm_done_resetting = [] | ||
|  |     while len(remote_lm_done_resetting) == 0: | ||
|  |         remote_lm_done_resetting = r.xread( | ||
|  |             {'remote_lm_done_resetting': remote_lm_done_resetting_lastEntrySeen}, | ||
|  |             count=1, | ||
|  |             block=10000, | ||
|  |         ) | ||
|  |         if len(remote_lm_done_resetting) == 0: | ||
|  |             print(f'Still waiting for remote lm reset from ts {remote_lm_done_resetting_lastEntrySeen}...') | ||
|  |     for entry_id, entry_data in remote_lm_done_resetting[0][1]: | ||
|  |         remote_lm_done_resetting_lastEntrySeen = entry_id | ||
|  |         # print('Remote language model reset.') | ||
|  | 
 | ||
|  |     return remote_lm_done_resetting_lastEntrySeen | ||
|  | 
 | ||
|  | 
 | ||
|  | def update_remote_lm_params( | ||
|  |         r, | ||
|  |         remote_lm_done_updating_lastEntrySeen, | ||
|  |         acoustic_scale=0.35, | ||
|  |         blank_penalty=90.0, | ||
|  |         alpha=0.55, | ||
|  |     ): | ||
|  |      | ||
|  |     # update remote lm params | ||
|  |     entry_dict = { | ||
|  |         # 'max_active': max_active, | ||
|  |         # 'min_active': min_active, | ||
|  |         # 'beam': beam, | ||
|  |         # 'lattice_beam': lattice_beam, | ||
|  |         'acoustic_scale': acoustic_scale, | ||
|  |         # 'ctc_blank_skip_threshold': ctc_blank_skip_threshold, | ||
|  |         # 'length_penalty': length_penalty, | ||
|  |         # 'nbest': nbest, | ||
|  |         'blank_penalty': blank_penalty, | ||
|  |         'alpha': alpha, | ||
|  |         # 'do_opt': do_opt, | ||
|  |         # 'rescore': rescore, | ||
|  |         # 'top_candidates_to_augment': top_candidates_to_augment, | ||
|  |         # 'score_penalty_percent': score_penalty_percent, | ||
|  |         # 'specific_word_bias': specific_word_bias, | ||
|  |     } | ||
|  | 
 | ||
|  |     r.xadd('remote_lm_update_params', entry_dict) | ||
|  |     time.sleep(0.001) | ||
|  |     remote_lm_done_updating = [] | ||
|  |     while len(remote_lm_done_updating) == 0: | ||
|  |         remote_lm_done_updating = r.xread( | ||
|  |             {'remote_lm_done_updating_params': remote_lm_done_updating_lastEntrySeen}, | ||
|  |             block=10000, | ||
|  |             count=1, | ||
|  |         ) | ||
|  |         if len(remote_lm_done_updating) == 0: | ||
|  |             print(f'Still waiting for remote lm to update parameters from ts {remote_lm_done_updating_lastEntrySeen}...') | ||
|  |     for entry_id, entry_data in remote_lm_done_updating[0][1]: | ||
|  |         remote_lm_done_updating_lastEntrySeen = entry_id | ||
|  |         # print('Remote language model params updated.') | ||
|  | 
 | ||
|  |     return remote_lm_done_updating_lastEntrySeen | ||
|  | 
 | ||
|  | 
 | ||
|  | def send_logits_to_remote_lm( | ||
|  |         r, | ||
|  |         remote_lm_input_stream, | ||
|  |         remote_lm_output_partial_stream, | ||
|  |         remote_lm_output_partial_lastEntrySeen, | ||
|  |         logits, | ||
|  |     ): | ||
|  |      | ||
|  |     # put logits into remote lm and get partial output | ||
|  |     r.xadd(remote_lm_input_stream, {'logits': np.float32(logits).tobytes()}) | ||
|  |     remote_lm_output = [] | ||
|  |     while len(remote_lm_output) == 0: | ||
|  |         remote_lm_output = r.xread( | ||
|  |             {remote_lm_output_partial_stream: remote_lm_output_partial_lastEntrySeen}, | ||
|  |             block=10000, | ||
|  |             count=1, | ||
|  |         ) | ||
|  |         if len(remote_lm_output) == 0: | ||
|  |             print(f'Still waiting for remote lm partial output from ts {remote_lm_output_partial_lastEntrySeen}...') | ||
|  |     for entry_id, entry_data in remote_lm_output[0][1]: | ||
|  |         remote_lm_output_partial_lastEntrySeen = entry_id | ||
|  |         decoded = entry_data[b'lm_response_partial'].decode() | ||
|  | 
 | ||
|  |     return remote_lm_output_partial_lastEntrySeen, decoded | ||
|  | 
 | ||
|  | 
 | ||
|  | def finalize_remote_lm( | ||
|  |         r, | ||
|  |         remote_lm_output_final_stream, | ||
|  |         remote_lm_output_final_lastEntrySeen, | ||
|  |     ): | ||
|  |      | ||
|  |     # finalize remote lm | ||
|  |     r.xadd('remote_lm_finalize', {'done': 0}) | ||
|  |     time.sleep(0.005) | ||
|  |     remote_lm_output = [] | ||
|  |     while len(remote_lm_output) == 0: | ||
|  |         remote_lm_output = r.xread( | ||
|  |             {remote_lm_output_final_stream: remote_lm_output_final_lastEntrySeen}, | ||
|  |             block=10000, | ||
|  |             count=1, | ||
|  |         ) | ||
|  |         if len(remote_lm_output) == 0: | ||
|  |             print(f'Still waiting for remote lm final output from ts {remote_lm_output_final_lastEntrySeen}...') | ||
|  |     # print('Received remote lm final output.') | ||
|  | 
 | ||
|  |     for entry_id, entry_data in remote_lm_output[0][1]: | ||
|  |         remote_lm_output_final_lastEntrySeen = entry_id | ||
|  | 
 | ||
|  |         candidate_sentences = [str(c) for c in entry_data[b'scoring'].decode().split(';')[::5]] | ||
|  |         candidate_acoustic_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[1::5]] | ||
|  |         candidate_ngram_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[2::5]] | ||
|  |         candidate_llm_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[3::5]] | ||
|  |         candidate_total_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[4::5]] | ||
|  | 
 | ||
|  | 
 | ||
|  |     # account for a weird edge case where there are no candidate sentences | ||
|  |     if len(candidate_sentences) == 0 or len(candidate_total_scores) == 0: | ||
|  |         print('No candidate sentences were received from the language model.') | ||
|  |         candidate_sentences = [''] | ||
|  |         candidate_acoustic_scores = [0] | ||
|  |         candidate_ngram_scores = [0] | ||
|  |         candidate_llm_scores = [0] | ||
|  |         candidate_total_scores = [0] | ||
|  | 
 | ||
|  |     else: | ||
|  |         # sort candidate sentences by total score (higher is better) | ||
|  |         sort_order = np.argsort(candidate_total_scores)[::-1] | ||
|  | 
 | ||
|  |         candidate_sentences = [candidate_sentences[i] for i in sort_order] | ||
|  |         candidate_acoustic_scores = [candidate_acoustic_scores[i] for i in sort_order] | ||
|  |         candidate_ngram_scores = [candidate_ngram_scores[i] for i in sort_order] | ||
|  |         candidate_llm_scores = [candidate_llm_scores[i] for i in sort_order] | ||
|  |         candidate_total_scores = [candidate_total_scores[i] for i in sort_order] | ||
|  | 
 | ||
|  |     # loop through candidates backwards and remove any duplicates | ||
|  |     for i in range(len(candidate_sentences)-1, 0, -1): | ||
|  |         if candidate_sentences[i] in candidate_sentences[:i]: | ||
|  |             candidate_sentences.pop(i) | ||
|  |             candidate_acoustic_scores.pop(i) | ||
|  |             candidate_ngram_scores.pop(i) | ||
|  |             candidate_llm_scores.pop(i) | ||
|  |             candidate_total_scores.pop(i) | ||
|  | 
 | ||
|  |     lm_out = { | ||
|  |         'candidate_sentences': candidate_sentences, | ||
|  |         'candidate_acoustic_scores': candidate_acoustic_scores, | ||
|  |         'candidate_ngram_scores': candidate_ngram_scores, | ||
|  |         'candidate_llm_scores': candidate_llm_scores, | ||
|  |         'candidate_total_scores': candidate_total_scores, | ||
|  |     } | ||
|  | 
 | ||
|  |     return remote_lm_output_final_lastEntrySeen, lm_out |