dowhy.causal_prediction.datasets package

Submodules

dowhy.causal_prediction.datasets.base_dataset module

MultipleDomainDataset class in this file is borrowed from DomainBed: https://github.com/facebookresearch/DomainBed
@inproceedings{gulrajani2021in,

title={In Search of Lost Domain Generalization}, author={Ishaan Gulrajani and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2021},

}

class dowhy.causal_prediction.datasets.base_dataset.MultipleDomainDataset[source]

Bases: object

CHECKPOINT_FREQ = 100
ENVIRONMENTS = None
INPUT_SHAPE = None
N_STEPS = 5001
N_WORKERS = 8

dowhy.causal_prediction.datasets.mnist module

class dowhy.causal_prediction.datasets.mnist.MNISTCausalAttribute(root, download=True)[source]

Bases: MultipleDomainDataset

Class for MNISTCausalAttribute dataset.

Parameters:
  • root – The directory where data can be found (or should be downloaded to, if it does not exist).

  • download – Binary flag indicating whether data should be downloaded

Returns:

an instance of MultipleDomainDataset class

CHECKPOINT_FREQ = 500
ENVIRONMENTS = ['+90%', '+80%', '-90%', '-90%']
INPUT_SHAPE = (2, 14, 14)
N_STEPS = 5001
color_dataset(images, labels, environment)[source]

Transform MNIST dataset to introduce correlation between attribute (color) and label. There is a direct-causal relationship between label Y and color.

Parameters:
  • images – original MNIST images

  • labels – original MNIST labels

  • environment – Value of correlation between color and label

Returns:

TensorDataset containing transformed images, labels, and attributes (color)

torch_bernoulli_(p, size)[source]
torch_xor_(a, b)[source]
class dowhy.causal_prediction.datasets.mnist.MNISTCausalIndAttribute(root, download=True)[source]

Bases: MultipleDomainDataset

Class for MNISTIndAttribute dataset.

Parameters:
  • root – The directory where data can be found (or should be downloaded to, if it does not exist).

  • download – Binary flag indicating whether data should be downloaded

Returns:

an instance of MultipleDomainDataset class

CHECKPOINT_FREQ = 500
ENVIRONMENTS = ['+90%, 15', '+80%, 16', '-90%, 90', '-90%, 90']
INPUT_SHAPE = (2, 14, 14)
N_STEPS = 5001
color_dataset(images, labels, environment)[source]

Transform MNIST dataset to introduce correlation between attribute (color) and label. There is a direct-causal relationship between label Y and color.

Parameters:
  • images – rotated MNIST images

  • labels – original MNIST labels

  • environment – Value of correlation between color and label

Returns:

transformed images, labels, and attributes (color)

color_rot_dataset(images, labels, environment, env_id, angle)[source]

Transform MNIST dataset by (i) applying rotation to images, then (ii) introducing correlation between attribute (color) and label. Attribute (rotation angle) is independent of label Y; there is a direct-causal relationship between label Y and color.

Parameters:
  • images – original MNIST images

  • labels – original MNIST labels

  • environment – Value of correlation between color and label

  • angle – Value of rotation angle used for transforming the image

Returns:

TensorDataset containing transformed images, labels, and attributes (color, angle)

rotate_dataset(images, angle)[source]

Transform MNIST dataset by applying rotation to images. Attribute (rotation angle) is independent of label Y.

Parameters:
  • images – original MNIST images

  • angle – Value of rotation angle used for transforming the image

Returns:

transformed images

torch_bernoulli_(p, size)[source]
torch_xor_(a, b)[source]
class dowhy.causal_prediction.datasets.mnist.MNISTIndAttribute(root, download=True)[source]

Bases: MultipleDomainDataset

Class for MNISTIndAttribute dataset.

Parameters:
  • root – The directory where data can be found (or should be downloaded to, if it does not exist).

  • download – Binary flag indicating whether data should be downloaded

Returns:

an instance of MultipleDomainDataset class

CHECKPOINT_FREQ = 500
ENVIRONMENTS = ['15', '60', '90', '90']
INPUT_SHAPE = (1, 14, 14)
N_STEPS = 5001
rotate_dataset(images, labels, env_id, angle)[source]

Transform MNIST dataset by applying rotation to images. Attribute (rotation angle) is independent of label Y.

Parameters:
  • images – original MNIST images

  • labels – original MNIST labels

  • angle – Value of rotation angle used for transforming the image

Returns:

TensorDataset containing transformed images, labels, and attributes (angle)

torch_bernoulli_(p, size)[source]
torch_xor_(a, b)[source]

Module contents