Source code for dowhy.causal_graph

import logging
import re
import networkx as nx
from dowhy.utils.api import parse_state
import itertools


[docs]class CausalGraph: def __init__(self, treatment_name, outcome_name, graph=None, common_cause_names=None, instrument_names=None, observed_node_names=None): self.treatment_name = parse_state(treatment_name) self.outcome_name = parse_state(outcome_name) instrument_names = parse_state(instrument_names) common_cause_names = parse_state(common_cause_names) self.fullname = "_".join(self.treatment_name + self.outcome_name + common_cause_names + instrument_names) self.logger = logging.getLogger(__name__) if graph is None: self._graph = nx.DiGraph() self._graph = self.build_graph(common_cause_names, instrument_names) elif re.match(r".*\.dot", graph): # load dot file try: import pygraphviz as pgv self._graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph)) except Exception as e: self.logger.error("Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot...") try: import pydot self._graph = nx.DiGraph(nx.drawing.nx_pydot.read_dot(graph)) except Exception as e: self.logger.error("Error: Pydot cannot be loaded. " + str(e)) raise e elif re.match(r".*\.gml", graph): self._graph = nx.DiGraph(nx.read_gml(graph)) elif re.match(r".*graph\s*\{.*\}\s*", graph): try: import pygraphviz as pgv self._graph = pgv.AGraph(graph, strict=True, directed=True) self._graph = nx.drawing.nx_agraph.from_agraph(self._graph) except Exception as e: self.logger.error("Error: Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot ...") try: import pydot P_list = pydot.graph_from_dot_data(graph) self._graph = nx.drawing.nx_pydot.from_pydot(P_list[0]) except Exception as e: self.logger.error("Error: Pydot cannot be loaded. " + str(e)) raise e elif re.match(".*graph\s*\[.*\]\s*", graph): self._graph = nx.DiGraph(nx.parse_gml(graph)) else: self.logger.error("Error: Please provide graph (as string or text file) in dot or gml format.") self.logger.error("Error: Incorrect graph format") raise ValueError self._graph = self.add_node_attributes(observed_node_names) self._graph = self.add_unobserved_common_cause(observed_node_names)
[docs] def view_graph(self, layout="dot"): out_filename = "causal_model.png" try: import pygraphviz as pgv agraph = nx.drawing.nx_agraph.to_agraph(self._graph) agraph.draw(out_filename, format="png", prog=layout) except: self.logger.warning("Warning: Pygraphviz cannot be loaded. Check that graphviz and pygraphviz are installed.") self.logger.info("Using Matplotlib for plotting") import matplotlib.pyplot as plt plt.clf() nx.draw_networkx(self._graph, pos=nx.shell_layout(self._graph)) plt.axis('off') plt.savefig(out_filename) plt.draw()
[docs] def build_graph(self, common_cause_names, instrument_names): for treatment in self.treatment_name: self._graph.add_node(treatment, observed="yes") for outcome in self.outcome_name: self._graph.add_node(outcome, observed="yes") for treatment, outcome in itertools.product(self.treatment_name, self.outcome_name): self._graph.add_edge(treatment, outcome) # Adding common causes if common_cause_names is not None: for node_name in common_cause_names: for treatment, outcome in itertools.product(self.treatment_name, self.outcome_name): self._graph.add_node(node_name, observed="yes") self._graph.add_edge(node_name, treatment) self._graph.add_edge(node_name, outcome) # Adding instruments if instrument_names: if type(instrument_names[0]) != tuple: if len(self.treatment_name) > 1: self.logger.info("Assuming Instrument points to all treatments! Use tuples for more granularity.") for instrument, treatment in itertools.product(instrument_names, self.treatment_name): self._graph.add_node(instrument, observed="yes") self._graph.add_edge(instrument, treatment) else: for instrument, treatment in itertools.product(instrument_names): self._graph.add_node(instrument, observed="yes") self._graph.add_edge(instrument, treatment) return self._graph
[docs] def add_node_attributes(self, observed_node_names): for node_name in self._graph: if node_name in observed_node_names: self._graph.nodes[node_name]["observed"] = "yes" else: self._graph.nodes[node_name]["observed"] = "no" return self._graph
[docs] def add_unobserved_common_cause(self, observed_node_names): # Adding unobserved confounders current_common_causes = self.get_common_causes(self.treatment_name, self.outcome_name) create_new_common_cause = True for node_name in current_common_causes: if self._graph.nodes[node_name]["observed"] == "no": create_new_common_cause = False if create_new_common_cause: uc_label = "Unobserved Confounders" self._graph.add_node('U', label=uc_label, observed="no") for node in self.treatment_name + self.outcome_name: self._graph.add_edge('U', node) return self._graph
[docs] def get_unconfounded_observed_subgraph(self): observed_nodes = [node for node in self._graph.nodes() if self._graph.nodes[node]["observed"] == "yes"] return self._graph.subgraph(observed_nodes)
[docs] def do_surgery(self, node_names, remove_outgoing_edges=False, remove_incoming_edges=False): new_graph = self._graph.copy() for node_name in node_names: if remove_outgoing_edges: children = new_graph.successors(node_name) edges_bunch = [(node_name, child) for child in children] new_graph.remove_edges_from(edges_bunch) if remove_incoming_edges: parents = new_graph.predecessors(node_name) edges_bunch = [(parent, node_name) for parent in parents] new_graph.remove_edges_from(edges_bunch) return new_graph
[docs] def get_causes(self, nodes, remove_edges = None): nodes = parse_state(nodes) new_graph=None if remove_edges is not None: new_graph = self._graph.copy() # caution: shallow copy of the attributes sources = parse_state(remove_edges["sources"]) targets = parse_state(remove_edges["targets"]) for s in sources: for t in targets: new_graph.remove_edge(s, t) causes = set() for v in nodes: causes = causes.union(self.get_ancestors(v, new_graph=new_graph)) return causes
[docs] def get_common_causes(self, nodes1, nodes2): nodes1 = parse_state(nodes1) nodes2 = parse_state(nodes2) causes_1 = set() causes_2 = set() for node in nodes1: causes_1 = causes_1.union(self.get_ancestors(node)) for node in nodes2: causes_2 = causes_2.union(self.get_ancestors(node)) return list(causes_1.intersection(causes_2))
[docs] def get_parents(self, node_name): return set(self._graph.predecessors(node_name))
[docs] def get_ancestors(self, node_name, new_graph=None): if new_graph is None: graph=self._graph else: graph=new_graph return set(nx.ancestors(graph, node_name))
[docs] def get_descendants(self, node_name): return set(nx.descendants(self._graph, node_name))
[docs] def all_observed(self, node_names): for node_name in node_names: if self._graph.nodes[node_name]["observed"] != "yes": return False return True
[docs] def filter_unobserved_variables(self, node_names): observed_node_names = list() for node_name in node_names: if self._graph.nodes[node_name]["observed"] == "yes": observed_node_names.append(node_name) return observed_node_names
[docs] def get_instruments(self, treatment_nodes, outcome_nodes): treatment_nodes = parse_state(treatment_nodes) outcome_nodes = parse_state(outcome_nodes) parents_treatment = set() for node in treatment_nodes: parents_treatment = parents_treatment.union(self.get_parents(node)) g_no_parents_treatment = self.do_surgery(treatment_nodes, remove_incoming_edges=True) ancestors_outcome = set() for node in outcome_nodes: ancestors_outcome = ancestors_outcome.union(nx.ancestors(g_no_parents_treatment, node)) # [TODO: double check these work with multivariate implementation:] # Exclusion candidate_instruments = parents_treatment.difference(ancestors_outcome) self.logger.debug("Candidate instruments after exclusion %s", candidate_instruments) # As-if-random setup children_causes_outcome = [nx.descendants(g_no_parents_treatment, v) for v in ancestors_outcome] children_causes_outcome = set([item for sublist in children_causes_outcome for item in sublist]) # As-if-random instruments = candidate_instruments.difference(children_causes_outcome) return list(instruments)