sketchgraphs_models.graph.train.harness.GraphModelHarness

class sketchgraphs_models.graph.train.harness.GraphModelHarness(model, opt, node_feature_dimension, edge_feature_dimension, config_train, config_eval=None, scheduler=None, output_dir=None, dist_config=None, profile_enabled=False, additional_model_information=None)

This class is the main harness for training graph models.

The harness is responsible for coordinating all the procedures that surround training, such as learning rate scheduling, data loading, and logging.

__init__(model, opt, node_feature_dimension, edge_feature_dimension, config_train, config_eval=None, scheduler=None, output_dir=None, dist_config=None, profile_enabled=False, additional_model_information=None)

Creates a new harness for the given model.

Parameters
on_epoch_end(epoch, global_step)

This function is called at the end of each epoch.

single_step(batch, global_step)

Implements a single step of the model evaluation / training.

Parameters
  • batch (dict) – Input batch from the dataloader

  • global_step (int) – Global step for this batch

Returns

  • losses (dict) – Dictionary of computed losses

  • accuracy (dict) – Dictionary of computed accuracy