sketchgraphs_models.training¶
Utilities for model training.
This module implements the main training harness, which is shared by both the graph and autoconstrain model.
Classes
|
Named tuple holding configuration for training a given model. |
|
This class implements the main training loop. |
Functions
-
sketchgraphs_models.training.
load_cuda_async
(batch, device=None)¶ Loads a structured batch recursively onto the given torch device.
-
sketchgraphs_models.training.
map_structure_flat
(structure, function, scalar_types=None)¶ Utility function for mapping a function over an arbitrary structure, maintaining the structure.
- Parameters
structure (object) – An arbitrary nested structure
function (function) – A function to apply to each leaf of the structure
scalar_types (Tuple, optional) – If not None, a tuple of types considered scalar types over which the function is directly applied
- Returns
A structure with each element modified.
- Return type
- Raises
ValueError – If the type is not a scalar type and is not decomposable, an exception is raised.