sketchgraphs_models.nn.functional

Utility functions for computing specific nn functions.

Functions

sketchgraphs_models.nn.functional.segmented_cross_entropy(logits: torch.Tensor, target: torch.Tensor, scopes: torch.Tensor)torch.Tensor

Segmented cross-entropy loss.

Computes the cross-entropy loss from unscaled logits for a segmented problem.

Parameters
  • logits (torch.Tensor) – unscaled logits by segment

  • target (torch.Tensor) – tensor of length n_segments, representing the index of the true label for each segment.

  • scopes (tonch.Tensor) – tensor of shape [n_segments, 2], representing the segments as (start, length).

Returns

A tensor of length n_segments representing the cross-entropy loss at each segment.

Return type

torch.Tensor

sketchgraphs_models.nn.functional.segmented_multinomial(logits, scopes, generator=None)

Segmented multinomial sample.

Parameters
  • logits (torch.Tensor) – unscaled logits by segment

  • scopes (torch.Tensor) – tensor of shape [n_segments, 2] representing the segments as (start, length).

  • generator (torch.Generator, optional) – PRNG for sampling

Returns

A tensor of length n_segments representing the sampled values.

Return type

torch.Tensor

sketchgraphs_models.nn.functional.segmented_multinomial_extended(logits, scopes, generator=None, also_return_probs=False)

Segmented multinomial sample with implicit element.

Parameters
  • logits (torch.Tensor) – logits for explicit outcomes by segment

  • scopes (torch.Tensor) – tensor of shape [n_segments, 2] representing the segments as (start, length).

  • generator (torch.Generator, optional) – PRNG for sampling

  • also_return_probs (bool, optional) – If true, returns tuple including log-likelihood of the sample.

Returns

A tensor of length n_segments representing the sampled values.

Return type

torch.Tensor