sketchgraphs_models.nn.MessagePassingNetwork¶
-
class
sketchgraphs_models.nn.
MessagePassingNetwork
(depth, message_aggregation_network, transform_edge_messages=None)¶ Custom configurable message-passing network.
This class implements the main plumbing for a message passing network. but exposes points that can be configured to easily create different variants of the networks.
-
__init__
(depth, message_aggregation_network, transform_edge_messages=None)¶ Creates a new module representing the message passing network. :param depth: number of message passing iterations to execute. :type depth: int :param message_aggregation_network: A module representing the model used to compute the embeddings
to be used at the next step. This model receives the array of messages corresponding to the sum of the propagated messages, and the array of previous node embeddings.
- Parameters
transform_edge_messages (torch.nn.Module) – A module representing the model used to transform edge messages at each step. See
aggregate_by_incidence
.
-
forward
(node_embedding, incidence, edge_transform_args=None)¶ Forward function for the message passing network.
- Parameters
node_embedding (torch.Tensor) – Tensor of shape
[num_nodes, ...]
representing the data at each node in the graph.incidence (torch.Tensor) – tensor of shape
[2, num_edges]
representing edge incidence in the graphedge_transform_args (any) – A tuple of further arguments to be passed to the edge transformation network.
- Returns
The final node embedding values after the message passing has been carried out.
- Return type
-