44 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			44 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # Copyright 2019 Shigeki Karita
 | |
| #  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 | |
| """Positionwise feed forward layer definition."""
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| class PositionwiseFeedForward(torch.nn.Module):
 | |
|     """Positionwise feed forward layer.
 | |
| 
 | |
|     FeedForward are appied on each position of the sequence.
 | |
|     The output dim is same with the input dim.
 | |
| 
 | |
|     Args:
 | |
|         idim (int): Input dimenstion.
 | |
|         hidden_units (int): The number of hidden units.
 | |
|         dropout_rate (float): Dropout rate.
 | |
|         activation (torch.nn.Module): Activation function
 | |
|     """
 | |
|     def __init__(self,
 | |
|                  idim: int,
 | |
|                  hidden_units: int,
 | |
|                  dropout_rate: float,
 | |
|                  activation: torch.nn.Module = torch.nn.ReLU()):
 | |
|         """Construct a PositionwiseFeedForward object."""
 | |
|         super(PositionwiseFeedForward, self).__init__()
 | |
|         self.w_1 = torch.nn.Linear(idim, hidden_units)
 | |
|         self.activation = activation
 | |
|         self.dropout = torch.nn.Dropout(dropout_rate)
 | |
|         self.w_2 = torch.nn.Linear(hidden_units, idim)
 | |
| 
 | |
|     def forward(self, xs: torch.Tensor) -> torch.Tensor:
 | |
|         """Forward function.
 | |
| 
 | |
|         Args:
 | |
|             xs: input tensor (B, L, D)
 | |
|         Returns:
 | |
|             output tensor, (B, L, D)
 | |
|         """
 | |
|         return self.w_2(self.dropout(self.activation(self.w_1(xs))))
 | 
