sketchgraphs_models.training

Utilities for model training.

This module implements the main training harness, which is shared by both the graph and autoconstrain model.

Classes

TrainingConfig(dataloader, tb_writer, …)

Named tuple holding configuration for training a given model.

TrainingHarness(model, opt, config_train, …)

This class implements the main training loop.

Functions

sketchgraphs_models.training.load_cuda_async(batch, device=None)

Loads a structured batch recursively onto the given torch device.

sketchgraphs_models.training.map_structure_flat(structure, function, scalar_types=None)

Utility function for mapping a function over an arbitrary structure, maintaining the structure.

Parameters
  • structure (object) – An arbitrary nested structure

  • function (function) – A function to apply to each leaf of the structure

  • scalar_types (Tuple, optional) – If not None, a tuple of types considered scalar types over which the function is directly applied

Returns

A structure with each element modified.

Return type

object

Raises

ValueError – If the type is not a scalar type and is not decomposable, an exception is raised.