sketchgraphs_models.graph.sample.GraphSamplingModel

class sketchgraphs_models.graph.sample.GraphSamplingModel(model_core, entity_label, entity_feature_readout, edge_post_embedding, edge_label, edge_feature_readout, edge_partner, feature_dimensions)
__init__(model_core, entity_label, entity_feature_readout, edge_post_embedding, edge_label, edge_feature_readout, edge_partner, feature_dimensions)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

static from_numerical_model(model: sketchgraphs_models.graph.model.GraphModel, feature_dimensions)

Creates a new sampling model from the given numerical model.

sample_edge_label(graph, targets, generator=None)

Samples edge labels.

This function samples edge label types for given edges between the last node in the graph and the specified target node.

Parameters
  • graph (GraphInfo) – the graph (or batch of graphs) for which to obtain edge labels

  • targets (torch.Tensor) – A tensor indicating the indices of the edges

  • generator (torch.Generator, optional) – Optional PRNG to use for sampling

Returns

An integer tensor containing the sampled edge labels.

Return type

torch.Tensor

sample_edge_target(graph, generator=None)

Samples edge targets.

This function samples potential edges between the last node and a given target. The sampled target may potentially be one past the end of the nodes, indicating that no new edges should be created.

Parameters
  • graph (GraphInfo) – the graph (or batch of graphs) for which to obtain node labels.

  • generator (torch.Generator, optional) – Optional PRNG to use for sampling

Returns

An integer tensor containing the sampled edge target for each graph in the batch.

Return type

torch.Tensor

sample_entity_features(graph, target_type, features, generator=None)

Samples entity features.

Parameters
  • graph (GraphInfo) – The graph (or batch of graphs) for which to obtain node features

  • target_type (TargetType) – The target type / label of the entity for which to sample features

  • features (torch.Tensor) – An integer tensor corresponding to the entity features

  • generator (torch.Generator, optional) – Optional PRNG to use for sampling

Returns

An integer tensor containing the sampled node features

Return type

torch.Tensor

sample_node_label(graph, generator=None)

Samples node labels.

This function samples a node label for the last node in each graph.

Parameters
  • graph (GraphInfo) – The graph (or batch of graph for which to obtain node labels).

  • generator (torch.Generator, optional) – Optional PRNG to use for sampling.

Returns

An integer tensor containing a sampled node label for each graph in the batch.

Return type

torch.Tensor