Fix handling of output from gpt2_lm_decode in main

This commit is contained in:
Zchen
2025-10-05 11:01:21 +08:00
parent ee081a9151
commit 6f1b1d827d

View File

@@ -685,8 +685,7 @@ def main(args):
# Optionally rescore with a LLM
if do_opt:
startT = time.time()
decoded_final, nbest_redis, confidences = gpt2_lm_decode(
gpt2_out = gpt2_lm_decode(
lm,
lm_tokenizer,
device,
@@ -697,6 +696,11 @@ def main(args):
current_context_str = current_context_str,
returnConfidence = True,
)
if isinstance(gpt2_out, tuple) and len(gpt2_out) == 3:
decoded_final, nbest_redis, confidences = gpt2_out
else:
decoded_final, nbest_redis = gpt2_out # type: ignore
confidences = 0.0
logging.info('OPT time: %.3f' % (time.time() - startT))
elif len(ngramDecoder.result()) > 0: