144 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			144 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import torch  | ||
|  | from torch import nn | ||
|  | 
 | ||
|  | class LSTMDecoder(nn.Module): | ||
|  |     '''
 | ||
|  |     Defines the LSTM decoder | ||
|  | 
 | ||
|  |     This class combines day-specific input layers, an LSTM, and an output classification layer | ||
|  |     '''
 | ||
|  |     def __init__(self, | ||
|  |                  neural_dim, | ||
|  |                  n_units, | ||
|  |                  n_days, | ||
|  |                  n_classes, | ||
|  |                  rnn_dropout = 0.0, | ||
|  |                  input_dropout = 0.0, | ||
|  |                  n_layers = 5,  | ||
|  |                  patch_size = 0, | ||
|  |                  patch_stride = 0, | ||
|  |                  ): | ||
|  |         '''
 | ||
|  |         neural_dim  (int)      - number of channels in a single timestep (e.g. 512) | ||
|  |         n_units     (int)      - number of hidden units in each recurrent layer - equal to the size of the hidden state | ||
|  |         n_days      (int)      - number of days in the dataset | ||
|  |         n_classes   (int)      - number of classes  | ||
|  |         rnn_dropout    (float) - percentage of units to droupout during training | ||
|  |         input_dropout (float)  - percentage of input units to dropout during training | ||
|  |         n_layers    (int)      - number of recurrent layers  | ||
|  |         patch_size  (int)      - the number of timesteps to concat on initial input layer - a value of 0 will disable this "input concat" step  | ||
|  |         patch_stride(int)      - the number of timesteps to stride over when concatenating initial input  | ||
|  |         '''
 | ||
|  |         super(LSTMDecoder, self).__init__() | ||
|  |          | ||
|  |         self.neural_dim = neural_dim | ||
|  |         self.n_units = n_units | ||
|  |         self.n_classes = n_classes | ||
|  |         self.n_layers = n_layers  | ||
|  |         self.n_days = n_days | ||
|  | 
 | ||
|  |         self.rnn_dropout = rnn_dropout | ||
|  |         self.input_dropout = input_dropout | ||
|  |          | ||
|  |         self.patch_size = patch_size | ||
|  |         self.patch_stride = patch_stride | ||
|  | 
 | ||
|  |         # Parameters for the day-specific input layers | ||
|  |         self.day_layer_activation = nn.Softsign() # basically a shallower tanh  | ||
|  | 
 | ||
|  |         # Set weights for day layers to be identity matrices so the model can learn its own day-specific transformations | ||
|  |         self.day_weights = nn.ParameterList( | ||
|  |             [nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)] | ||
|  |         ) | ||
|  |         self.day_biases = nn.ParameterList( | ||
|  |             [nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)] | ||
|  |         ) | ||
|  | 
 | ||
|  |         self.day_layer_dropout = nn.Dropout(input_dropout) | ||
|  |          | ||
|  |         self.input_size = self.neural_dim | ||
|  | 
 | ||
|  |         # If we are using "strided inputs", then the input size of the first recurrent layer will actually be in_size * patch_size | ||
|  |         if self.patch_size > 0: | ||
|  |             self.input_size *= self.patch_size | ||
|  | 
 | ||
|  |         self.lstm = nn.LSTM( | ||
|  |             input_size = self.input_size, | ||
|  |             hidden_size = self.n_units, | ||
|  |             num_layers = self.n_layers, | ||
|  |             dropout = self.rnn_dropout,  | ||
|  |             batch_first = True, # The first dim of our input is the batch dim | ||
|  |             bidirectional = False, | ||
|  |         ) | ||
|  | 
 | ||
|  |         # Set recurrent units to have orthogonal param init and input layers to have xavier init | ||
|  |         for name, param in self.lstm.named_parameters(): | ||
|  |             if "weight_hh" in name: | ||
|  |                 nn.init.orthogonal_(param) | ||
|  |             elif "weight_ih" in name: | ||
|  |                 nn.init.xavier_uniform_(param) | ||
|  |             elif "bias" in name: | ||
|  |                 # Initialize biases to zero first | ||
|  |                 nn.init.zeros_(param) | ||
|  |                 # Set forget gate bias to 1.0 to prevent vanishing gradients | ||
|  |                 # LSTM bias structure: [input_gate, forget_gate, cell_gate, output_gate] | ||
|  |                 # Each gate has hidden_size parameters | ||
|  |                 hidden_size = param.size(0) // 4 | ||
|  |                 param.data[hidden_size:2*hidden_size].fill_(1.0)  # forget gate bias = 1.0 | ||
|  | 
 | ||
|  |         # Prediciton head. Weight init to xavier | ||
|  |         self.out = nn.Linear(self.n_units, self.n_classes) | ||
|  |         nn.init.xavier_uniform_(self.out.weight) | ||
|  | 
 | ||
|  |         # Learnable initial hidden states | ||
|  |         self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) | ||
|  | 
 | ||
|  |     def forward(self, x, day_idx, states = None, return_state = False): | ||
|  |         '''
 | ||
|  |         x        (tensor)  - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim) | ||
|  |         day_idx  (tensor)  - tensor which is a list of day indexs corresponding to the day of each example in the batch x.  | ||
|  |         '''
 | ||
|  | 
 | ||
|  |         # Apply day-specific layer to (hopefully) project neural data from the different days to the same latent space | ||
|  |         day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0) | ||
|  |         day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1) | ||
|  | 
 | ||
|  |         x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases | ||
|  |         x = self.day_layer_activation(x) | ||
|  | 
 | ||
|  |         # Apply dropout to the ouput of the day specific layer | ||
|  |         if self.input_dropout > 0: | ||
|  |             x = self.day_layer_dropout(x) | ||
|  | 
 | ||
|  |         # (Optionally) Perform input concat operation | ||
|  |         if self.patch_size > 0:  | ||
|  |    | ||
|  |             x = x.unsqueeze(1)                      # [batches, 1, timesteps, feature_dim] | ||
|  |             x = x.permute(0, 3, 1, 2)               # [batches, feature_dim, 1, timesteps] | ||
|  |              | ||
|  |             # Extract patches using unfold (sliding window) | ||
|  |             x_unfold = x.unfold(3, self.patch_size, self.patch_stride)  # [batches, feature_dim, 1, num_patches, patch_size] | ||
|  |              | ||
|  |             # Remove dummy height dimension and rearrange dimensions | ||
|  |             x_unfold = x_unfold.squeeze(2)           # [batches, feature_dum, num_patches, patch_size] | ||
|  |             x_unfold = x_unfold.permute(0, 2, 3, 1)  # [batches, num_patches, patch_size, feature_dim] | ||
|  | 
 | ||
|  |             # Flatten last two dimensions (patch_size and features) | ||
|  |             x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)  | ||
|  |          | ||
|  |         # Determine initial hidden states | ||
|  |         if states is None: | ||
|  |             h0 = self.h0.expand(self.n_layers, x.shape[0], self.n_units).contiguous() | ||
|  |             c0 = torch.zeros_like(h0)  # Initialize cell state to zeros | ||
|  |             states = (h0, c0) | ||
|  | 
 | ||
|  |         # Pass input through RNN  | ||
|  |         output, hidden_states = self.lstm(x, states) | ||
|  | 
 | ||
|  |         # Compute logits | ||
|  |         logits = self.out(output) | ||
|  |          | ||
|  |         if return_state: | ||
|  |             return logits, hidden_states | ||
|  |          | ||
|  |         return logits |