252 lines
8.7 KiB
Python
252 lines
8.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2019 Shigeki Karita
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
import torch
|
|
|
|
|
|
def subsequent_mask(
|
|
size: int,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> torch.Tensor:
|
|
"""Create mask for subsequent steps (size, size).
|
|
|
|
This mask is used only in decoder which works in an auto-regressive mode.
|
|
This means the current step could only do attention with its left steps.
|
|
|
|
In encoder, fully attention is used when streaming is not necessary and
|
|
the sequence is not long. In this case, no attention mask is needed.
|
|
|
|
When streaming is need, chunk-based attention is used in encoder. See
|
|
subsequent_chunk_mask for the chunk-based attention mask.
|
|
|
|
Args:
|
|
size (int): size of mask
|
|
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
|
dtype (torch.device): result dtype
|
|
|
|
Returns:
|
|
torch.Tensor: mask
|
|
|
|
Examples:
|
|
>>> subsequent_mask(3)
|
|
[[1, 0, 0],
|
|
[1, 1, 0],
|
|
[1, 1, 1]]
|
|
"""
|
|
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
|
return torch.tril(ret, out=ret)
|
|
|
|
|
|
def subsequent_chunk_mask(
|
|
size: int,
|
|
chunk_size: int,
|
|
num_left_chunks: int = -1,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> torch.Tensor:
|
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
this is for streaming encoder
|
|
|
|
Args:
|
|
size (int): size of mask
|
|
chunk_size (int): size of chunk
|
|
num_left_chunks (int): number of left chunks
|
|
<0: use full chunk
|
|
>=0: use num_left_chunks
|
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
|
|
Returns:
|
|
torch.Tensor: mask
|
|
|
|
Examples:
|
|
>>> subsequent_chunk_mask(4, 2)
|
|
[[1, 1, 0, 0],
|
|
[1, 1, 0, 0],
|
|
[1, 1, 1, 1],
|
|
[1, 1, 1, 1]]
|
|
"""
|
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
|
for i in range(size):
|
|
if num_left_chunks < 0:
|
|
start = 0
|
|
else:
|
|
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
|
ending = min((i // chunk_size + 1) * chunk_size, size)
|
|
ret[i, start:ending] = True
|
|
return ret
|
|
|
|
|
|
def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
|
|
use_dynamic_chunk: bool,
|
|
use_dynamic_left_chunk: bool,
|
|
decoding_chunk_size: int, static_chunk_size: int,
|
|
num_decoding_left_chunks: int):
|
|
""" Apply optional mask for encoder.
|
|
|
|
Args:
|
|
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
|
mask (torch.Tensor): mask for xs, (B, 1, L)
|
|
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
|
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
|
training.
|
|
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
|
0: default for training, use random dynamic chunk.
|
|
<0: for decoding, use full chunk.
|
|
>0: for decoding, use fixed chunk size as set.
|
|
static_chunk_size (int): chunk size for static chunk training/decoding
|
|
if it's greater than 0, if use_dynamic_chunk is true,
|
|
this parameter will be ignored
|
|
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
the chunk size is decoding_chunk_size.
|
|
>=0: use num_decoding_left_chunks
|
|
<0: use all left chunks
|
|
|
|
Returns:
|
|
torch.Tensor: chunk mask of the input xs.
|
|
"""
|
|
# Whether to use chunk mask or not
|
|
if use_dynamic_chunk:
|
|
max_len = xs.size(1)
|
|
if decoding_chunk_size < 0:
|
|
chunk_size = max_len
|
|
num_left_chunks = -1
|
|
elif decoding_chunk_size > 0:
|
|
chunk_size = decoding_chunk_size
|
|
num_left_chunks = num_decoding_left_chunks
|
|
else:
|
|
# chunk size is either [1, 25] or full context(max_len).
|
|
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
|
# delay, the maximum frame is 100 / 4 = 25.
|
|
chunk_size = torch.randint(1, max_len, (1, )).item()
|
|
num_left_chunks = -1
|
|
if chunk_size > max_len // 2:
|
|
chunk_size = max_len
|
|
else:
|
|
chunk_size = chunk_size % 25 + 1
|
|
if use_dynamic_left_chunk:
|
|
max_left_chunks = (max_len - 1) // chunk_size
|
|
num_left_chunks = torch.randint(0, max_left_chunks,
|
|
(1, )).item()
|
|
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
|
num_left_chunks,
|
|
xs.device) # (L, L)
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
elif static_chunk_size > 0:
|
|
num_left_chunks = num_decoding_left_chunks
|
|
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
|
num_left_chunks,
|
|
xs.device) # (L, L)
|
|
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
chunk_masks = masks & chunk_masks # (B, L, L)
|
|
else:
|
|
chunk_masks = masks
|
|
return chunk_masks
|
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|
"""Make mask tensor containing indices of padded part.
|
|
|
|
See description of make_non_pad_mask.
|
|
|
|
Args:
|
|
lengths (torch.Tensor): Batch of lengths (B,).
|
|
Returns:
|
|
torch.Tensor: Mask tensor containing indices of padded part.
|
|
|
|
Examples:
|
|
>>> lengths = [5, 3, 2]
|
|
>>> make_pad_mask(lengths)
|
|
masks = [[0, 0, 0, 0 ,0],
|
|
[0, 0, 0, 1, 1],
|
|
[0, 0, 1, 1, 1]]
|
|
"""
|
|
batch_size = int(lengths.size(0))
|
|
max_len = int(lengths.max().item())
|
|
seq_range = torch.arange(0,
|
|
max_len,
|
|
dtype=torch.int64,
|
|
device=lengths.device)
|
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
seq_length_expand = lengths.unsqueeze(-1)
|
|
mask = seq_range_expand >= seq_length_expand
|
|
return mask
|
|
|
|
|
|
def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|
"""Make mask tensor containing indices of non-padded part.
|
|
|
|
The sequences in a batch may have different lengths. To enable
|
|
batch computing, padding is need to make all sequence in same
|
|
size. To avoid the padding part pass value to context dependent
|
|
block such as attention or convolution , this padding part is
|
|
masked.
|
|
|
|
This pad_mask is used in both encoder and decoder.
|
|
|
|
1 for non-padded part and 0 for padded part.
|
|
|
|
Args:
|
|
lengths (torch.Tensor): Batch of lengths (B,).
|
|
Returns:
|
|
torch.Tensor: mask tensor containing indices of padded part.
|
|
|
|
Examples:
|
|
>>> lengths = [5, 3, 2]
|
|
>>> make_non_pad_mask(lengths)
|
|
masks = [[1, 1, 1, 1 ,1],
|
|
[1, 1, 1, 0, 0],
|
|
[1, 1, 0, 0, 0]]
|
|
"""
|
|
return ~make_pad_mask(lengths)
|
|
|
|
|
|
def mask_finished_scores(score: torch.Tensor,
|
|
flag: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
If a sequence is finished, we only allow one alive branch. This function
|
|
aims to give one branch a zero score and the rest -inf score.
|
|
|
|
Args:
|
|
score (torch.Tensor): A real value array with shape
|
|
(batch_size * beam_size, beam_size).
|
|
flag (torch.Tensor): A bool array with shape
|
|
(batch_size * beam_size, 1).
|
|
|
|
Returns:
|
|
torch.Tensor: (batch_size * beam_size, beam_size).
|
|
"""
|
|
beam_size = score.size(-1)
|
|
zero_mask = torch.zeros_like(flag, dtype=torch.bool)
|
|
if beam_size > 1:
|
|
unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
|
|
dim=1)
|
|
finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
|
|
dim=1)
|
|
else:
|
|
unfinished = zero_mask
|
|
finished = flag
|
|
score.masked_fill_(unfinished, -float('inf'))
|
|
score.masked_fill_(finished, 0)
|
|
return score
|
|
|
|
|
|
def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
|
|
eos: int) -> torch.Tensor:
|
|
"""
|
|
If a sequence is finished, all of its branch should be <eos>
|
|
|
|
Args:
|
|
pred (torch.Tensor): A int array with shape
|
|
(batch_size * beam_size, beam_size).
|
|
flag (torch.Tensor): A bool array with shape
|
|
(batch_size * beam_size, 1).
|
|
|
|
Returns:
|
|
torch.Tensor: (batch_size * beam_size).
|
|
"""
|
|
beam_size = pred.size(-1)
|
|
finished = flag.repeat([1, beam_size])
|
|
return pred.masked_fill_(finished, eos)
|