sketchgraphs_models.graph.train.harness.GraphModelHarness¶
-
class
sketchgraphs_models.graph.train.harness.
GraphModelHarness
(model, opt, node_feature_dimension, edge_feature_dimension, config_train, config_eval=None, scheduler=None, output_dir=None, dist_config=None, profile_enabled=False, additional_model_information=None)¶ This class is the main harness for training graph models.
The harness is responsible for coordinating all the procedures that surround training, such as learning rate scheduling, data loading, and logging.
-
__init__
(model, opt, node_feature_dimension, edge_feature_dimension, config_train, config_eval=None, scheduler=None, output_dir=None, dist_config=None, profile_enabled=False, additional_model_information=None)¶ Creates a new harness for the given model.
- Parameters
model (torch.nn.Module) – The torch model to train.
opt (torch.optim.Optimizer) – The optimizer to use during training
config_train (TrainingConfig) – The configuration to use for training
config_eval (TrainingConfig, optional) – The configuration to use for evaluation
dist_config (DistributedTrainingInfo, optional) – The configuration used for distributed training
-
on_epoch_end
(epoch, global_step)¶ This function is called at the end of each epoch.
-
single_step
(batch, global_step)¶ Implements a single step of the model evaluation / training.
-