sketchgraphs_models.nn

This module provides utilities and generic build blocks for graph neural networks.

Classes

ConcatenateLinear(left_size, right_size, …)

A torch module which concatenates several inputs and mixes them using a linear layer.

MessagePassingNetwork(depth, …[, …])

Custom configurable message-passing network.

Sequential(*args)

Similar to torch.nn.Sequential, except can pass through modules which take multiple input arguments, and return tuples.

Functions

sketchgraphs_models.nn.aggregate_by_incidence(values: torch.Tensor, incidence: torch.Tensor, transform_edge_messages=None, transform_edge_messages_args=None, output_size=None)

Aggregates values according to an incidence matrix.

Effectively computes the following operation:

output[i] = values[incidence[1, incidence[0] == i]].sum(axis=0)

This operation essentially implements a sparse-matrix multiplication in coo format in a naive way. Optimization opportunity: write using actual cuSparse.

Parameters
  • values (torch.Tensor) – A tensor of rank at least 2

  • incidence (torch.Tensor) – a [2, k] tensor

  • transform_edge_messages (function, optional) – an arbitrary function which transforms edge messages.

  • transform_edge_messages_args (any) – Arbitrary set of arguments that are passed to the transform_edge_messages function.

  • output_size (List[int], optional) – if not None, the size of the output tensor. Otherwise, we assume the output tensor is the same size as values.

Returns

The output tensor, of the same rank as values.

Return type

torch.Tensor

sketchgraphs_models.nn.autograd_range(name)

Creates an autograd range for pytorch autograd profiling

Modules

sketchgraphs_models.nn.data_util

Utilities and extensions to work with the torch.utils.data package.

sketchgraphs_models.nn.distributed

Utility modules for distributed and parallel training.

sketchgraphs_models.nn.functional

Utility functions for computing specific nn functions.

sketchgraphs_models.nn.summary

This module implements utilities to compute summary statistics.