dowhy.causal_prediction.algorithms package

Submodules

dowhy.causal_prediction.algorithms.base_algorithm module

class dowhy.causal_prediction.algorithms.base_algorithm.PredictionAlgorithm(model, optimizer, lr, weight_decay, betas, momentum)[source]

Bases: LightningModule

This class implements the default methods for a Pytorch lightning module pl.LightningModule. Its methods are called when the fit() method is called. To know more about these methods, refer to https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html.

Parameters:
  • model – Neural network modules used for training

  • optimizer – Optimization algorithm used for training. Currently supports “Adam” and “SGD”.

  • lr – Value of learning rate

  • weight_decay – Value of weight decay for optimizer

  • betas – Adam configuration parameters (beta1, beta2), exponential decay rate for the first moment and second-moment estimates, respectively.

  • momentum – Value of momentum for SGD optimzer

configure_optimizers()[source]

Initialize the optimizer using params passed when initializing PredictionAlgorithm class.

test_step(batch, batch_idx, dataloader_idx=0)[source]

Activate the test loop for the pl.LightningModule.

Test loop is called only when test() is used.

training_step(train_batch, batch_idx)[source]

Activate the training loop for the pl.LightningModule.

Override this method for implementing a new algorithm.

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Activate the validation loop for the pl.LightningModule.

dowhy.causal_prediction.algorithms.cacm module

class dowhy.causal_prediction.algorithms.cacm.CACM(model, optimizer='Adam', lr=0.001, weight_decay=0.0, betas=(0.9, 0.999), momentum=0.9, kernel_type='gaussian', ci_test='mmd', attr_types=[], E_conditioned=True, E_eq_A=[], gamma=1e-06, lambda_causal=1.0, lambda_conf=1.0, lambda_ind=1.0, lambda_sel=1.0)[source]

Bases: PredictionAlgorithm

Class for Causally Adaptive Constraint Minimization (CACM) Algorithm.
@article{Kaur2022ModelingTD,

title={Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization}, author={Jivat Neet Kaur and Emre Kıcıman and Amit Sharma}, journal={ArXiv}, year={2022}, volume={abs/2206.07837}, url={https://arxiv.org/abs/2206.07837}

}

Parameters:
  • model – Networks used for training. model type expected is torch.nn.Sequential(featurizer, classifier) where featurizer and classifier are of type torch.nn.Module.

  • optimizer – Optimization algorithm used for training. Currently supports “Adam” and “SGD”.

  • lr – learning rate for CACM

  • weight_decay – Value of weight decay for optimizer

  • betas – Adam configuration parameters (beta1, beta2), exponential decay rate for the first moment and second-moment estimates, respectively.

  • momentum – Value of momentum for SGD optimzer

  • kernel_type – Kernel type for MMD penalty. Currently, supports “gaussian” (RBF). If None, distance between mean and second-order statistics (covariances) is used.

  • ci_test – Conditional independence metric used for regularization penalty. Currently, MMD is supported.

  • attr_types – list of attribute types (based on relationship with label Y); should be ordered according to attribute order in loaded dataset. Currently, ‘causal’ (Causal), ‘conf’ (Confounded), ‘ind’ (Independent) and ‘sel’ (Selected) are supported. For single-shift datasets, use: [‘causal’], [‘ind’] For multi-shift datasets, use: [‘causal’, ‘ind’]

  • E_conditioned – Binary flag indicating whether E-conditioned regularization has to be applied

  • E_eq_A – list indicating indices of attributes that coincide with environment (E) definition; default is empty.

  • gamma – kernel bandwidth for MMD (due to implementation, the kernel bandwdith will actually be the reciprocal of gamma i.e., gamma=1e-6 implies kernel bandwidth=1e6. See mmd_compute in utils.py)

  • lambda_causal – MMD penalty hyperparameter for Causal shift

  • lambda_conf – MMD penalty hyperparameter for Confounded shift

  • lambda_ind – MMD penalty hyperparameter for Independent shift

  • lambda_sel – MMD penalty hyperparameter for Selected shift

Returns:

an instance of PredictionAlgorithm class

training_step(train_batch, batch_idx)[source]

Override training_step from PredictionAlgorithm class for CACM-specific training loop.

dowhy.causal_prediction.algorithms.erm module

class dowhy.causal_prediction.algorithms.erm.ERM(model, optimizer='Adam', lr=0.001, weight_decay=0.0, betas=(0.9, 0.999), momentum=0.9)[source]

Bases: PredictionAlgorithm

This class implements the default methods for a Pytorch lightning module pl.LightningModule. Its methods are called when the fit() method is called. To know more about these methods, refer to https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html.

Parameters:
  • model – Neural network modules used for training

  • optimizer – Optimization algorithm used for training. Currently supports “Adam” and “SGD”.

  • lr – Value of learning rate

  • weight_decay – Value of weight decay for optimizer

  • betas – Adam configuration parameters (beta1, beta2), exponential decay rate for the first moment and second-moment estimates, respectively.

  • momentum – Value of momentum for SGD optimzer

training_step(train_batch, batch_idx)[source]

Override training_step from PredictionAlgorithm class for ERM-specific training loop.

dowhy.causal_prediction.algorithms.regularization module

class dowhy.causal_prediction.algorithms.regularization.Regularizer(E_conditioned, ci_test, kernel_type, gamma)[source]

Bases: object

Implements methods for applying unconditional and conditional regularization.

Parameters:
  • E_conditioned – Binary flag indicating whether E-conditioned regularization has to be applied

  • ci_test – Conditional independence metric used for regularization penalty. Currently, MMD is supported.

  • kernel_type – Kernel type for MMD penalty. Currently, supports “gaussian” (RBF). If None, distance between mean and second-order statistics (covariances) is used.

  • gamma – kernel bandwidth for MMD (due to implementation, the kernel bandwdith will actually be the reciprocal of gamma i.e., gamma=1e-6 implies kernel bandwidth=1e6. See mmd_compute in utils.py)

conditional_reg(classifs, attribute_labels, conditioning_subset, num_envs, E_eq_A=False)[source]

Implement conditional regularization φ(x) ⊥⊥ A_i | A_s

Parameters:
  • classifs – feature representations output from classifier layer (gφ(x))

  • attribute_labels – attribute labels loaded with the dataset for attribute A_i

  • conditioning_subset – list of subset of observed variables A_s (attributes + targets) such that (X_c, A_i) are d-separated conditioned on this subset

  • num_envs – number of environments/domains

  • E_eq_A – Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition

Find group indices for conditional regularization based on conditioning subset by taking all possible combinations e.g., conditioning_subset = [A1, Y], where A1 is in {0, 1} and Y is in {0, 1, 2}, we assign groups in the following way:

A1 = 0, Y = 0 -> group 0 A1 = 1, Y = 0 -> group 1 A1 = 0, Y = 1 -> group 2 A1 = 1, Y = 1 -> group 3 A1 = 0, Y = 2 -> group 4 A1 = 1, Y = 2 -> group 5

Code snippet for computing group indices adapted from WILDS: https://github.com/p-lambda/wilds
@inproceedings{wilds2021,

title = {{WILDS}: A Benchmark of in-the-Wild Distribution Shifts}, author = {Pang Wei Koh and Shiori Sagawa and Henrik Marklund and Sang Michael Xie and Marvin Zhang and Akshay Balsubramani and Weihua Hu and Michihiro Yasunaga and Richard Lanas Phillips and Irena Gao and Tony Lee and Etienne David and Ian Stavness and Wei Guo and Berton A. Earnshaw and Imran S. Haque and Sara Beery and Jure Leskovec and Anshul Kundaje and Emma Pierson and Sergey Levine and Chelsea Finn and Percy Liang}, booktitle = {International Conference on Machine Learning (ICML)}, year = {2021}

}`

mmd(x, y)[source]

Compute MMD penalty between x and y.

unconditional_reg(classifs, attribute_labels, num_envs, E_eq_A=False)[source]

Implement unconditional regularization φ(x) ⊥⊥ A_i

Parameters:
  • classifs – feature representations output from classifier layer (gφ(x))

  • attribute_labels – attribute labels loaded with the dataset for attribute A_i

  • num_envs – number of environments/domains

  • E_eq_A – Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition

dowhy.causal_prediction.algorithms.utils module

The functions in this file are borrowed from DomainBed: https://github.com/facebookresearch/DomainBed
@inproceedings{gulrajani2021in,

title={In Search of Lost Domain Generalization}, author={Ishaan Gulrajani and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2021},

}

dowhy.causal_prediction.algorithms.utils.gaussian_kernel(x, y, gamma)[source]
dowhy.causal_prediction.algorithms.utils.mmd_compute(x, y, kernel_type, gamma)[source]
dowhy.causal_prediction.algorithms.utils.my_cdist(x1, x2)[source]

Module contents