Fix handling of output from gpt2_lm_decode in main
This commit is contained in:
@@ -685,8 +685,7 @@ def main(args):
|
||||
# Optionally rescore with a LLM
|
||||
if do_opt:
|
||||
startT = time.time()
|
||||
|
||||
decoded_final, nbest_redis, confidences = gpt2_lm_decode(
|
||||
gpt2_out = gpt2_lm_decode(
|
||||
lm,
|
||||
lm_tokenizer,
|
||||
device,
|
||||
@@ -697,6 +696,11 @@ def main(args):
|
||||
current_context_str = current_context_str,
|
||||
returnConfidence = True,
|
||||
)
|
||||
if isinstance(gpt2_out, tuple) and len(gpt2_out) == 3:
|
||||
decoded_final, nbest_redis, confidences = gpt2_out
|
||||
else:
|
||||
decoded_final, nbest_redis = gpt2_out # type: ignore
|
||||
confidences = 0.0
|
||||
logging.info('OPT time: %.3f' % (time.time() - startT))
|
||||
|
||||
elif len(ngramDecoder.result()) > 0:
|
||||
|
Reference in New Issue
Block a user