Source code for openddi.models.ConvLSTM

import torch
import torch.nn as nn

[docs] class ConvLSTM(nn.Module): """ ConvLSTM model combining convolutional and LSTM layers for sequence processing. This model uses 1D convolutional layers followed by LSTM layers to process sequential data, with global pooling and fully connected layers for classification. """
[docs] def __init__(self, feature: int, hidden1: int, hidden2: int, num_relations: int, num_classes: int, dropout: float = 0.2, timesteps=1): """ Initialize ConvLSTM model. Args: feature: Input feature dimension hidden1: First hidden layer dimension for conv and LSTM hidden2: Second hidden layer dimension for conv and LSTM num_relations: Number of relation types (unused in current implementation) num_classes: Number of output classes dropout: Dropout rate for regularization timesteps: Number of timesteps for sequence processing """ super().__init__() self.num_classes = int(num_classes) self.feature = int(feature) self.hidden1 = int(hidden1) self.hidden2 = int(hidden2) self.hidden3 = int(hidden2 / 2) self.dropout_ratio = dropout # First convolution block self.liner1 = nn.Linear(self.feature * 2, int(self.feature)) self.conv1 = nn.Conv1d(in_channels=1, out_channels=self.hidden1, kernel_size=8, stride=8, padding=2) self.bn1 = nn.BatchNorm1d(self.hidden1) self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2, padding=0) self.relu = nn.ReLU() # Second convolution block self.conv2 = nn.Conv1d(in_channels=self.hidden1, out_channels=self.hidden2, kernel_size=4, stride=4, padding=1) self.bn2 = nn.BatchNorm1d(self.hidden2) self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2, padding=0) # Global pooling layers self.global_max_pool = nn.AdaptiveMaxPool1d(1) # Global max pooling self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Global average pooling # LSTM layers self.lstm1 = nn.LSTM(input_size=self.hidden2, hidden_size=self.hidden1, batch_first=True) self.dropout1 = nn.Dropout(self.dropout_ratio) self.lstm2 = nn.LSTM(input_size=self.hidden1, hidden_size=self.hidden2, batch_first=True) self.dropout2 = nn.Dropout(self.dropout_ratio) # 全连接层 self.fc = nn.Linear(self.hidden2 * 2,self.num_classes)
[docs] def forward(self, data_o, idx): """ Forward pass of the ConvLSTM model. Args: data_o: Graph data object containing node features, edge indices, and edge types idx: Batch indices containing pairs of nodes to process Returns: Output logits for the given node pairs """ x_o, edge_index, e_type = data_o.x, data_o.edge_index, data_o.edge_type a_idx = torch.as_tensor(list(idx[0]), dtype=torch.long, device=x_o.device) b_idx = torch.as_tensor(list(idx[1]), dtype=torch.long, device=x_o.device) # Alternative feature concatenation approach (commented out) # xa = x_o[a_idx] # xb = x_o[b_idx] # x = torch.cat([xa, xb], dim=1) # Concatenate features of a and b # x = self.liner1(x) x = x_o.unsqueeze(1) # First convolution block x = self.conv1(x) x = self.relu(x) x = self.bn1(x) x = self.pool1(x) # Second convolution block x = self.conv2(x) x = self.relu(x) x = self.bn2(x) x = self.pool2(x) # LSTM section x = x.permute(0, 2, 1) lstm_out, _ = self.lstm1(x) lstm_out = self.dropout1(lstm_out) lstm_out, _ = self.lstm2(lstm_out) lstm_out = self.dropout2(lstm_out) lstm_out = lstm_out[:, -1, :] # Take the last timestep output # Output layer out1 = lstm_out[a_idx] out2 = lstm_out[b_idx] lstm_out = torch.cat([out1, out2], dim=1) out = self.fc(lstm_out) return out