Source code for dowhy.causal_prediction.dataloaders.fast_data_loader

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch


class _InfiniteSampler(torch.utils.data.Sampler):
    """Wraps another Sampler to yield an infinite stream."""

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            for batch in self.sampler:
                yield batch


[docs]class InfiniteDataLoader: def __init__(self, dataset, weights, batch_size, num_workers): super().__init__() if weights is not None: sampler = torch.utils.data.WeightedRandomSampler(weights, replacement=True, num_samples=batch_size) else: sampler = torch.utils.data.RandomSampler(dataset, replacement=True) if weights == None: weights = torch.ones(len(dataset)) batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size=batch_size, drop_last=True) self._infinite_iterator = iter( torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_sampler=_InfiniteSampler(batch_sampler)) ) self._length = len(batch_sampler) def __iter__(self): while True: yield next(self._infinite_iterator) def __len__(self): return self._length
[docs]class FastDataLoader: """DataLoader wrapper with slightly improved speed by not respawning worker processes at every epoch.""" def __init__(self, dataset, batch_size, num_workers): super().__init__() batch_sampler = torch.utils.data.BatchSampler( torch.utils.data.RandomSampler(dataset, replacement=False), batch_size=batch_size, drop_last=False ) self._infinite_iterator = iter( torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_sampler=_InfiniteSampler(batch_sampler)) ) self._length = len(batch_sampler) def __iter__(self): for _ in range(len(self)): yield next(self._infinite_iterator) def __len__(self): return self._length