dowhy.causal_prediction.datasets package
Submodules
dowhy.causal_prediction.datasets.base_dataset module
- MultipleDomainDataset class in this file is borrowed from DomainBed: https://github.com/facebookresearch/DomainBed
 - @inproceedings{gulrajani2021in,
 title={In Search of Lost Domain Generalization}, author={Ishaan Gulrajani and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2021},
}
dowhy.causal_prediction.datasets.mnist module
- class dowhy.causal_prediction.datasets.mnist.MNISTCausalAttribute(root, download=True)[source]
 Bases:
MultipleDomainDatasetClass for MNISTCausalAttribute dataset.
- Parameters:
 root – The directory where data can be found (or should be downloaded to, if it does not exist).
download – Binary flag indicating whether data should be downloaded
- Returns:
 an instance of MultipleDomainDataset class
- CHECKPOINT_FREQ = 500
 
- ENVIRONMENTS = ['+90%', '+80%', '-90%', '-90%']
 
- INPUT_SHAPE = (2, 14, 14)
 
- N_STEPS = 5001
 
- color_dataset(images, labels, environment)[source]
 Transform MNIST dataset to introduce correlation between attribute (color) and label. There is a direct-causal relationship between label Y and color.
- Parameters:
 images – original MNIST images
labels – original MNIST labels
environment – Value of correlation between color and label
- Returns:
 TensorDataset containing transformed images, labels, and attributes (color)
- class dowhy.causal_prediction.datasets.mnist.MNISTCausalIndAttribute(root, download=True)[source]
 Bases:
MultipleDomainDatasetClass for MNISTIndAttribute dataset.
- Parameters:
 root – The directory where data can be found (or should be downloaded to, if it does not exist).
download – Binary flag indicating whether data should be downloaded
- Returns:
 an instance of MultipleDomainDataset class
- CHECKPOINT_FREQ = 500
 
- ENVIRONMENTS = ['+90%, 15', '+80%, 16', '-90%, 90', '-90%, 90']
 
- INPUT_SHAPE = (2, 14, 14)
 
- N_STEPS = 5001
 
- color_dataset(images, labels, environment)[source]
 Transform MNIST dataset to introduce correlation between attribute (color) and label. There is a direct-causal relationship between label Y and color.
- Parameters:
 images – rotated MNIST images
labels – original MNIST labels
environment – Value of correlation between color and label
- Returns:
 transformed images, labels, and attributes (color)
- color_rot_dataset(images, labels, environment, env_id, angle)[source]
 Transform MNIST dataset by (i) applying rotation to images, then (ii) introducing correlation between attribute (color) and label. Attribute (rotation angle) is independent of label Y; there is a direct-causal relationship between label Y and color.
- Parameters:
 images – original MNIST images
labels – original MNIST labels
environment – Value of correlation between color and label
angle – Value of rotation angle used for transforming the image
- Returns:
 TensorDataset containing transformed images, labels, and attributes (color, angle)
- class dowhy.causal_prediction.datasets.mnist.MNISTIndAttribute(root, download=True)[source]
 Bases:
MultipleDomainDatasetClass for MNISTIndAttribute dataset.
- Parameters:
 root – The directory where data can be found (or should be downloaded to, if it does not exist).
download – Binary flag indicating whether data should be downloaded
- Returns:
 an instance of MultipleDomainDataset class
- CHECKPOINT_FREQ = 500
 
- ENVIRONMENTS = ['15', '60', '90', '90']
 
- INPUT_SHAPE = (1, 14, 14)
 
- N_STEPS = 5001
 
- rotate_dataset(images, labels, env_id, angle)[source]
 Transform MNIST dataset by applying rotation to images. Attribute (rotation angle) is independent of label Y.
- Parameters:
 images – original MNIST images
labels – original MNIST labels
angle – Value of rotation angle used for transforming the image
- Returns:
 TensorDataset containing transformed images, labels, and attributes (angle)