sketchgraphs_models.graph.train.data_loading

This module contains the main functions used to load the required data from disk for training.

Functions

sketchgraphs_models.graph.train.data_loading.initialize_datasets(args, distributed_config: Optional[sketchgraphs_models.distributed_utils.DistributedTrainingInfo] = None)

Initialize datasets and dataloaders.

Parameters
Returns

  • torch.data.utils.Dataloader – Training dataloader

  • torch.data.utils.Dataloader – If not None, testing dataloader

  • int – Number of batches per training epoch

  • dataset.EntityFeatureMapping – Feature mapping in use for entities

  • dataset.EdgeFeatureMapping – Feature mapping in use for constraints

sketchgraphs_models.graph.train.data_loading.load_dataset_and_weights(dataset_file, auxiliary_file, quantization, seed=None, entity_features=True, edge_features=True, force_entity_categorical_features=False)
sketchgraphs_models.graph.train.data_loading.load_dataset_and_weights_with_mapping(dataset_file, node_feature_mapping, edge_feature_mapping, seed=None)
sketchgraphs_models.graph.train.data_loading.load_sequences_and_mappings(dataset_file, auxiliary_file, quantization, entity_features=True, edge_features=True)
sketchgraphs_models.graph.train.data_loading.make_dataloader_train(collate_fn, ds_train, weights, batch_size, num_epochs, num_workers, distributed_config=None)