Files
b2txt25/language_model/language-model-standalone.py

898 lines
35 KiB
Python

import redis
import argparse
import numpy as np
from datetime import datetime
import time
import os
import re
import logging
import torch
"""
Optional device backends:
- Accelerate: simplifies multi-GPU/TPU/CPU device placement
- torch_xla: TPU support (only available in TPU environments)
Both imports are optional to keep backward compatibility.
"""
try:
from accelerate import Accelerator # type: ignore
except Exception:
Accelerator = None # accelerate is optional
try:
import torch_xla.core.xla_model as xm # type: ignore
XLA_AVAILABLE = True
except Exception:
xm = None
XLA_AVAILABLE = False
import lm_decoder
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer
# set up logging
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',level=logging.INFO)
# function for initializing the ngram decoder
def build_lm_decoder(
model_path,
max_active=7000,
min_active=200,
beam=17.,
lattice_beam=8.0,
acoustic_scale=1.5,
ctc_blank_skip_threshold=1.0,
length_penalty=0.0,
nbest=1,
):
decode_opts = lm_decoder.DecodeOptions(
max_active,
min_active,
beam,
lattice_beam,
acoustic_scale,
ctc_blank_skip_threshold,
length_penalty,
nbest
)
TLG_path = os.path.join(model_path, 'TLG.fst')
words_path = os.path.join(model_path, 'words.txt')
G_path = os.path.join(model_path, 'G.fst')
rescore_G_path = os.path.join(model_path, 'G_no_prune.fst')
if not os.path.exists(rescore_G_path):
rescore_G_path = ""
G_path = ""
if not os.path.exists(TLG_path):
raise ValueError('TLG file not found at {}'.format(TLG_path))
if not os.path.exists(words_path):
raise ValueError('words file not found at {}'.format(words_path))
decode_resource = lm_decoder.DecodeResource(
TLG_path,
G_path,
rescore_G_path,
words_path,
""
)
decoder = lm_decoder.BrainSpeechDecoder(decode_resource, decode_opts)
return decoder
# function for updating the ngram decoder parameters
def update_ngram_params(
ngramDecoder,
max_active=200,
min_active=17.0,
beam=13.0,
lattice_beam=8.0,
acoustic_scale=1.5,
ctc_blank_skip_threshold=1.0,
length_penalty=0.0,
nbest=100,
):
decode_opts = lm_decoder.DecodeOptions(
max_active,
min_active,
beam,
lattice_beam,
acoustic_scale,
ctc_blank_skip_threshold,
length_penalty,
nbest,
)
ngramDecoder.SetOpt(decode_opts)
# function for initializing the OPT model and tokenizer
def build_opt(
model_name='facebook/opt-6.7b',
cache_dir=None,
device=None,
accelerator=None,
):
'''
Load the OPT-6.7b model and tokenizer from Hugging Face.
We will load the model with 16-bit precision for faster inference. This requires ~13 GB of VRAM.
Put the model onto the GPU (if available).
'''
# Resolve device automatically if not provided
if device is None:
if accelerator is not None:
device = accelerator.device
elif XLA_AVAILABLE and xm is not None:
# Use all available TPU cores for multi-core processing
device = xm.xla_device()
logging.info(f"TPU cores available: {xm.xrt_world_size()}")
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
# Choose appropriate dtype per device
try:
device_type = device.type # torch.device or XLA device
except AttributeError:
# Fallback for XLA device objects
device_type = str(device)
if XLA_AVAILABLE and (str(device).startswith('xla') or device_type == 'xla'):
load_dtype = torch.bfloat16 # TPU prefers bfloat16
elif device_type == 'cuda':
load_dtype = torch.float16
else:
load_dtype = torch.float32
# load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=cache_dir,
torch_dtype=load_dtype,
)
# Device placement
if accelerator is not None:
model = accelerator.prepare(model)
else:
model = model.to(device)
# For TPU multi-core, ensure model is replicated across all cores
if XLA_AVAILABLE and str(device).startswith('xla') and xm is not None:
# This will be handled by torch_xla internally when using xla_device()
pass
# Set the model to evaluation mode
model.eval()
# ensure padding token
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
# function for rescoring hypotheses with the GPT-2 model
@torch.inference_mode()
def rescore_with_gpt2(
model,
tokenizer,
device,
hypotheses,
length_penalty
):
# set model to evaluation mode
model.eval()
inputs = tokenizer(hypotheses, return_tensors='pt', padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
# compute log-probabilities
log_probs = torch.nn.functional.log_softmax(outputs.logits, dim=-1)
log_probs = log_probs.cpu().numpy()
input_ids = inputs['input_ids'].cpu().numpy()
attention_mask = inputs['attention_mask'].cpu().numpy()
batch_size, seq_len, _ = log_probs.shape
scores = []
for i in range(batch_size):
n_tokens = int(attention_mask[i].sum())
# sum log-probs of each token given the previous context
score = sum(
log_probs[i, t-1, input_ids[i, t]]
for t in range(1, n_tokens)
)
scores.append(score - n_tokens * length_penalty)
return scores
# function for decoding with the GPT-2 model
def gpt2_lm_decode(
model,
tokenizer,
device,
nbest,
acoustic_scale,
length_penalty,
alpha,
returnConfidence=False,
current_context_str=None,
):
hypotheses = []
acousticScores = []
oldLMScores = []
for out in nbest:
# get the candidate sentence (hypothesis)
hyp = out[0].strip()
if len(hyp) == 0:
continue
# add context to the front of each sentence
if current_context_str is not None and len(current_context_str.split()) > 0:
hyp = current_context_str + ' ' + hyp
hyp = hyp.replace('>', '')
hyp = hyp.replace(' ', ' ')
hyp = hyp.replace(' ,', ',')
hyp = hyp.replace(' .', '.')
hyp = hyp.replace(' ?', '?')
hypotheses.append(hyp)
acousticScores.append(out[1])
oldLMScores.append(out[2])
if len(hypotheses) == 0:
logging.error('In g2p_lm_decode, len(hypotheses) == 0')
return ("", []) if not returnConfidence else ("", [], 0.)
# convert to numpy arrays
acousticScores = np.array(acousticScores)
oldLMScores = np.array(oldLMScores)
# get new LM scores from LLM
try:
# first, try to rescore all at once
newLMScores = np.array(rescore_with_gpt2(model, tokenizer, device, hypotheses, length_penalty))
except Exception as e:
logging.error(f'Error during OPT rescore: {e}')
try:
# if that fails, try to rescore in batches (to avoid VRAM issues)
newLMScores = []
for i in range(0, len(hypotheses), int(np.ceil(len(hypotheses)/5))):
newLMScores.extend(rescore_with_gpt2(model, tokenizer, device, hypotheses[i:i+int(np.ceil(len(hypotheses)/5))], length_penalty))
newLMScores = np.array(newLMScores)
except Exception as e:
logging.error(f'Error during OPT rescore: {e}')
newLMScores = np.zeros(len(hypotheses))
# remove context from start of each sentence
if current_context_str is not None and len(current_context_str.split()) > 0:
hypotheses = [h[(len(current_context_str)+1):] for h in hypotheses]
# calculate total scores
totalScores = (acoustic_scale * acousticScores) + ((1 - alpha) * oldLMScores) + (alpha * newLMScores)
# get the best hypothesis
maxIdx = np.argmax(totalScores)
bestHyp = hypotheses[maxIdx]
# create nbest output
nbest_out = []
min_len = np.min((len(nbest), len(newLMScores), len(totalScores)))
for i in range(min_len):
nbest_out.append(';'.join(map(str,[nbest[i][0], nbest[i][1], nbest[i][2], newLMScores[i], totalScores[i]])))
# return
if not returnConfidence:
return bestHyp, nbest_out
else:
totalScores = totalScores - np.max(totalScores)
probs = np.exp(totalScores)
return bestHyp, nbest_out, probs[maxIdx] / np.sum(probs)
def connect_to_redis_server(redis_ip, redis_port, redis_password='admin01'):
try:
# logging.info("Attempting to connect to redis...")
redis_conn = redis.Redis(host=redis_ip, port=redis_port, password=redis_password)
redis_conn.ping()
except redis.exceptions.ConnectionError:
logging.warning("Can't connect to redis server (ConnectionError).")
return
else:
logging.info("Connected to redis.")
return redis_conn
def get_current_redis_time_ms(redis_conn):
t = redis_conn.time()
return int(t[0]*1000 + t[1]/1000)
# function to get string differences between two sentences
def get_string_differences(cue, decoder_output):
decoder_output_words = decoder_output.split()
cue_words = cue.split()
@lru_cache(None)
def reverse_w_backtrace(i, j):
if i == 0:
return j, ['I'] * j
elif j == 0:
return i, ['D'] * i
elif i > 0 and j > 0 and decoder_output_words[i-1] == cue_words[j-1]:
cost, path = reverse_w_backtrace(i-1, j-1)
return cost, path + [i - 1]
else:
insertion_cost, insertion_path = reverse_w_backtrace(i, j-1)
deletion_cost, deletion_path = reverse_w_backtrace(i-1, j)
substitution_cost, substitution_path = reverse_w_backtrace(i-1, j-1)
if insertion_cost <= deletion_cost and insertion_cost <= substitution_cost:
return insertion_cost + 1, insertion_path + ['I']
elif deletion_cost <= insertion_cost and deletion_cost <= substitution_cost:
return deletion_cost + 1, deletion_path + ['D']
else:
return substitution_cost + 1, substitution_path + ['R']
cost, path = reverse_w_backtrace(len(decoder_output_words), len(cue_words))
# remove insertions from path
path = [p for p in path if p != 'I']
# Get the indices in decoder_output of the words that are different from cue
indices_to_highlight = []
current_index = 0
for label, word in zip(path, decoder_output_words):
if label in ['R','D']:
indices_to_highlight.append((current_index, current_index+len(word)))
current_index += len(word) + 1
return cost, path, indices_to_highlight
def remove_punctuation(sentence):
# Remove punctuation
sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
sentence = sentence.replace('- ', ' ').lower()
sentence = sentence.replace('--', '').lower()
sentence = sentence.replace(" '", "'").lower()
sentence = sentence.strip()
sentence = ' '.join(sentence.split())
return sentence
# function to augment the nbest list by swapping words around, artificially increasing the number of candidates
def augment_nbest(nbest, top_candidates_to_augment=20, acoustic_scale=0.3, score_penalty_percent=0.01):
sentences = []
ac_scores = []
lm_scores = []
total_scores = []
for i in range(len(nbest)):
sentences.append(nbest[i][0].strip())
ac_scores.append(nbest[i][1])
lm_scores.append(nbest[i][2])
total_scores.append(acoustic_scale*nbest[i][1] + nbest[i][2])
# sort by total score
sorted_indices = np.argsort(total_scores)[::-1]
sentences = [sentences[i] for i in sorted_indices]
ac_scores = [ac_scores[i] for i in sorted_indices]
lm_scores = [lm_scores[i] for i in sorted_indices]
total_scores = [total_scores[i] for i in sorted_indices]
# new sentences and scores
new_sentences = []
new_ac_scores = []
new_lm_scores = []
new_total_scores = []
# swap words around
for i1 in range(np.min([len(sentences)-1, top_candidates_to_augment])):
words1 = sentences[i1].split()
for i2 in range(i1+1, np.min([len(sentences), top_candidates_to_augment])):
words2 = sentences[i2].split()
if len(words1) != len(words2):
continue
_, path1, _ = get_string_differences(sentences[i1], sentences[i2])
_, path2, _ = get_string_differences(sentences[i2], sentences[i1])
replace_indices1 = [i for i, p in enumerate(path2) if p == 'R']
replace_indices2 = [i for i, p in enumerate(path1) if p == 'R']
for r1, r2 in zip(replace_indices1, replace_indices2):
new_words1 = words1.copy()
new_words2 = words2.copy()
new_words1[r1] = words2[r2]
new_words2[r2] = words1[r1]
new_sentence1 = ' '.join(new_words1)
new_sentence2 = ' '.join(new_words2)
if new_sentence1 not in sentences and new_sentence1 not in new_sentences:
new_sentences.append(new_sentence1)
new_ac_scores.append(np.mean([ac_scores[i1], ac_scores[i2]]) - score_penalty_percent * np.abs(np.mean([ac_scores[i1], ac_scores[i2]])))
new_lm_scores.append(np.mean([lm_scores[i1], lm_scores[i2]]) - score_penalty_percent * np.abs(np.mean([lm_scores[i1], lm_scores[i2]])))
new_total_scores.append(acoustic_scale*new_ac_scores[-1] + new_lm_scores[-1])
if new_sentence2 not in sentences and new_sentence2 not in new_sentences:
new_sentences.append(new_sentence2)
new_ac_scores.append(np.mean([ac_scores[i1], ac_scores[i2]]) - score_penalty_percent * np.abs(np.mean([ac_scores[i1], ac_scores[i2]])))
new_lm_scores.append(np.mean([lm_scores[i1], lm_scores[i2]]) - score_penalty_percent * np.abs(np.mean([lm_scores[i1], lm_scores[i2]])))
new_total_scores.append(acoustic_scale*new_ac_scores[-1] + new_lm_scores[-1])
# combine new sentences and scores with old
for i in range(len(new_sentences)):
sentences.append(new_sentences[i])
ac_scores.append(new_ac_scores[i])
lm_scores.append(new_lm_scores[i])
total_scores.append(new_total_scores[i])
# sort by total score
sorted_indices = np.argsort(total_scores)[::-1]
sentences = [sentences[i] for i in sorted_indices]
ac_scores = [ac_scores[i] for i in sorted_indices]
lm_scores = [lm_scores[i] for i in sorted_indices]
total_scores = [total_scores[i] for i in sorted_indices]
# return nbest
nbest_out = []
for i in range(len(sentences)):
nbest_out.append([sentences[i], ac_scores[i], lm_scores[i]])
return nbest_out
# main function
def main(args):
lm_path = args.lm_path
gpu_number = args.gpu_number
max_active = args.max_active
min_active = args.min_active
beam = args.beam
lattice_beam = args.lattice_beam
acoustic_scale = args.acoustic_scale
ctc_blank_skip_threshold = args.ctc_blank_skip_threshold
length_penalty = args.length_penalty
nbest = args.nbest
top_candidates_to_augment = args.top_candidates_to_augment
score_penalty_percent = args.score_penalty_percent
blank_penalty = args.blank_penalty
do_opt = args.do_opt # acoustic scale = 0.8, blank penalty = 7, alpha = 0.5
opt_cache_dir = args.opt_cache_dir
alpha = args.alpha
rescore = args.rescore
redis_ip = args.redis_ip
redis_port = args.redis_port
input_stream = args.input_stream
partial_output_stream = args.partial_output_stream
final_output_stream = args.final_output_stream
# expand user on paths
lm_path = os.path.expanduser(lm_path)
if not os.path.exists(lm_path):
raise ValueError(f'Language model path does not exist: {lm_path}')
if opt_cache_dir is not None:
opt_cache_dir = os.path.expanduser(opt_cache_dir)
# create a nice dict of params to put into redis
lm_args = {
'lm_path': lm_path,
'max_active': int(max_active),
'min_active': int(min_active),
'beam': float(beam),
'lattice_beam': float(lattice_beam),
'acoustic_scale': float(acoustic_scale),
'ctc_blank_skip_threshold': float(ctc_blank_skip_threshold),
'length_penalty': float(length_penalty),
'nbest': int(nbest),
'blank_penalty': float(blank_penalty),
'alpha': float(alpha),
'do_opt': int(do_opt),
'rescore': int(rescore),
'top_candidates_to_augment': int(top_candidates_to_augment),
'score_penalty_percent': float(score_penalty_percent),
}
# pick device (Accelerate -> TPU -> CUDA -> CPU)
accelerator = None
if Accelerator is not None:
try:
accelerator = Accelerator()
except Exception:
accelerator = None
if accelerator is not None:
device = accelerator.device
elif XLA_AVAILABLE and xm is not None:
# Use all available TPU cores for multi-core processing
device = xm.xla_device()
logging.info(f"TPU cores available: {xm.xrt_world_size()}")
elif torch.cuda.is_available():
device = torch.device(f"cuda:{gpu_number}")
else:
device = torch.device("cpu")
logging.info(f'Using device: {device}')
# initialize opt model
if do_opt:
logging.info(f"Building opt model from {opt_cache_dir}...")
start_time = time.time()
lm, lm_tokenizer = build_opt(
cache_dir=opt_cache_dir,
device=device,
accelerator=accelerator,
)
logging.info(f'OPT model successfully built in {(time.time()-start_time):0.4f} seconds.')
# initialize ngram decoder
logging.info(f'Initializing language model decoder from {lm_path}...')
start_time = time.time()
ngramDecoder = build_lm_decoder(
lm_path,
max_active = 7000,
min_active = 200,
beam = 17.,
lattice_beam = 8.,
acoustic_scale = acoustic_scale,
ctc_blank_skip_threshold = 1.0,
length_penalty = 0.0,
nbest = nbest,
)
logging.info(f'Language model successfully initialized in {(time.time()-start_time):0.4f} seconds.')
# connect to redis server
REDIS_STATE = -1
redis_password = getattr(args, 'redis_password', 'admin01')
logging.info(f'Attempting to connect to redis at {redis_ip}:{redis_port}...')
r = connect_to_redis_server(redis_ip, redis_port, redis_password)
while r is None:
r = connect_to_redis_server(redis_ip, redis_port, redis_password)
if r is None:
logging.warning(f'At startup, could not connect to redis server at {redis_ip}:{redis_port}. Trying again in 3 seconds...')
time.sleep(3)
logging.info(f'Successfully connected to redis server at {redis_ip}:{redis_port}.')
timeout_ms = 100
oldStr = ''
prev_loop_start_time = 0
# main loop
logging.info('Entering main loop...')
while True:
# make sure that the loop doesn't run too fast (max 1000 Hz)
loop_time = time.time() - prev_loop_start_time
if loop_time < 0.001:
time.sleep(0.001 - loop_time)
prev_loop_start_time = time.time()
# try catch is to make sure we're connected to redis, and reconnect if not
try:
r.ping()
except redis.exceptions.ConnectionError:
if REDIS_STATE != 0:
logging.error(f'Could not connect to the redis server at at {redis_ip}:{redis_port}! I will keep trying...')
REDIS_STATE = 0
time.sleep(1)
continue
else:
if REDIS_STATE != 1:
logging.info('Successfully connected to the redis server.')
logits_last_entry_seen = get_current_redis_time_ms(r)
reset_last_entry_seen = get_current_redis_time_ms(r)
finalize_last_entry_seen = get_current_redis_time_ms(r)
update_params_last_entry_seen = get_current_redis_time_ms(r)
REDIS_STATE = 1
# if the 'remote_lm_args' stream is empty, add the current args
# (this makes sure it's re-added once redis is flushed at the start of a new block)
if r.xlen('remote_lm_args') == 0:
r.xadd('remote_lm_args', lm_args)
# check if we need to reset
lm_reset_stream = r.xread(
{'remote_lm_reset': reset_last_entry_seen},
count=1,
block=None,
)
if len(lm_reset_stream) > 0:
for entry_id, entry_data in lm_reset_stream[0][1]:
reset_last_entry_seen = entry_id
# Reset the language model and tell redis, then move on to the next loop
oldStr = ''
ngramDecoder.Reset()
r.xadd('remote_lm_done_resetting', {'done': 1})
logging.info('Reset the language model.')
continue
# check if we need to finalize
lm_finalize_stream = r.xread(
{'remote_lm_finalize': finalize_last_entry_seen},
count=1,
block=None,
)
if len(lm_finalize_stream) > 0:
for entry_id, entry_data in lm_finalize_stream[0][1]:
finalize_last_entry_seen = entry_id
if r.get('contextual_decoding_current_context') is not None:
current_context_str = r.get('contextual_decoding_current_context').decode().strip()
if len(current_context_str.split()) > 0:
logging.info(f'For LLM rescore, adding context str to the beginning of each candidate sentence:')
logging.info(f'\t"{current_context_str}"')
else:
current_context_str = ''
# Finalize decoding, add the output to the output stream, and then move on to the next loop
ngramDecoder.FinishDecoding()
oldStr = ''
# Optionally rescore with unpruned LM
if rescore:
startT = time.time()
ngramDecoder.Rescore()
logging.info('Rescore time: %.3f' % (time.time() - startT))
# if nbest > 1, augment those sentences and bias them toward certain words
if nbest > 1:
# append the sentence, acoustic score, and lm score to a list
nbest_out = []
for d in ngramDecoder.result():
nbest_out.append([d.sentence, d.ac_score, d.lm_score])
# generate some more candidate sentences by swapping words around
nbest_out_len = len(nbest_out)
nbest_out = augment_nbest(
nbest = nbest_out,
top_candidates_to_augment = top_candidates_to_augment,
acoustic_scale = acoustic_scale,
score_penalty_percent = score_penalty_percent,
)
logging.info(f'Augmented nbest from {nbest_out_len} to {len(nbest_out)} candidates.')
# Optionally rescore with a LLM
if do_opt:
startT = time.time()
gpt2_out = gpt2_lm_decode(
lm,
lm_tokenizer,
device,
nbest_out,
acoustic_scale,
alpha = alpha,
length_penalty = length_penalty,
current_context_str = current_context_str,
returnConfidence = True,
)
if isinstance(gpt2_out, tuple) and len(gpt2_out) == 3:
decoded_final, nbest_redis, confidences = gpt2_out
else:
decoded_final, nbest_redis = gpt2_out # type: ignore
confidences = 0.0
logging.info('OPT time: %.3f' % (time.time() - startT))
elif len(ngramDecoder.result()) > 0:
# Otherwise just output the best sentence
decoded_final = ngramDecoder.result()[0].sentence
# create nbest_redis with 0 values for LLM score
nbest_redis = []
for i in range(len(nbest_out)):
sentence = nbest_out[i][0].strip()
ac_score = nbest_out[i][1]
lm_score = nbest_out[i][2]
llm_score = 0.0
total_score = acoustic_scale * ac_score + lm_score
nbest_redis.append(';'.join(map(str,[sentence, ac_score, lm_score, llm_score, total_score])))
else:
logging.error('No output from language model.')
decoded_final = ''
nbest_redis = ''
logging.info(f'Final: {decoded_final}')
if nbest > 1:
r.xadd(final_output_stream, {'lm_response_final': decoded_final, 'scoring': ';'.join(nbest_redis), 'context_str': current_context_str})
else:
r.xadd(final_output_stream, {'lm_response_final': decoded_final})
logging.info('Finalized the language model.\n')
r.xadd('remote_lm_done_finalizing', {'done': 1})
continue
# check if we need to update the decoder params
update_params_stream = r.xread(
{'remote_lm_update_params': update_params_last_entry_seen},
count=1,
block=None,
)
if len(update_params_stream) > 0:
for entry_id, entry_data in update_params_stream[0][1]:
update_params_last_entry_seen = entry_id
max_active = int(entry_data.get(b'max_active', max_active))
min_active = int(entry_data.get(b'min_active', min_active))
beam = float(entry_data.get(b'beam', beam))
lattice_beam = float(entry_data.get(b'lattice_beam', lattice_beam))
acoustic_scale = float(entry_data.get(b'acoustic_scale', acoustic_scale))
ctc_blank_skip_threshold = float(entry_data.get(b'ctc_blank_skip_threshold', ctc_blank_skip_threshold))
length_penalty = float(entry_data.get(b'length_penalty', length_penalty))
nbest = int(entry_data.get(b'nbest', nbest))
blank_penalty = float(entry_data.get(b'blank_penalty', blank_penalty))
alpha = float(entry_data.get(b'alpha', alpha))
do_opt = int(entry_data.get(b'do_opt', do_opt))
rescore = int(entry_data.get(b'rescore', rescore))
top_candidates_to_augment = int(entry_data.get(b'top_candidates_to_augment', top_candidates_to_augment))
score_penalty_percent = float(entry_data.get(b'score_penalty_percent', score_penalty_percent))
# make sure that the update remote lm args are put into redis nicely
lm_args = {
'lm_path': lm_path,
'max_active': int(max_active),
'min_active': int(min_active),
'beam': float(beam),
'lattice_beam': float(lattice_beam),
'acoustic_scale': float(acoustic_scale),
'ctc_blank_skip_threshold': float(ctc_blank_skip_threshold),
'length_penalty': float(length_penalty),
'nbest': int(nbest),
'blank_penalty': float(blank_penalty),
'alpha': float(alpha),
'do_opt': int(do_opt),
'rescore': int(rescore),
'top_candidates_to_augment': int(top_candidates_to_augment),
'score_penalty_percent': float(score_penalty_percent),
}
r.xadd('remote_lm_args', lm_args)
# update ngram parameters
update_ngram_params(
ngramDecoder,
max_active = max_active,
min_active = min_active,
beam = beam,
lattice_beam = lattice_beam,
acoustic_scale = acoustic_scale,
ctc_blank_skip_threshold = ctc_blank_skip_threshold,
length_penalty = length_penalty,
nbest = nbest,
)
logging.info(
f'Updated language model params:' +
f'\n\tmax_active = {max_active}' +
f'\n\tmin_active = {min_active}' +
f'\n\tbeam = {beam}' +
f'\n\tlattice_beam = {lattice_beam}' +
f'\n\tacoustic_scale = {acoustic_scale}' +
f'\n\tctc_blank_skip_threshold = {ctc_blank_skip_threshold}' +
f'\n\tlength_penalty = {length_penalty}' +
f'\n\tnbest = {nbest}' +
f'\n\tblank_penalty = {blank_penalty}' +
f'\n\talpha = {alpha}' +
f'\n\tdo_opt = {do_opt}' +
f'\n\trescore = {rescore}' +
f'\n\ttop_candidates_to_augment = {top_candidates_to_augment}' +
f'\n\tscore_penalty_percent = {score_penalty_percent}'
)
r.xadd('remote_lm_done_updating_params', {'done': 1})
continue
# ------------------------------------------------------------------------------------------------------------------------
# ------------ The loop can only get down to here if we're not finalizing, resetting, or updating params -----------------
# ------------------------------------------------------------------------------------------------------------------------
# try to read logits from redis stream
try:
read_result = r.xread(
{input_stream: logits_last_entry_seen},
count = 1,
block = timeout_ms
)
except redis.exceptions.ConnectionError:
if REDIS_STATE != 0:
logging.error(f'Could not connect to the redis server at at {redis_ip}:{redis_port}! I will keep trying...')
REDIS_STATE = 0
time.sleep(1)
continue
if (len(read_result) >= 1):
# --------------- Read input stream --------------------------------
for entry_id, entry_data in read_result[0][1]:
logits_last_entry_seen = entry_id
logits = np.frombuffer(entry_data[b'logits'], dtype=np.float32)
# reshape logits to (T, 41)
logits = logits.reshape(-1, 41)
# --------------- Run language model -------------------------------
lm_decoder.DecodeNumpy(ngramDecoder,
logits,
np.zeros_like(logits),
np.log(blank_penalty))
# display partial decoded sentence if it exists
if len(ngramDecoder.result()) > 0:
decoded_partial = ngramDecoder.result()[0].sentence
newStr = f'Partial: {decoded_partial}'
if oldStr != newStr:
logging.info(newStr)
oldStr = newStr
else:
logging.info('Partial: [NONE]')
decoded_partial = ''
# print(ngramDecoder.result())
r.xadd(partial_output_stream, {'lm_response_partial': decoded_partial})
else:
# timeout if no data received for X ms
# logging.warning(F'No logits came in for {timeout_ms} ms.')
continue
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--lm_path', type=str, help='Path to language model folder')
parser.add_argument('--gpu_number', type=int, default=0, help='GPU number to use')
parser.add_argument('--max_active', type=int, default=7000, help='max_active param for LM')
parser.add_argument('--min_active', type=int, default=200, help='min_active param for LM')
parser.add_argument('--beam', type=float, default=17.0, help='beam param for LM')
parser.add_argument('--lattice_beam', type=float, default=8.0, help='lattice_beam param for LM')
parser.add_argument('--ctc_blank_skip_threshold', type=float, default=1., help='ctc_blank_skip_threshold param for LM')
parser.add_argument('--length_penalty', type=float, default=0.0, help='length_penalty param for LM')
parser.add_argument('--acoustic_scale', type=float, default=0.3, help='Acoustic scale for LM')
parser.add_argument('--nbest', type=int, default=100, help='# of candidate sentences for LM decoding')
parser.add_argument('--top_candidates_to_augment', type=int, default=20, help='# of top candidates to augment')
parser.add_argument('--score_penalty_percent', type=float, default=0.01, help='Score penalty percent for augmented candidates')
parser.add_argument('--blank_penalty', type=float, default=9.0, help='Blank penalty for LM')
parser.add_argument('--rescore', action='store_true', help='Use an unpruned ngram model for rescoring?')
parser.add_argument('--do_opt', action='store_true', help='Use the opt model for rescoring?')
parser.add_argument('--opt_cache_dir', type=str, default=None, help='path to opt cache')
parser.add_argument('--alpha', type=float, default=0.5, help='alpha value [0-1]: Higher = more weight on OPT rescore. Lower = more weight on ngram rescore')
parser.add_argument('--redis_ip', type=str, default='hs.zchens.cn', help='IP of the redis stream (string)')
parser.add_argument('--redis_port', type=int, default=6379, help='Port of the redis stream (int)')
parser.add_argument('--redis_password', type=str, default='admin01', help='Password for the redis stream (string)')
parser.add_argument('--input_stream', type=str, default="remote_lm_input", help='Input stream containing logits')
parser.add_argument('--partial_output_stream', type=str, default="remote_lm_output_partial", help='Output stream containing partial decoded sentences')
parser.add_argument('--final_output_stream', type=str, default="remote_lm_output_final", help='Output stream containing final decoded sentences')
args = parser.parse_args()
main(args)