187 lines
5.4 KiB
Python
187 lines
5.4 KiB
Python
![]() |
"""Unility functions for Transformer."""
|
||
|
|
||
|
import math
|
||
|
from typing import Tuple, List
|
||
|
|
||
|
import torch
|
||
|
from torch.nn.utils.rnn import pad_sequence
|
||
|
|
||
|
IGNORE_ID = -1
|
||
|
|
||
|
|
||
|
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
||
|
"""Perform padding for the list of tensors.
|
||
|
|
||
|
Args:
|
||
|
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||
|
pad_value (float): Value for padding.
|
||
|
|
||
|
Returns:
|
||
|
Tensor: Padded tensor (B, Tmax, `*`).
|
||
|
|
||
|
Examples:
|
||
|
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||
|
>>> x
|
||
|
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||
|
>>> pad_list(x, 0)
|
||
|
tensor([[1., 1., 1., 1.],
|
||
|
[1., 1., 0., 0.],
|
||
|
[1., 0., 0., 0.]])
|
||
|
|
||
|
"""
|
||
|
n_batch = len(xs)
|
||
|
max_len = max([x.size(0) for x in xs])
|
||
|
pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)
|
||
|
pad = pad.fill_(pad_value)
|
||
|
for i in range(n_batch):
|
||
|
pad[i, :xs[i].size(0)] = xs[i]
|
||
|
|
||
|
return pad
|
||
|
|
||
|
|
||
|
def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,
|
||
|
ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Add <sos> and <eos> labels.
|
||
|
|
||
|
Args:
|
||
|
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
|
||
|
sos (int): index of <sos>
|
||
|
eos (int): index of <eeos>
|
||
|
ignore_id (int): index of padding
|
||
|
|
||
|
Returns:
|
||
|
ys_in (torch.Tensor) : (B, Lmax + 1)
|
||
|
ys_out (torch.Tensor) : (B, Lmax + 1)
|
||
|
|
||
|
Examples:
|
||
|
>>> sos_id = 10
|
||
|
>>> eos_id = 11
|
||
|
>>> ignore_id = -1
|
||
|
>>> ys_pad
|
||
|
tensor([[ 1, 2, 3, 4, 5],
|
||
|
[ 4, 5, 6, -1, -1],
|
||
|
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
|
||
|
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
|
||
|
>>> ys_in
|
||
|
tensor([[10, 1, 2, 3, 4, 5],
|
||
|
[10, 4, 5, 6, 11, 11],
|
||
|
[10, 7, 8, 9, 11, 11]])
|
||
|
>>> ys_out
|
||
|
tensor([[ 1, 2, 3, 4, 5, 11],
|
||
|
[ 4, 5, 6, 11, -1, -1],
|
||
|
[ 7, 8, 9, 11, -1, -1]])
|
||
|
"""
|
||
|
_sos = torch.tensor([sos],
|
||
|
dtype=torch.long,
|
||
|
requires_grad=False,
|
||
|
device=ys_pad.device)
|
||
|
_eos = torch.tensor([eos],
|
||
|
dtype=torch.long,
|
||
|
requires_grad=False,
|
||
|
device=ys_pad.device)
|
||
|
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
|
||
|
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
|
||
|
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
|
||
|
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
|
||
|
|
||
|
|
||
|
def reverse_pad_list(ys_pad: torch.Tensor,
|
||
|
ys_lens: torch.Tensor,
|
||
|
pad_value: float = -1.0) -> torch.Tensor:
|
||
|
"""Reverse padding for the list of tensors.
|
||
|
|
||
|
Args:
|
||
|
ys_pad (tensor): The padded tensor (B, Tokenmax).
|
||
|
ys_lens (tensor): The lens of token seqs (B)
|
||
|
pad_value (int): Value for padding.
|
||
|
|
||
|
Returns:
|
||
|
Tensor: Padded tensor (B, Tokenmax).
|
||
|
|
||
|
Examples:
|
||
|
>>> x
|
||
|
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
|
||
|
>>> pad_list(x, 0)
|
||
|
tensor([[4, 3, 2, 1],
|
||
|
[7, 6, 5, 0],
|
||
|
[9, 8, 0, 0]])
|
||
|
|
||
|
"""
|
||
|
r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0]))
|
||
|
for y, i in zip(ys_pad, ys_lens)], True,
|
||
|
pad_value)
|
||
|
return r_ys_pad
|
||
|
|
||
|
|
||
|
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
|
||
|
ignore_label: int) -> float:
|
||
|
"""Calculate accuracy.
|
||
|
|
||
|
Args:
|
||
|
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||
|
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||
|
ignore_label (int): Ignore label id.
|
||
|
|
||
|
Returns:
|
||
|
float: Accuracy value (0.0 - 1.0).
|
||
|
|
||
|
"""
|
||
|
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
|
||
|
pad_outputs.size(1)).argmax(2)
|
||
|
mask = pad_targets != ignore_label
|
||
|
numerator = torch.sum(
|
||
|
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||
|
denominator = torch.sum(mask)
|
||
|
return float(numerator) / float(denominator)
|
||
|
|
||
|
|
||
|
def get_activation(act):
|
||
|
"""Return activation function."""
|
||
|
# Lazy load to avoid unused import
|
||
|
from wenet.transformer.swish import Swish
|
||
|
|
||
|
activation_funcs = {
|
||
|
"hardtanh": torch.nn.Hardtanh,
|
||
|
"tanh": torch.nn.Tanh,
|
||
|
"relu": torch.nn.ReLU,
|
||
|
"selu": torch.nn.SELU,
|
||
|
"swish": Swish,
|
||
|
"gelu": torch.nn.GELU
|
||
|
}
|
||
|
|
||
|
return activation_funcs[act]()
|
||
|
|
||
|
|
||
|
def get_subsample(config):
|
||
|
input_layer = config["encoder_conf"]["input_layer"]
|
||
|
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
|
||
|
if input_layer == "conv2d":
|
||
|
return 4
|
||
|
elif input_layer == "conv2d6":
|
||
|
return 6
|
||
|
elif input_layer == "conv2d8":
|
||
|
return 8
|
||
|
|
||
|
|
||
|
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||
|
new_hyp: List[int] = []
|
||
|
cur = 0
|
||
|
while cur < len(hyp):
|
||
|
if hyp[cur] != 0:
|
||
|
new_hyp.append(hyp[cur])
|
||
|
prev = cur
|
||
|
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||
|
cur += 1
|
||
|
return new_hyp
|
||
|
|
||
|
|
||
|
def log_add(args: List[int]) -> float:
|
||
|
"""
|
||
|
Stable log add
|
||
|
"""
|
||
|
if all(a == -float('inf') for a in args):
|
||
|
return -float('inf')
|
||
|
a_max = max(args)
|
||
|
lsp = math.log(sum(math.exp(a - a_max) for a in args))
|
||
|
return a_max + lsp
|