198 lines
8.4 KiB
Python
198 lines
8.4 KiB
Python
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import copy
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import yaml
|
|
from torch.utils.data import DataLoader
|
|
|
|
from wenet.dataset.dataset import AudioDataset, CollateFunc
|
|
from wenet.transformer.asr_model import init_asr_model
|
|
from wenet.utils.checkpoint import load_checkpoint
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='recognize with your model')
|
|
parser.add_argument('--config', required=True, help='config file')
|
|
parser.add_argument('--test_data', required=True, help='test data file')
|
|
parser.add_argument('--gpu',
|
|
type=int,
|
|
default=-1,
|
|
help='gpu id for this rank, -1 for cpu')
|
|
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
|
parser.add_argument('--dict', required=True, help='dict file')
|
|
parser.add_argument('--beam_size',
|
|
type=int,
|
|
default=10,
|
|
help='beam size for search')
|
|
parser.add_argument('--penalty',
|
|
type=float,
|
|
default=0.0,
|
|
help='length penalty')
|
|
parser.add_argument('--result_file', required=True, help='asr result file')
|
|
parser.add_argument('--batch_size',
|
|
type=int,
|
|
default=16,
|
|
help='asr result file')
|
|
parser.add_argument('--mode',
|
|
choices=[
|
|
'attention', 'ctc_greedy_search',
|
|
'ctc_prefix_beam_search', 'attention_rescoring'
|
|
],
|
|
default='attention',
|
|
help='decoding mode')
|
|
parser.add_argument('--ctc_weight',
|
|
type=float,
|
|
default=0.0,
|
|
help='ctc weight for attention rescoring decode mode')
|
|
parser.add_argument('--decoding_chunk_size',
|
|
type=int,
|
|
default=-1,
|
|
help='''decoding chunk size,
|
|
<0: for decoding, use full chunk.
|
|
>0: for decoding, use fixed chunk size as set.
|
|
0: used for training, it's prohibited here''')
|
|
parser.add_argument('--num_decoding_left_chunks',
|
|
type=int,
|
|
default=-1,
|
|
help='number of left chunks for decoding')
|
|
parser.add_argument('--simulate_streaming',
|
|
action='store_true',
|
|
help='simulate streaming inference')
|
|
parser.add_argument('--reverse_weight',
|
|
type=float,
|
|
default=0.0,
|
|
help='''right to left weight for attention rescoring
|
|
decode mode''')
|
|
args = parser.parse_args()
|
|
print(args)
|
|
logging.basicConfig(level=logging.DEBUG,
|
|
format='%(asctime)s %(levelname)s %(message)s')
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
|
|
|
if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring'
|
|
] and args.batch_size > 1:
|
|
logging.fatal(
|
|
'decoding mode {} must be running with batch_size == 1'.format(
|
|
args.mode))
|
|
sys.exit(1)
|
|
|
|
with open(args.config, 'r') as fin:
|
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
|
|
|
raw_wav = configs['raw_wav']
|
|
# Init dataset and data loader
|
|
# Init dataset and data loader
|
|
test_collate_conf = copy.deepcopy(configs['collate_conf'])
|
|
test_collate_conf['spec_aug'] = False
|
|
test_collate_conf['spec_sub'] = False
|
|
test_collate_conf['feature_dither'] = False
|
|
test_collate_conf['speed_perturb'] = False
|
|
if raw_wav:
|
|
test_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0
|
|
test_collate_func = CollateFunc(**test_collate_conf, raw_wav=raw_wav)
|
|
dataset_conf = configs.get('dataset_conf', {})
|
|
dataset_conf['batch_size'] = args.batch_size
|
|
dataset_conf['batch_type'] = 'static'
|
|
dataset_conf['sort'] = False
|
|
test_dataset = AudioDataset(args.test_data,
|
|
**dataset_conf,
|
|
raw_wav=raw_wav)
|
|
test_data_loader = DataLoader(test_dataset,
|
|
collate_fn=test_collate_func,
|
|
shuffle=False,
|
|
batch_size=1,
|
|
num_workers=0)
|
|
|
|
# Init asr model from configs
|
|
model = init_asr_model(configs)
|
|
|
|
# Load dict
|
|
char_dict = {}
|
|
with open(args.dict, 'r') as fin:
|
|
for line in fin:
|
|
arr = line.strip().split()
|
|
assert len(arr) == 2
|
|
char_dict[int(arr[1])] = arr[0]
|
|
eos = len(char_dict) - 1
|
|
|
|
load_checkpoint(model, args.checkpoint)
|
|
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
|
device = torch.device('cuda' if use_cuda else 'cpu')
|
|
model = model.to(device)
|
|
|
|
model.eval()
|
|
with torch.no_grad(), open(args.result_file, 'w') as fout:
|
|
for batch_idx, batch in enumerate(test_data_loader):
|
|
keys, feats, target, feats_lengths, target_lengths = batch
|
|
feats = feats.to(device)
|
|
target = target.to(device)
|
|
feats_lengths = feats_lengths.to(device)
|
|
target_lengths = target_lengths.to(device)
|
|
if args.mode == 'attention':
|
|
hyps = model.recognize(
|
|
feats,
|
|
feats_lengths,
|
|
beam_size=args.beam_size,
|
|
decoding_chunk_size=args.decoding_chunk_size,
|
|
num_decoding_left_chunks=args.num_decoding_left_chunks,
|
|
simulate_streaming=args.simulate_streaming)
|
|
hyps = [hyp.tolist() for hyp in hyps]
|
|
elif args.mode == 'ctc_greedy_search':
|
|
hyps = model.ctc_greedy_search(
|
|
feats,
|
|
feats_lengths,
|
|
decoding_chunk_size=args.decoding_chunk_size,
|
|
num_decoding_left_chunks=args.num_decoding_left_chunks,
|
|
simulate_streaming=args.simulate_streaming)
|
|
# ctc_prefix_beam_search and attention_rescoring only return one
|
|
# result in List[int], change it to List[List[int]] for compatible
|
|
# with other batch decoding mode
|
|
elif args.mode == 'ctc_prefix_beam_search':
|
|
assert (feats.size(0) == 1)
|
|
hyp = model.ctc_prefix_beam_search(
|
|
feats,
|
|
feats_lengths,
|
|
args.beam_size,
|
|
decoding_chunk_size=args.decoding_chunk_size,
|
|
num_decoding_left_chunks=args.num_decoding_left_chunks,
|
|
simulate_streaming=args.simulate_streaming)
|
|
hyps = [hyp]
|
|
elif args.mode == 'attention_rescoring':
|
|
assert (feats.size(0) == 1)
|
|
hyp = model.attention_rescoring(
|
|
feats,
|
|
feats_lengths,
|
|
args.beam_size,
|
|
decoding_chunk_size=args.decoding_chunk_size,
|
|
num_decoding_left_chunks=args.num_decoding_left_chunks,
|
|
ctc_weight=args.ctc_weight,
|
|
simulate_streaming=args.simulate_streaming,
|
|
reverse_weight=args.reverse_weight)
|
|
hyps = [hyp]
|
|
for i, key in enumerate(keys):
|
|
content = ''
|
|
for w in hyps[i]:
|
|
if w == eos:
|
|
break
|
|
content += char_dict[w]
|
|
logging.info('{} {}'.format(key, content))
|
|
fout.write('{} {}\n'.format(key, content))
|