""" Module containing the main model class for the dowhy package.
"""
import logging
from sympy import init_printing
import dowhy.causal_estimators as causal_estimators
import dowhy.causal_refuters as causal_refuters
import dowhy.utils.cli_helpers as cli
from dowhy.causal_estimator import CausalEstimate
from dowhy.causal_graph import CausalGraph
from dowhy.causal_identifier import CausalIdentifier
from dowhy.utils.api import parse_state
init_printing() # To display symbolic math symbols
[docs]class CausalModel:
"""Main class for storing the causal model state.
"""
def __init__(self, data, treatment, outcome, graph=None,
common_causes=None, instruments=None, estimand_type="ate",
proceed_when_unidentifiable=False,
**kwargs):
"""Initialize data and create a causal graph instance.
Assigns treatment and outcome variables.
Also checks and finds the common causes and instruments for treatment
and outcome.
At least one of graph, common_causes or instruments must be provided.
:param data: a pandas dataframe containing treatment, outcome and other
variables.
:param treatment: name of the treatment variable
:param outcome: name of the outcome variable
:param graph: path to DOT file containing a DAG or a string containing
a DAG specification in DOT format
:param common_causes: names of common causes of treatment and _outcome
:param instruments: names of instrumental variables for the effect of
treatment on outcome
:returns: an instance of CausalModel class
"""
self._data = data
self._treatment = parse_state(treatment)
self._outcome = parse_state(outcome)
self._estimand_type = estimand_type
self._proceed_when_unidentifiable = proceed_when_unidentifiable
if 'logging_level' in kwargs:
logging.basicConfig(level=kwargs['logging_level'])
else:
logging.basicConfig(level=logging.INFO)
# TODO: move the logging level argument to a json file. Tue 20 Feb 2018 06:56:27 PM DST
self.logger = logging.getLogger(__name__)
if graph is None:
self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
self._common_causes = parse_state(common_causes)
self._instruments = parse_state(instruments)
if common_causes is not None and instruments is not None:
self._graph = CausalGraph(
self._treatment,
self._outcome,
common_cause_names=self._common_causes,
instrument_names=self._instruments,
observed_node_names=self._data.columns.tolist()
)
elif common_causes is not None:
self._graph = CausalGraph(
self._treatment,
self._outcome,
common_cause_names=self._common_causes,
observed_node_names=self._data.columns.tolist()
)
elif instruments is not None:
self._graph = CausalGraph(
self._treatment,
self._outcome,
instrument_names=self._instruments,
observed_node_names=self._data.columns.tolist()
)
else:
cli.query_yes_no(
"WARN: Are you sure that there are no common causes of treatment and outcome?",
default=None
)
else:
self._graph = CausalGraph(
self._treatment,
self._outcome,
graph,
observed_node_names=self._data.columns.tolist()
)
self._common_causes = self._graph.get_common_causes(self._treatment, self._outcome)
self._instruments = self._graph.get_instruments(self._treatment,
self._outcome)
self._other_variables = kwargs
self.summary()
[docs] def identify_effect(self, proceed_when_unidentifiable=None):
"""Identify the causal effect to be estimated, using properties of the causal graph.
:returns: a probability expression for the causal effect if identified, else NULL
"""
if proceed_when_unidentifiable is None:
proceed_unidentifiable = self._proceed_when_unidentifiable
self.identifier = CausalIdentifier(self._graph,
self._estimand_type,
proceed_when_unidentifiable=proceed_when_unidentifiable)
identified_estimand = self.identifier.identify_effect()
return identified_estimand
[docs] def estimate_effect(self, identified_estimand, method_name=None,
test_significance=None, method_params=None):
"""Estimate the identified causal effect.
If method_name is provided, uses the provided method. Else, finds a
suitable method to be used.
:param identified_estimand: a probability expression
that represents the effect to be estimated. Output of
CausalModel.identify_effect method
:param method_name: (optional) name of the estimation method to be used.
:returns: an instance of the CausalEstimate class, containing the causal effect estimate
and other method-dependent information
"""
if method_name is None:
pass
else:
str_arr = method_name.split(".")
identifier_name = str_arr[0]
estimator_name = str_arr[1]
identified_estimand.set_identifier_method(identifier_name)
causal_estimator_class = causal_estimators.get_class_object(estimator_name + "_estimator")
# Check if estimator's target estimand is identified
if identified_estimand.estimands[identifier_name] is None:
self.logger.warning("No valid identified estimand for using instrumental variables method")
estimate = CausalEstimate(None, None, None)
else:
causal_estimator = causal_estimator_class(
self._data,
identified_estimand,
self._treatment, self._outcome,
test_significance=test_significance,
params=method_params
)
estimate = causal_estimator.estimate_effect()
estimate.add_params(
estimand_type=identified_estimand.estimand_type,
estimator_class=causal_estimator_class
)
return estimate
[docs] def do(self, x, identified_estimand, method_name=None, method_params=None):
"""Estimate the identified causal effect.
If method_name is provided, uses the provided method. Else, finds a
suitable method to be used.
:param identified_estimand: a probability expression
that represents the effect to be estimated. Output of
CausalModel.identify_effect method
:param method_name: (optional) name of the estimation method to be used.
:returns: an instance of the CausalEstimate class, containing the causal effect estimate
and other method-dependent information
"""
if method_name is None:
pass
else:
str_arr = method_name.split(".")
identifier_name = str_arr[0]
estimator_name = str_arr[1]
identified_estimand.set_identifier_method(identifier_name)
causal_estimator_class = causal_estimators.get_class_object(estimator_name + "_estimator")
# Check if estimator's target estimand is identified
if identified_estimand.estimands[identifier_name] is None:
self.logger.warning("No valid identified estimand for using instrumental variables method")
estimate = CausalEstimate(None, None, None)
else:
causal_estimator = causal_estimator_class(
self._data,
identified_estimand,
self._treatment, self._outcome,
test_significance=False,
params=method_params
)
try:
estimate = causal_estimator.do(x)
except NotImplementedError:
self.logger.error('Do Operation not implemented or not supported for this estimator.')
raise NotImplementedError
return estimate
[docs] def refute_estimate(self, estimand, estimate, method_name=None, **kwargs):
"""Refute an estimated causal effect.
If method_name is provided, uses the provided method. Else, finds a
suitable method to use.
:param estimate: an instance of the CausalEstimate class.
:returns: an instance of the RefuteResult class
"""
if method_name is None:
pass
else:
refuter_class = causal_refuters.get_class_object(method_name)
refuter = refuter_class(
self._data,
identified_estimand=estimand,
estimate=estimate,
**kwargs
)
res = refuter.refute_estimate()
return res
[docs] def view_model(self, layout="dot"):
"""View the causal DAG.
:returns: a visualization of the graph
"""
self._graph.view_graph(layout)
[docs] def summary(self):
"""Print a text summary of the model.
:returns: None
"""
self.logger.info("Model to find the causal effect of treatment {0} on outcome {1}".format(self._treatment, self._outcome))