sketchgraphs_models.nn.MessagePassingNetwork

class sketchgraphs_models.nn.MessagePassingNetwork(depth, message_aggregation_network, transform_edge_messages=None)

Custom configurable message-passing network.

This class implements the main plumbing for a message passing network. but exposes points that can be configured to easily create different variants of the networks.

__init__(depth, message_aggregation_network, transform_edge_messages=None)

Creates a new module representing the message passing network. :param depth: number of message passing iterations to execute. :type depth: int :param message_aggregation_network: A module representing the model used to compute the embeddings

to be used at the next step. This model receives the array of messages corresponding to the sum of the propagated messages, and the array of previous node embeddings.

Parameters

transform_edge_messages (torch.nn.Module) – A module representing the model used to transform edge messages at each step. See aggregate_by_incidence.

forward(node_embedding, incidence, edge_transform_args=None)

Forward function for the message passing network.

Parameters
  • node_embedding (torch.Tensor) – Tensor of shape [num_nodes, ...] representing the data at each node in the graph.

  • incidence (torch.Tensor) – tensor of shape [2, num_edges] representing edge incidence in the graph

  • edge_transform_args (any) – A tuple of further arguments to be passed to the edge transformation network.

Returns

The final node embedding values after the message passing has been carried out.

Return type

torch.Tensor