sketchgraphs_models.graph.sample.GraphSamplingModel¶
-
class
sketchgraphs_models.graph.sample.
GraphSamplingModel
(model_core, entity_label, entity_feature_readout, edge_post_embedding, edge_label, edge_feature_readout, edge_partner, feature_dimensions)¶ -
__init__
(model_core, entity_label, entity_feature_readout, edge_post_embedding, edge_label, edge_feature_readout, edge_partner, feature_dimensions)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
static
from_numerical_model
(model: sketchgraphs_models.graph.model.GraphModel, feature_dimensions)¶ Creates a new sampling model from the given numerical model.
-
sample_edge_label
(graph, targets, generator=None)¶ Samples edge labels.
This function samples edge label types for given edges between the last node in the graph and the specified target node.
- Parameters
graph (GraphInfo) – the graph (or batch of graphs) for which to obtain edge labels
targets (torch.Tensor) – A tensor indicating the indices of the edges
generator (torch.Generator, optional) – Optional PRNG to use for sampling
- Returns
An integer tensor containing the sampled edge labels.
- Return type
-
sample_edge_target
(graph, generator=None)¶ Samples edge targets.
This function samples potential edges between the last node and a given target. The sampled target may potentially be one past the end of the nodes, indicating that no new edges should be created.
- Parameters
graph (GraphInfo) – the graph (or batch of graphs) for which to obtain node labels.
generator (torch.Generator, optional) – Optional PRNG to use for sampling
- Returns
An integer tensor containing the sampled edge target for each graph in the batch.
- Return type
-
sample_entity_features
(graph, target_type, features, generator=None)¶ Samples entity features.
- Parameters
graph (GraphInfo) – The graph (or batch of graphs) for which to obtain node features
target_type (TargetType) – The target type / label of the entity for which to sample features
features (torch.Tensor) – An integer tensor corresponding to the entity features
generator (torch.Generator, optional) – Optional PRNG to use for sampling
- Returns
An integer tensor containing the sampled node features
- Return type
-
sample_node_label
(graph, generator=None)¶ Samples node labels.
This function samples a node label for the last node in each graph.
- Parameters
graph (GraphInfo) – The graph (or batch of graph for which to obtain node labels).
generator (torch.Generator, optional) – Optional PRNG to use for sampling.
- Returns
An integer tensor containing a sampled node label for each graph in the batch.
- Return type
-