dowhy.causal_prediction.models package

Submodules

dowhy.causal_prediction.models.networks module

The MNIST_MLP architecture is borrowed from OoD-Bench:
@inproceedings{ye2022ood,

title={OoD-Bench: Quantifying and Understanding Two Dimensions of Out-of-Distribution Generalization}, author={Ye, Nanyang and Li, Kaican and Bai, Haoyue and Yu, Runpeng and Hong, Lanqing and Zhou, Fengwei and Li, Zhenguo and Zhu, Jun}, booktitle={CVPR}, year={2022}

}

dowhy.causal_prediction.models.networks.Classifier(in_features, out_features, is_nonlinear=False)[source]
class dowhy.causal_prediction.models.networks.ContextNet(input_shape)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dowhy.causal_prediction.models.networks.Identity[source]

Bases: Module

An identity layer

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dowhy.causal_prediction.models.networks.MLP(n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout)[source]

Bases: Module

Just an MLP

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dowhy.causal_prediction.models.networks.MNIST_CNN(input_shape)[source]

Bases: Module

Hand-tuned architecture for MNIST. Weirdness I’ve noticed so far with this architecture: - adding a linear layer after the mean-pool in features hurts

RotatedMNIST-100 generalization severely.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

n_outputs = 128
training: bool
class dowhy.causal_prediction.models.networks.MNIST_MLP(input_shape)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class dowhy.causal_prediction.models.networks.ResNet(input_shape, resnet18=True, resnet_dropout=0.0)[source]

Bases: Module

ResNet with the softmax chopped off and the batchnorm frozen

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Encode x into a feature vector of size n_outputs.

freeze_bn()[source]
train(mode=True)[source]

Override the default train() to freeze the BN parameters

training: bool

Module contents