50 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
		
		
			
		
	
	
			50 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
|   | #!/usr/bin/env python | ||
|  | # Copyright (c) Facebook, Inc. and its affiliates. | ||
|  | # All rights reserved. | ||
|  | # | ||
|  | # This source code is licensed under the license found in the | ||
|  | # https://github.com/pytorch/fairseq/blob/master/LICENSE | ||
|  | 
 | ||
|  | from __future__ import absolute_import, division, print_function, unicode_literals | ||
|  | 
 | ||
|  | import argparse | ||
|  | import sys | ||
|  | 
 | ||
|  | import sentencepiece as spm | ||
|  | 
 | ||
|  | 
 | ||
|  | def main(): | ||
|  |     parser = argparse.ArgumentParser() | ||
|  |     parser.add_argument("--model", required=True, | ||
|  |                         help="sentencepiece model to use for decoding") | ||
|  |     parser.add_argument("--input", default=None, help="input file to decode") | ||
|  |     parser.add_argument("--input_format", choices=["piece", "id"], default="piece") | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     sp = spm.SentencePieceProcessor() | ||
|  |     sp.Load(args.model) | ||
|  | 
 | ||
|  |     if args.input_format == "piece": | ||
|  |         def decode(l): | ||
|  |             return "".join(sp.DecodePieces(l)) | ||
|  |     elif args.input_format == "id": | ||
|  |         def decode(l): | ||
|  |             return "".join(sp.DecodeIds(l)) | ||
|  |     else: | ||
|  |         raise NotImplementedError | ||
|  | 
 | ||
|  |     def tok2int(tok): | ||
|  |         # remap reference-side <unk> (represented as <<unk>>) to 0 | ||
|  |         return int(tok) if tok != "<<unk>>" else 0 | ||
|  | 
 | ||
|  |     if args.input is None: | ||
|  |         h = sys.stdin | ||
|  |     else: | ||
|  |         h = open(args.input, "r", encoding="utf-8") | ||
|  |     for line in h: | ||
|  |         print(decode(line.split())) | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     main() |