# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import argparse import copy import logging import os import sys import torch import yaml from torch.utils.data import DataLoader from wenet.dataset.dataset import AudioDataset, CollateFunc from wenet.transformer.asr_model import init_asr_model from wenet.utils.checkpoint import load_checkpoint if __name__ == '__main__': parser = argparse.ArgumentParser(description='recognize with your model') parser.add_argument('--config', required=True, help='config file') parser.add_argument('--test_data', required=True, help='test data file') parser.add_argument('--gpu', type=int, default=-1, help='gpu id for this rank, -1 for cpu') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--dict', required=True, help='dict file') parser.add_argument('--beam_size', type=int, default=10, help='beam size for search') parser.add_argument('--penalty', type=float, default=0.0, help='length penalty') parser.add_argument('--result_file', required=True, help='asr result file') parser.add_argument('--batch_size', type=int, default=16, help='asr result file') parser.add_argument('--mode', choices=[ 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' ], default='attention', help='decoding mode') parser.add_argument('--ctc_weight', type=float, default=0.0, help='ctc weight for attention rescoring decode mode') parser.add_argument('--decoding_chunk_size', type=int, default=-1, help='''decoding chunk size, <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here''') parser.add_argument('--num_decoding_left_chunks', type=int, default=-1, help='number of left chunks for decoding') parser.add_argument('--simulate_streaming', action='store_true', help='simulate streaming inference') parser.add_argument('--reverse_weight', type=float, default=0.0, help='''right to left weight for attention rescoring decode mode''') args = parser.parse_args() print(args) logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s') os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring' ] and args.batch_size > 1: logging.fatal( 'decoding mode {} must be running with batch_size == 1'.format( args.mode)) sys.exit(1) with open(args.config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) raw_wav = configs['raw_wav'] # Init dataset and data loader # Init dataset and data loader test_collate_conf = copy.deepcopy(configs['collate_conf']) test_collate_conf['spec_aug'] = False test_collate_conf['spec_sub'] = False test_collate_conf['feature_dither'] = False test_collate_conf['speed_perturb'] = False if raw_wav: test_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 test_collate_func = CollateFunc(**test_collate_conf, raw_wav=raw_wav) dataset_conf = configs.get('dataset_conf', {}) dataset_conf['batch_size'] = args.batch_size dataset_conf['batch_type'] = 'static' dataset_conf['sort'] = False test_dataset = AudioDataset(args.test_data, **dataset_conf, raw_wav=raw_wav) test_data_loader = DataLoader(test_dataset, collate_fn=test_collate_func, shuffle=False, batch_size=1, num_workers=0) # Init asr model from configs model = init_asr_model(configs) # Load dict char_dict = {} with open(args.dict, 'r') as fin: for line in fin: arr = line.strip().split() assert len(arr) == 2 char_dict[int(arr[1])] = arr[0] eos = len(char_dict) - 1 load_checkpoint(model, args.checkpoint) use_cuda = args.gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) model.eval() with torch.no_grad(), open(args.result_file, 'w') as fout: for batch_idx, batch in enumerate(test_data_loader): keys, feats, target, feats_lengths, target_lengths = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) target_lengths = target_lengths.to(device) if args.mode == 'attention': hyps = model.recognize( feats, feats_lengths, beam_size=args.beam_size, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming) hyps = [hyp.tolist() for hyp in hyps] elif args.mode == 'ctc_greedy_search': hyps = model.ctc_greedy_search( feats, feats_lengths, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming) # ctc_prefix_beam_search and attention_rescoring only return one # result in List[int], change it to List[List[int]] for compatible # with other batch decoding mode elif args.mode == 'ctc_prefix_beam_search': assert (feats.size(0) == 1) hyp = model.ctc_prefix_beam_search( feats, feats_lengths, args.beam_size, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, simulate_streaming=args.simulate_streaming) hyps = [hyp] elif args.mode == 'attention_rescoring': assert (feats.size(0) == 1) hyp = model.attention_rescoring( feats, feats_lengths, args.beam_size, decoding_chunk_size=args.decoding_chunk_size, num_decoding_left_chunks=args.num_decoding_left_chunks, ctc_weight=args.ctc_weight, simulate_streaming=args.simulate_streaming, reverse_weight=args.reverse_weight) hyps = [hyp] for i, key in enumerate(keys): content = '' for w in hyps[i]: if w == eos: break content += char_dict[w] logging.info('{} {}'.format(key, content)) fout.write('{} {}\n'.format(key, content))