252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | # -*- coding: utf-8 -*- | ||
|  | 
 | ||
|  | # Copyright 2019 Shigeki Karita | ||
|  | #  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0) | ||
|  | 
 | ||
|  | import torch | ||
|  | 
 | ||
|  | 
 | ||
|  | def subsequent_mask( | ||
|  |         size: int, | ||
|  |         device: torch.device = torch.device("cpu"), | ||
|  | ) -> torch.Tensor: | ||
|  |     """Create mask for subsequent steps (size, size).
 | ||
|  | 
 | ||
|  |     This mask is used only in decoder which works in an auto-regressive mode. | ||
|  |     This means the current step could only do attention with its left steps. | ||
|  | 
 | ||
|  |     In encoder, fully attention is used when streaming is not necessary and | ||
|  |     the sequence is not long. In this  case, no attention mask is needed. | ||
|  | 
 | ||
|  |     When streaming is need, chunk-based attention is used in encoder. See | ||
|  |     subsequent_chunk_mask for the chunk-based attention mask. | ||
|  | 
 | ||
|  |     Args: | ||
|  |         size (int): size of mask | ||
|  |         str device (str): "cpu" or "cuda" or torch.Tensor.device | ||
|  |         dtype (torch.device): result dtype | ||
|  | 
 | ||
|  |     Returns: | ||
|  |         torch.Tensor: mask | ||
|  | 
 | ||
|  |     Examples: | ||
|  |         >>> subsequent_mask(3) | ||
|  |         [[1, 0, 0], | ||
|  |          [1, 1, 0], | ||
|  |          [1, 1, 1]] | ||
|  |     """
 | ||
|  |     ret = torch.ones(size, size, device=device, dtype=torch.bool) | ||
|  |     return torch.tril(ret, out=ret) | ||
|  | 
 | ||
|  | 
 | ||
|  | def subsequent_chunk_mask( | ||
|  |         size: int, | ||
|  |         chunk_size: int, | ||
|  |         num_left_chunks: int = -1, | ||
|  |         device: torch.device = torch.device("cpu"), | ||
|  | ) -> torch.Tensor: | ||
|  |     """Create mask for subsequent steps (size, size) with chunk size,
 | ||
|  |        this is for streaming encoder | ||
|  | 
 | ||
|  |     Args: | ||
|  |         size (int): size of mask | ||
|  |         chunk_size (int): size of chunk | ||
|  |         num_left_chunks (int): number of left chunks | ||
|  |             <0: use full chunk | ||
|  |             >=0: use num_left_chunks | ||
|  |         device (torch.device): "cpu" or "cuda" or torch.Tensor.device | ||
|  | 
 | ||
|  |     Returns: | ||
|  |         torch.Tensor: mask | ||
|  | 
 | ||
|  |     Examples: | ||
|  |         >>> subsequent_chunk_mask(4, 2) | ||
|  |         [[1, 1, 0, 0], | ||
|  |          [1, 1, 0, 0], | ||
|  |          [1, 1, 1, 1], | ||
|  |          [1, 1, 1, 1]] | ||
|  |     """
 | ||
|  |     ret = torch.zeros(size, size, device=device, dtype=torch.bool) | ||
|  |     for i in range(size): | ||
|  |         if num_left_chunks < 0: | ||
|  |             start = 0 | ||
|  |         else: | ||
|  |             start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) | ||
|  |         ending = min((i // chunk_size + 1) * chunk_size, size) | ||
|  |         ret[i, start:ending] = True | ||
|  |     return ret | ||
|  | 
 | ||
|  | 
 | ||
|  | def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, | ||
|  |                             use_dynamic_chunk: bool, | ||
|  |                             use_dynamic_left_chunk: bool, | ||
|  |                             decoding_chunk_size: int, static_chunk_size: int, | ||
|  |                             num_decoding_left_chunks: int): | ||
|  |     """ Apply optional mask for encoder.
 | ||
|  | 
 | ||
|  |     Args: | ||
|  |         xs (torch.Tensor): padded input, (B, L, D), L for max length | ||
|  |         mask (torch.Tensor): mask for xs, (B, 1, L) | ||
|  |         use_dynamic_chunk (bool): whether to use dynamic chunk or not | ||
|  |         use_dynamic_left_chunk (bool): whether to use dynamic left chunk for | ||
|  |             training. | ||
|  |         decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's | ||
|  |             0: default for training, use random dynamic chunk. | ||
|  |             <0: for decoding, use full chunk. | ||
|  |             >0: for decoding, use fixed chunk size as set. | ||
|  |         static_chunk_size (int): chunk size for static chunk training/decoding | ||
|  |             if it's greater than 0, if use_dynamic_chunk is true, | ||
|  |             this parameter will be ignored | ||
|  |         num_decoding_left_chunks: number of left chunks, this is for decoding, | ||
|  |             the chunk size is decoding_chunk_size. | ||
|  |             >=0: use num_decoding_left_chunks | ||
|  |             <0: use all left chunks | ||
|  | 
 | ||
|  |     Returns: | ||
|  |         torch.Tensor: chunk mask of the input xs. | ||
|  |     """
 | ||
|  |     # Whether to use chunk mask or not | ||
|  |     if use_dynamic_chunk: | ||
|  |         max_len = xs.size(1) | ||
|  |         if decoding_chunk_size < 0: | ||
|  |             chunk_size = max_len | ||
|  |             num_left_chunks = -1 | ||
|  |         elif decoding_chunk_size > 0: | ||
|  |             chunk_size = decoding_chunk_size | ||
|  |             num_left_chunks = num_decoding_left_chunks | ||
|  |         else: | ||
|  |             # chunk size is either [1, 25] or full context(max_len). | ||
|  |             # Since we use 4 times subsampling and allow up to 1s(100 frames) | ||
|  |             # delay, the maximum frame is 100 / 4 = 25. | ||
|  |             chunk_size = torch.randint(1, max_len, (1, )).item() | ||
|  |             num_left_chunks = -1 | ||
|  |             if chunk_size > max_len // 2: | ||
|  |                 chunk_size = max_len | ||
|  |             else: | ||
|  |                 chunk_size = chunk_size % 25 + 1 | ||
|  |                 if use_dynamic_left_chunk: | ||
|  |                     max_left_chunks = (max_len - 1) // chunk_size | ||
|  |                     num_left_chunks = torch.randint(0, max_left_chunks, | ||
|  |                                                     (1, )).item() | ||
|  |         chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, | ||
|  |                                             num_left_chunks, | ||
|  |                                             xs.device)  # (L, L) | ||
|  |         chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L) | ||
|  |         chunk_masks = masks & chunk_masks  # (B, L, L) | ||
|  |     elif static_chunk_size > 0: | ||
|  |         num_left_chunks = num_decoding_left_chunks | ||
|  |         chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, | ||
|  |                                             num_left_chunks, | ||
|  |                                             xs.device)  # (L, L) | ||
|  |         chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L) | ||
|  |         chunk_masks = masks & chunk_masks  # (B, L, L) | ||
|  |     else: | ||
|  |         chunk_masks = masks | ||
|  |     return chunk_masks | ||
|  | 
 | ||
|  | 
 | ||
|  | def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: | ||
|  |     """Make mask tensor containing indices of padded part.
 | ||
|  | 
 | ||
|  |     See description of make_non_pad_mask. | ||
|  | 
 | ||
|  |     Args: | ||
|  |         lengths (torch.Tensor): Batch of lengths (B,). | ||
|  |     Returns: | ||
|  |         torch.Tensor: Mask tensor containing indices of padded part. | ||
|  | 
 | ||
|  |     Examples: | ||
|  |         >>> lengths = [5, 3, 2] | ||
|  |         >>> make_pad_mask(lengths) | ||
|  |         masks = [[0, 0, 0, 0 ,0], | ||
|  |                  [0, 0, 0, 1, 1], | ||
|  |                  [0, 0, 1, 1, 1]] | ||
|  |     """
 | ||
|  |     batch_size = int(lengths.size(0)) | ||
|  |     max_len = int(lengths.max().item()) | ||
|  |     seq_range = torch.arange(0, | ||
|  |                              max_len, | ||
|  |                              dtype=torch.int64, | ||
|  |                              device=lengths.device) | ||
|  |     seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | ||
|  |     seq_length_expand = lengths.unsqueeze(-1) | ||
|  |     mask = seq_range_expand >= seq_length_expand | ||
|  |     return mask | ||
|  | 
 | ||
|  | 
 | ||
|  | def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor: | ||
|  |     """Make mask tensor containing indices of non-padded part.
 | ||
|  | 
 | ||
|  |     The sequences in a batch may have different lengths. To enable | ||
|  |     batch computing, padding is need to make all sequence in same | ||
|  |     size. To avoid the padding part pass value to context dependent | ||
|  |     block such as attention or convolution , this padding part is | ||
|  |     masked. | ||
|  | 
 | ||
|  |     This pad_mask is used in both encoder and decoder. | ||
|  | 
 | ||
|  |     1 for non-padded part and 0 for padded part. | ||
|  | 
 | ||
|  |     Args: | ||
|  |         lengths (torch.Tensor): Batch of lengths (B,). | ||
|  |     Returns: | ||
|  |         torch.Tensor: mask tensor containing indices of padded part. | ||
|  | 
 | ||
|  |     Examples: | ||
|  |         >>> lengths = [5, 3, 2] | ||
|  |         >>> make_non_pad_mask(lengths) | ||
|  |         masks = [[1, 1, 1, 1 ,1], | ||
|  |                  [1, 1, 1, 0, 0], | ||
|  |                  [1, 1, 0, 0, 0]] | ||
|  |     """
 | ||
|  |     return ~make_pad_mask(lengths) | ||
|  | 
 | ||
|  | 
 | ||
|  | def mask_finished_scores(score: torch.Tensor, | ||
|  |                          flag: torch.Tensor) -> torch.Tensor: | ||
|  |     """
 | ||
|  |     If a sequence is finished, we only allow one alive branch. This function | ||
|  |     aims to give one branch a zero score and the rest -inf score. | ||
|  | 
 | ||
|  |     Args: | ||
|  |         score (torch.Tensor): A real value array with shape | ||
|  |             (batch_size * beam_size, beam_size). | ||
|  |         flag (torch.Tensor): A bool array with shape | ||
|  |             (batch_size * beam_size, 1). | ||
|  | 
 | ||
|  |     Returns: | ||
|  |         torch.Tensor: (batch_size * beam_size, beam_size). | ||
|  |     """
 | ||
|  |     beam_size = score.size(-1) | ||
|  |     zero_mask = torch.zeros_like(flag, dtype=torch.bool) | ||
|  |     if beam_size > 1: | ||
|  |         unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])), | ||
|  |                                dim=1) | ||
|  |         finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])), | ||
|  |                              dim=1) | ||
|  |     else: | ||
|  |         unfinished = zero_mask | ||
|  |         finished = flag | ||
|  |     score.masked_fill_(unfinished, -float('inf')) | ||
|  |     score.masked_fill_(finished, 0) | ||
|  |     return score | ||
|  | 
 | ||
|  | 
 | ||
|  | def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor, | ||
|  |                         eos: int) -> torch.Tensor: | ||
|  |     """
 | ||
|  |     If a sequence is finished, all of its branch should be <eos> | ||
|  | 
 | ||
|  |     Args: | ||
|  |         pred (torch.Tensor): A int array with shape | ||
|  |             (batch_size * beam_size, beam_size). | ||
|  |         flag (torch.Tensor): A bool array with shape | ||
|  |             (batch_size * beam_size, 1). | ||
|  | 
 | ||
|  |     Returns: | ||
|  |         torch.Tensor: (batch_size * beam_size). | ||
|  |     """
 | ||
|  |     beam_size = pred.size(-1) | ||
|  |     finished = flag.repeat([1, beam_size]) | ||
|  |     return pred.masked_fill_(finished, eos) |