sketchgraphs_models.graph.model.EdgePartnerNetwork

class sketchgraphs_models.graph.model.EdgePartnerNetwork(readout_net)

Predicts a probability for a new edge from a node to the last node in the graph.

__init__(readout_net)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(node_embedding, graph_embedding, graph)

Computes the logits at each node associated with the introducing the edge between that node and the target node in its associated graph.

Note that the target specified by target_idx[i] is compared to its corresponding nodes in the ith graph.

Parameters
  • node_embedding (tensor representing node embedding.) –

  • graph_embedding (tensor representing graph embedding.) –

  • graph (object describing the graph structure.) –

Returns

The log-probability at each node for the corresponding edge

Return type

torch.Tensor