sketchgraphs_models.graph.train.data_loading¶
This module contains the main functions used to load the required data from disk for training.
Functions
-
sketchgraphs_models.graph.train.data_loading.
initialize_datasets
(args, distributed_config: Optional[sketchgraphs_models.distributed_utils.DistributedTrainingInfo] = None)¶ Initialize datasets and dataloaders.
- Parameters
args (dict) – Dictionary containing all the dataset configurations.
distributed_config (distributed_utils.DistributedTrainingInfo, optional) – If not None, configuration options for distributed training.
- Returns
torch.data.utils.Dataloader – Training dataloader
torch.data.utils.Dataloader – If not None, testing dataloader
int – Number of batches per training epoch
dataset.EntityFeatureMapping – Feature mapping in use for entities
dataset.EdgeFeatureMapping – Feature mapping in use for constraints
-
sketchgraphs_models.graph.train.data_loading.
load_dataset_and_weights
(dataset_file, auxiliary_file, quantization, seed=None, entity_features=True, edge_features=True, force_entity_categorical_features=False)¶
-
sketchgraphs_models.graph.train.data_loading.
load_dataset_and_weights_with_mapping
(dataset_file, node_feature_mapping, edge_feature_mapping, seed=None)¶
-
sketchgraphs_models.graph.train.data_loading.
load_sequences_and_mappings
(dataset_file, auxiliary_file, quantization, entity_features=True, edge_features=True)¶
-
sketchgraphs_models.graph.train.data_loading.
make_dataloader_train
(collate_fn, ds_train, weights, batch_size, num_epochs, num_workers, distributed_config=None)¶