Fix handling of output from gpt2_lm_decode in main
This commit is contained in:
@@ -685,8 +685,7 @@ def main(args):
|
|||||||
# Optionally rescore with a LLM
|
# Optionally rescore with a LLM
|
||||||
if do_opt:
|
if do_opt:
|
||||||
startT = time.time()
|
startT = time.time()
|
||||||
|
gpt2_out = gpt2_lm_decode(
|
||||||
decoded_final, nbest_redis, confidences = gpt2_lm_decode(
|
|
||||||
lm,
|
lm,
|
||||||
lm_tokenizer,
|
lm_tokenizer,
|
||||||
device,
|
device,
|
||||||
@@ -697,6 +696,11 @@ def main(args):
|
|||||||
current_context_str = current_context_str,
|
current_context_str = current_context_str,
|
||||||
returnConfidence = True,
|
returnConfidence = True,
|
||||||
)
|
)
|
||||||
|
if isinstance(gpt2_out, tuple) and len(gpt2_out) == 3:
|
||||||
|
decoded_final, nbest_redis, confidences = gpt2_out
|
||||||
|
else:
|
||||||
|
decoded_final, nbest_redis = gpt2_out # type: ignore
|
||||||
|
confidences = 0.0
|
||||||
logging.info('OPT time: %.3f' % (time.time() - startT))
|
logging.info('OPT time: %.3f' % (time.time() - startT))
|
||||||
|
|
||||||
elif len(ngramDecoder.result()) > 0:
|
elif len(ngramDecoder.result()) > 0:
|
||||||
|
Reference in New Issue
Block a user