sketchgraphs_models.nn.functional¶
Utility functions for computing specific nn functions.
Functions
-
sketchgraphs_models.nn.functional.
segmented_cross_entropy
(logits: torch.Tensor, target: torch.Tensor, scopes: torch.Tensor) → torch.Tensor¶ Segmented cross-entropy loss.
Computes the cross-entropy loss from unscaled
logits
for a segmented problem.- Parameters
logits (torch.Tensor) – unscaled logits by segment
target (torch.Tensor) – tensor of length
n_segments
, representing the index of the true label for each segment.scopes (tonch.Tensor) – tensor of shape
[n_segments, 2]
, representing the segments as(start, length)
.
- Returns
A tensor of length
n_segments
representing the cross-entropy loss at each segment.- Return type
-
sketchgraphs_models.nn.functional.
segmented_multinomial
(logits, scopes, generator=None)¶ Segmented multinomial sample.
- Parameters
logits (torch.Tensor) – unscaled logits by segment
scopes (torch.Tensor) – tensor of shape
[n_segments, 2]
representing the segments as(start, length)
.generator (torch.Generator, optional) – PRNG for sampling
- Returns
A tensor of length
n_segments
representing the sampled values.- Return type
-
sketchgraphs_models.nn.functional.
segmented_multinomial_extended
(logits, scopes, generator=None, also_return_probs=False)¶ Segmented multinomial sample with implicit element.
- Parameters
logits (torch.Tensor) – logits for explicit outcomes by segment
scopes (torch.Tensor) – tensor of shape
[n_segments, 2]
representing the segments as(start, length)
.generator (torch.Generator, optional) – PRNG for sampling
also_return_probs (bool, optional) – If true, returns tuple including log-likelihood of the sample.
- Returns
A tensor of length
n_segments
representing the sampled values.- Return type