Source code for dowhy.gcm.causal_models

"""This module defines the fundamental classes for graphical causal models (GCMs).

Classes in this module should be considered experimental, meaning there might be breaking API changes in the future.

from typing import Any, Callable, Optional, Union

import networkx as nx

from dowhy.gcm.causal_mechanisms import (
from dowhy.graph import (

# This constant is used as key when storing/accessing models as causal mechanisms in graph node attributes
CAUSAL_MECHANISM = "causal_mechanism"
# This constant is used as key when storing the parents of a node during fitting. It's used for validation purposes
# afterwards.
PARENTS_DURING_FIT = "parents_during_fit"

[docs]class ProbabilisticCausalModel: """Represents a probabilistic graphical causal model, i.e. it combines a graphical representation of causal causal relationships and corresponding causal mechanism for each node describing the data generation process. The causal mechanisms can be any general stochastic models.""" def __init__( self, graph: Optional[DirectedGraph] = None, graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph ): """ :param graph: Optional graph object to be used as causal graph. :param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph constructor. """ if graph is None: graph = nx.DiGraph() self.graph = graph self.graph_copier = graph_copier
[docs] def set_causal_mechanism(self, node: Any, mechanism: Union[StochasticModel, ConditionalStochasticModel]) -> None: """Assigns the generative causal model of node in the causal graph. :param node: Target node whose causal model is to be assigned. :param mechanism: Causal mechanism to be assigned. A root node must be a :class:`~dowhy.gcm.graph.StochasticModel`, whereas a non-root node must be a :class:`~dowhy.gcm.graph.ConditionalStochasticModel`. """ if node not in self.graph.nodes: raise ValueError("Node %s can not be found in the given graph!" % node) self.graph.nodes[node][CAUSAL_MECHANISM] = mechanism
[docs] def causal_mechanism(self, node: Any) -> Union[StochasticModel, ConditionalStochasticModel]: """Returns the generative causal model of node in the causal graph. :param node: Target node whose causal model is to be assigned. :returns: The causal mechanism for this node. A root node is of type :class:`~dowhy.gcm.graph.StochasticModel`, whereas a non-root node is of type :class:`~dowhy.gcm.graph.ConditionalStochasticModel`. """ return self.graph.nodes[node][CAUSAL_MECHANISM]
[docs] def clone(self): """Clones the causal model, but keeps causal mechanisms untrained.""" graph_copy = self.graph_copier(self.graph) clone_causal_models(self.graph, graph_copy) return self.__class__(graph_copy)
[docs]class StructuralCausalModel(ProbabilisticCausalModel): """Represents a structural causal model (SCM), as required e.g. by :func:`~dowhy.gcm.whatif.counterfactual_samples`. As compared to a :class:`~dowhy.gcm.cms.ProbabilisticCausalModel`, an SCM describes the data generation process in non-root nodes by functional causal models. """
[docs] def set_causal_mechanism(self, node: Any, mechanism: Union[StochasticModel, FunctionalCausalModel]) -> None: super().set_causal_mechanism(node, mechanism)
[docs] def causal_mechanism(self, node: Any) -> Union[StochasticModel, FunctionalCausalModel]: return super().causal_mechanism(node)
[docs]class InvertibleStructuralCausalModel(StructuralCausalModel): """Represents an invertible structural graphical causal model, as required e.g. by :func:`~dowhy.gcm.whatif.counterfactual_samples`. This is a subclass of :class:`~dowhy.gcm.cms.StructuralCausalModel` and has further restrictions on the class of causal mechanisms. Here, the mechanisms of non-root nodes need to be invertible with respect to the noise, such as :class:`~dowhy.gcm.causal_mechanisms.PostNonlinearModel`. """
[docs] def set_causal_mechanism( self, target_node: Any, mechanism: Union[StochasticModel, InvertibleFunctionalCausalModel] ) -> None: super().set_causal_mechanism(target_node, mechanism)
[docs] def causal_mechanism(self, node: Any) -> Union[StochasticModel, InvertibleFunctionalCausalModel]: return super().causal_mechanism(node)
[docs]def validate_causal_dag(causal_graph: DirectedGraph) -> None: validate_acyclic(causal_graph) validate_causal_graph(causal_graph)
[docs]def validate_causal_graph(causal_graph: DirectedGraph) -> None: for node in causal_graph.nodes: validate_node(causal_graph, node)
[docs]def validate_node(causal_graph: DirectedGraph, node: Any) -> None: validate_causal_model_assignment(causal_graph, node) validate_local_structure(causal_graph, node)
[docs]def validate_causal_model_assignment(causal_graph: DirectedGraph, target_node: Any) -> None: validate_node_has_causal_model(causal_graph, target_node) causal_model = causal_graph.nodes[target_node][CAUSAL_MECHANISM] if is_root_node(causal_graph, target_node): if not isinstance(causal_model, StochasticModel): raise RuntimeError( "Node %s is a root node and, thus, requires a StochasticModel, " "but a %s was found!" % (target_node, causal_model) ) elif not isinstance(causal_model, ConditionalStochasticModel): raise RuntimeError( "Node %s has parents and, thus, requires a ConditionalStochasticModel, " "but a %s was found!" % (target_node, causal_model) )
[docs]def validate_local_structure(causal_graph: DirectedGraph, node: Any) -> None: if PARENTS_DURING_FIT not in causal_graph.nodes[node] or causal_graph.nodes[node][ PARENTS_DURING_FIT ] != get_ordered_predecessors(causal_graph, node): raise RuntimeError( "The causal mechanism of node %s is not fitted to the graphical structure! Fit all" "causal models in the graph first. If the mechanism is already fitted based on the causal" "parents, consider to update the persisted parents for that node manually." % node )
[docs]def validate_node_has_causal_model(causal_graph: HasNodes, node: Any) -> None: validate_node_in_graph(causal_graph, node) if CAUSAL_MECHANISM not in causal_graph.nodes[node]: raise ValueError("Node %s has no assigned causal mechanism!" % node)
[docs]def clone_causal_models(source: HasNodes, destination: HasNodes): for node in destination.nodes: if CAUSAL_MECHANISM in source.nodes[node]: destination.nodes[node][CAUSAL_MECHANISM] = source.nodes[node][CAUSAL_MECHANISM].clone()