252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # Copyright 2019 Shigeki Karita
 | |
| #  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| def subsequent_mask(
 | |
|         size: int,
 | |
|         device: torch.device = torch.device("cpu"),
 | |
| ) -> torch.Tensor:
 | |
|     """Create mask for subsequent steps (size, size).
 | |
| 
 | |
|     This mask is used only in decoder which works in an auto-regressive mode.
 | |
|     This means the current step could only do attention with its left steps.
 | |
| 
 | |
|     In encoder, fully attention is used when streaming is not necessary and
 | |
|     the sequence is not long. In this  case, no attention mask is needed.
 | |
| 
 | |
|     When streaming is need, chunk-based attention is used in encoder. See
 | |
|     subsequent_chunk_mask for the chunk-based attention mask.
 | |
| 
 | |
|     Args:
 | |
|         size (int): size of mask
 | |
|         str device (str): "cpu" or "cuda" or torch.Tensor.device
 | |
|         dtype (torch.device): result dtype
 | |
| 
 | |
|     Returns:
 | |
|         torch.Tensor: mask
 | |
| 
 | |
|     Examples:
 | |
|         >>> subsequent_mask(3)
 | |
|         [[1, 0, 0],
 | |
|          [1, 1, 0],
 | |
|          [1, 1, 1]]
 | |
|     """
 | |
|     ret = torch.ones(size, size, device=device, dtype=torch.bool)
 | |
|     return torch.tril(ret, out=ret)
 | |
| 
 | |
| 
 | |
| def subsequent_chunk_mask(
 | |
|         size: int,
 | |
|         chunk_size: int,
 | |
|         num_left_chunks: int = -1,
 | |
|         device: torch.device = torch.device("cpu"),
 | |
| ) -> torch.Tensor:
 | |
|     """Create mask for subsequent steps (size, size) with chunk size,
 | |
|        this is for streaming encoder
 | |
| 
 | |
|     Args:
 | |
|         size (int): size of mask
 | |
|         chunk_size (int): size of chunk
 | |
|         num_left_chunks (int): number of left chunks
 | |
|             <0: use full chunk
 | |
|             >=0: use num_left_chunks
 | |
|         device (torch.device): "cpu" or "cuda" or torch.Tensor.device
 | |
| 
 | |
|     Returns:
 | |
|         torch.Tensor: mask
 | |
| 
 | |
|     Examples:
 | |
|         >>> subsequent_chunk_mask(4, 2)
 | |
|         [[1, 1, 0, 0],
 | |
|          [1, 1, 0, 0],
 | |
|          [1, 1, 1, 1],
 | |
|          [1, 1, 1, 1]]
 | |
|     """
 | |
|     ret = torch.zeros(size, size, device=device, dtype=torch.bool)
 | |
|     for i in range(size):
 | |
|         if num_left_chunks < 0:
 | |
|             start = 0
 | |
|         else:
 | |
|             start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
 | |
|         ending = min((i // chunk_size + 1) * chunk_size, size)
 | |
|         ret[i, start:ending] = True
 | |
|     return ret
 | |
| 
 | |
| 
 | |
| def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
 | |
|                             use_dynamic_chunk: bool,
 | |
|                             use_dynamic_left_chunk: bool,
 | |
|                             decoding_chunk_size: int, static_chunk_size: int,
 | |
|                             num_decoding_left_chunks: int):
 | |
|     """ Apply optional mask for encoder.
 | |
| 
 | |
|     Args:
 | |
|         xs (torch.Tensor): padded input, (B, L, D), L for max length
 | |
|         mask (torch.Tensor): mask for xs, (B, 1, L)
 | |
|         use_dynamic_chunk (bool): whether to use dynamic chunk or not
 | |
|         use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
 | |
|             training.
 | |
|         decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
 | |
|             0: default for training, use random dynamic chunk.
 | |
|             <0: for decoding, use full chunk.
 | |
|             >0: for decoding, use fixed chunk size as set.
 | |
|         static_chunk_size (int): chunk size for static chunk training/decoding
 | |
|             if it's greater than 0, if use_dynamic_chunk is true,
 | |
|             this parameter will be ignored
 | |
|         num_decoding_left_chunks: number of left chunks, this is for decoding,
 | |
|             the chunk size is decoding_chunk_size.
 | |
|             >=0: use num_decoding_left_chunks
 | |
|             <0: use all left chunks
 | |
| 
 | |
|     Returns:
 | |
|         torch.Tensor: chunk mask of the input xs.
 | |
|     """
 | |
|     # Whether to use chunk mask or not
 | |
|     if use_dynamic_chunk:
 | |
|         max_len = xs.size(1)
 | |
|         if decoding_chunk_size < 0:
 | |
|             chunk_size = max_len
 | |
|             num_left_chunks = -1
 | |
|         elif decoding_chunk_size > 0:
 | |
|             chunk_size = decoding_chunk_size
 | |
|             num_left_chunks = num_decoding_left_chunks
 | |
|         else:
 | |
|             # chunk size is either [1, 25] or full context(max_len).
 | |
|             # Since we use 4 times subsampling and allow up to 1s(100 frames)
 | |
|             # delay, the maximum frame is 100 / 4 = 25.
 | |
|             chunk_size = torch.randint(1, max_len, (1, )).item()
 | |
|             num_left_chunks = -1
 | |
|             if chunk_size > max_len // 2:
 | |
|                 chunk_size = max_len
 | |
|             else:
 | |
|                 chunk_size = chunk_size % 25 + 1
 | |
|                 if use_dynamic_left_chunk:
 | |
|                     max_left_chunks = (max_len - 1) // chunk_size
 | |
|                     num_left_chunks = torch.randint(0, max_left_chunks,
 | |
|                                                     (1, )).item()
 | |
|         chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
 | |
|                                             num_left_chunks,
 | |
|                                             xs.device)  # (L, L)
 | |
|         chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
 | |
|         chunk_masks = masks & chunk_masks  # (B, L, L)
 | |
|     elif static_chunk_size > 0:
 | |
|         num_left_chunks = num_decoding_left_chunks
 | |
|         chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
 | |
|                                             num_left_chunks,
 | |
|                                             xs.device)  # (L, L)
 | |
|         chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
 | |
|         chunk_masks = masks & chunk_masks  # (B, L, L)
 | |
|     else:
 | |
|         chunk_masks = masks
 | |
|     return chunk_masks
 | |
| 
 | |
| 
 | |
| def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
 | |
|     """Make mask tensor containing indices of padded part.
 | |
| 
 | |
|     See description of make_non_pad_mask.
 | |
| 
 | |
|     Args:
 | |
|         lengths (torch.Tensor): Batch of lengths (B,).
 | |
|     Returns:
 | |
|         torch.Tensor: Mask tensor containing indices of padded part.
 | |
| 
 | |
|     Examples:
 | |
|         >>> lengths = [5, 3, 2]
 | |
|         >>> make_pad_mask(lengths)
 | |
|         masks = [[0, 0, 0, 0 ,0],
 | |
|                  [0, 0, 0, 1, 1],
 | |
|                  [0, 0, 1, 1, 1]]
 | |
|     """
 | |
|     batch_size = int(lengths.size(0))
 | |
|     max_len = int(lengths.max().item())
 | |
|     seq_range = torch.arange(0,
 | |
|                              max_len,
 | |
|                              dtype=torch.int64,
 | |
|                              device=lengths.device)
 | |
|     seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
 | |
|     seq_length_expand = lengths.unsqueeze(-1)
 | |
|     mask = seq_range_expand >= seq_length_expand
 | |
|     return mask
 | |
| 
 | |
| 
 | |
| def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
 | |
|     """Make mask tensor containing indices of non-padded part.
 | |
| 
 | |
|     The sequences in a batch may have different lengths. To enable
 | |
|     batch computing, padding is need to make all sequence in same
 | |
|     size. To avoid the padding part pass value to context dependent
 | |
|     block such as attention or convolution , this padding part is
 | |
|     masked.
 | |
| 
 | |
|     This pad_mask is used in both encoder and decoder.
 | |
| 
 | |
|     1 for non-padded part and 0 for padded part.
 | |
| 
 | |
|     Args:
 | |
|         lengths (torch.Tensor): Batch of lengths (B,).
 | |
|     Returns:
 | |
|         torch.Tensor: mask tensor containing indices of padded part.
 | |
| 
 | |
|     Examples:
 | |
|         >>> lengths = [5, 3, 2]
 | |
|         >>> make_non_pad_mask(lengths)
 | |
|         masks = [[1, 1, 1, 1 ,1],
 | |
|                  [1, 1, 1, 0, 0],
 | |
|                  [1, 1, 0, 0, 0]]
 | |
|     """
 | |
|     return ~make_pad_mask(lengths)
 | |
| 
 | |
| 
 | |
| def mask_finished_scores(score: torch.Tensor,
 | |
|                          flag: torch.Tensor) -> torch.Tensor:
 | |
|     """
 | |
|     If a sequence is finished, we only allow one alive branch. This function
 | |
|     aims to give one branch a zero score and the rest -inf score.
 | |
| 
 | |
|     Args:
 | |
|         score (torch.Tensor): A real value array with shape
 | |
|             (batch_size * beam_size, beam_size).
 | |
|         flag (torch.Tensor): A bool array with shape
 | |
|             (batch_size * beam_size, 1).
 | |
| 
 | |
|     Returns:
 | |
|         torch.Tensor: (batch_size * beam_size, beam_size).
 | |
|     """
 | |
|     beam_size = score.size(-1)
 | |
|     zero_mask = torch.zeros_like(flag, dtype=torch.bool)
 | |
|     if beam_size > 1:
 | |
|         unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
 | |
|                                dim=1)
 | |
|         finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
 | |
|                              dim=1)
 | |
|     else:
 | |
|         unfinished = zero_mask
 | |
|         finished = flag
 | |
|     score.masked_fill_(unfinished, -float('inf'))
 | |
|     score.masked_fill_(finished, 0)
 | |
|     return score
 | |
| 
 | |
| 
 | |
| def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
 | |
|                         eos: int) -> torch.Tensor:
 | |
|     """
 | |
|     If a sequence is finished, all of its branch should be <eos>
 | |
| 
 | |
|     Args:
 | |
|         pred (torch.Tensor): A int array with shape
 | |
|             (batch_size * beam_size, beam_size).
 | |
|         flag (torch.Tensor): A bool array with shape
 | |
|             (batch_size * beam_size, 1).
 | |
| 
 | |
|     Returns:
 | |
|         torch.Tensor: (batch_size * beam_size).
 | |
|     """
 | |
|     beam_size = pred.size(-1)
 | |
|     finished = flag.repeat([1, beam_size])
 | |
|     return pred.masked_fill_(finished, eos)
 | 
