135 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			135 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # Copyright 2021 Mobvoi Inc. All Rights Reserved.
 | |
| # Author: di.wu@mobvoi.com (DI WU)
 | |
| """ConvolutionModule definition."""
 | |
| 
 | |
| from typing import Optional, Tuple
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| from typeguard import check_argument_types
 | |
| 
 | |
| 
 | |
| class ConvolutionModule(nn.Module):
 | |
|     """ConvolutionModule in Conformer model."""
 | |
|     def __init__(self,
 | |
|                  channels: int,
 | |
|                  kernel_size: int = 15,
 | |
|                  activation: nn.Module = nn.ReLU(),
 | |
|                  norm: str = "batch_norm",
 | |
|                  causal: bool = False,
 | |
|                  bias: bool = True):
 | |
|         """Construct an ConvolutionModule object.
 | |
|         Args:
 | |
|             channels (int): The number of channels of conv layers.
 | |
|             kernel_size (int): Kernel size of conv layers.
 | |
|             causal (int): Whether use causal convolution or not
 | |
|         """
 | |
|         assert check_argument_types()
 | |
|         super().__init__()
 | |
| 
 | |
|         self.pointwise_conv1 = nn.Conv1d(
 | |
|             channels,
 | |
|             2 * channels,
 | |
|             kernel_size=1,
 | |
|             stride=1,
 | |
|             padding=0,
 | |
|             bias=bias,
 | |
|         )
 | |
|         # self.lorder is used to distinguish if it's a causal convolution,
 | |
|         # if self.lorder > 0: it's a causal convolution, the input will be
 | |
|         #    padded with self.lorder frames on the left in forward.
 | |
|         # else: it's a symmetrical convolution
 | |
|         if causal:
 | |
|             padding = 0
 | |
|             self.lorder = kernel_size - 1
 | |
|         else:
 | |
|             # kernel_size should be an odd number for none causal convolution
 | |
|             assert (kernel_size - 1) % 2 == 0
 | |
|             padding = (kernel_size - 1) // 2
 | |
|             self.lorder = 0
 | |
|         self.depthwise_conv = nn.Conv1d(
 | |
|             channels,
 | |
|             channels,
 | |
|             kernel_size,
 | |
|             stride=1,
 | |
|             padding=padding,
 | |
|             groups=channels,
 | |
|             bias=bias,
 | |
|         )
 | |
| 
 | |
|         assert norm in ['batch_norm', 'layer_norm']
 | |
|         if norm == "batch_norm":
 | |
|             self.use_layer_norm = False
 | |
|             self.norm = nn.BatchNorm1d(channels)
 | |
|         else:
 | |
|             self.use_layer_norm = True
 | |
|             self.norm = nn.LayerNorm(channels)
 | |
| 
 | |
|         self.pointwise_conv2 = nn.Conv1d(
 | |
|             channels,
 | |
|             channels,
 | |
|             kernel_size=1,
 | |
|             stride=1,
 | |
|             padding=0,
 | |
|             bias=bias,
 | |
|         )
 | |
|         self.activation = activation
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         x: torch.Tensor,
 | |
|         mask_pad: Optional[torch.Tensor] = None,
 | |
|         cache: Optional[torch.Tensor] = None,
 | |
|     ) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """Compute convolution module.
 | |
|         Args:
 | |
|             x (torch.Tensor): Input tensor (#batch, time, channels).
 | |
|             mask_pad (torch.Tensor): used for batch padding (#batch, 1, time)
 | |
|             cache (torch.Tensor): left context cache, it is only
 | |
|                 used in causal convolution
 | |
|         Returns:
 | |
|             torch.Tensor: Output tensor (#batch, time, channels).
 | |
|         """
 | |
|         # exchange the temporal dimension and the feature dimension
 | |
|         x = x.transpose(1, 2)  # (#batch, channels, time)
 | |
| 
 | |
|         # mask batch padding
 | |
|         if mask_pad is not None:
 | |
|             x.masked_fill_(~mask_pad, 0.0)
 | |
| 
 | |
|         if self.lorder > 0:
 | |
|             if cache is None:
 | |
|                 x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
 | |
|             else:
 | |
|                 assert cache.size(0) == x.size(0)
 | |
|                 assert cache.size(1) == x.size(1)
 | |
|                 x = torch.cat((cache, x), dim=2)
 | |
|             assert (x.size(2) > self.lorder)
 | |
|             new_cache = x[:, :, -self.lorder:]
 | |
|         else:
 | |
|             # It's better we just return None if no cache is requried,
 | |
|             # However, for JIT export, here we just fake one tensor instead of
 | |
|             # None.
 | |
|             new_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
 | |
| 
 | |
|         # GLU mechanism
 | |
|         x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
 | |
|         x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
 | |
| 
 | |
|         # 1D Depthwise Conv
 | |
|         x = self.depthwise_conv(x)
 | |
|         if self.use_layer_norm:
 | |
|             x = x.transpose(1, 2)
 | |
|         x = self.activation(self.norm(x))
 | |
|         if self.use_layer_norm:
 | |
|             x = x.transpose(1, 2)
 | |
|         x = self.pointwise_conv2(x)
 | |
|         # mask batch padding
 | |
|         if mask_pad is not None:
 | |
|             x.masked_fill_(~mask_pad, 0.0)
 | |
| 
 | |
|         return x.transpose(1, 2), new_cache
 | 
