Source code for dowhy.causal_prediction.dataloaders.misc

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

"""misc helper functions
"""

import hashlib
from collections import Counter, OrderedDict

import numpy as np
import torch


[docs]def make_weights_for_balanced_classes(dataset): counts = Counter() classes = [] for _, y in dataset: y = int(y) counts[y] += 1 classes.append(y) n_classes = len(counts) weight_per_class = {} for y in counts: weight_per_class[y] = 1 / (counts[y] * n_classes) weights = torch.zeros(len(dataset)) for i, y in enumerate(classes): weights[i] = weight_per_class[int(y)] return weights
class _SplitDataset(torch.utils.data.Dataset): """Used by split_dataset""" def __init__(self, underlying_dataset, keys): super(_SplitDataset, self).__init__() self.underlying_dataset = underlying_dataset self.keys = keys def __getitem__(self, key): return self.underlying_dataset[self.keys[key]] def __len__(self): return len(self.keys)
[docs]def split_dataset(dataset, n, seed=0): """ Return a pair of datasets corresponding to a random split of the given dataset, with n datapoints in the first dataset and the rest in the last, using the given random seed """ assert n <= len(dataset) keys = list(range(len(dataset))) np.random.RandomState(seed).shuffle(keys) keys_1 = keys[:n] keys_2 = keys[n:] return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2)
[docs]def seed_hash(*args): """ Derive an integer hash from all args, for use as a random seed. """ args_str = str(args) return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31)