import logging
import re
import networkx as nx
from dowhy.utils.api import parse_state
import itertools
[docs]class CausalGraph:
"""Class for creating and modifying the causal graph.
Accepts a graph string (or a text file) in gml format (preferred) and dot format. Graphviz-like attributes can be set for edges and nodes. E.g. style="dashed" as an edge attribute ensures that the edge is drawn with a dashed line.
If a graph string is not given, names of treatment, outcome, and confounders, instruments and effect modifiers (if any) can be provided to create the graph.
"""
def __init__(self,
treatment_name, outcome_name,
graph=None,
common_cause_names=None,
instrument_names=None,
effect_modifier_names=None,
observed_node_names=None,
missing_nodes_as_confounders=False):
self.treatment_name = parse_state(treatment_name)
self.outcome_name = parse_state(outcome_name)
instrument_names = parse_state(instrument_names)
common_cause_names = parse_state(common_cause_names)
self.logger = logging.getLogger(__name__)
if graph is None:
self._graph = nx.DiGraph()
self._graph = self.build_graph(common_cause_names,
instrument_names, effect_modifier_names)
elif re.match(r".*\.dot", graph):
# load dot file
try:
import pygraphviz as pgv
self._graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph))
except Exception as e:
self.logger.error("Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot...")
try:
import pydot
self._graph = nx.DiGraph(nx.drawing.nx_pydot.read_dot(graph))
except Exception as e:
self.logger.error("Error: Pydot cannot be loaded. " + str(e))
raise e
elif re.match(r".*\.gml", graph):
self._graph = nx.DiGraph(nx.read_gml(graph))
elif re.match(r".*graph\s*\{.*\}\s*", graph):
try:
import pygraphviz as pgv
self._graph = pgv.AGraph(graph, strict=True, directed=True)
self._graph = nx.drawing.nx_agraph.from_agraph(self._graph)
except Exception as e:
self.logger.error("Error: Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot ...")
try:
import pydot
P_list = pydot.graph_from_dot_data(graph)
self._graph = nx.drawing.nx_pydot.from_pydot(P_list[0])
except Exception as e:
self.logger.error("Error: Pydot cannot be loaded. " + str(e))
raise e
elif re.match(".*graph\s*\[.*\]\s*", graph):
self._graph = nx.DiGraph(nx.parse_gml(graph))
else:
self.logger.error("Error: Please provide graph (as string or text file) in dot or gml format.")
self.logger.error("Error: Incorrect graph format")
raise ValueError
if missing_nodes_as_confounders:
self._graph = self.add_missing_nodes_as_common_causes(observed_node_names)
# Adding node attributes
self._graph = self.add_node_attributes(observed_node_names)
#TODO do not add it here. CausalIdentifier should call causal_graph to add an unobserved common cause if needed. This also ensures that we do not need get_common_causes in this class.
self._graph = self.add_unobserved_common_cause(observed_node_names)
[docs] def view_graph(self, layout="dot"):
out_filename = "causal_model.png"
try:
import pygraphviz as pgv
agraph = nx.drawing.nx_agraph.to_agraph(self._graph)
agraph.draw(out_filename, format="png", prog=layout)
except:
self.logger.warning("Warning: Pygraphviz cannot be loaded. Check that graphviz and pygraphviz are installed.")
self.logger.info("Using Matplotlib for plotting")
import matplotlib.pyplot as plt
solid_edges = [(n1,n2) for n1,n2, e in self._graph.edges(data=True) if 'style' not in e ]
dashed_edges =[(n1,n2) for n1,n2, e in self._graph.edges(data=True) if ('style' in e and e['style']=="dashed") ]
plt.clf()
pos = nx.layout.shell_layout(self._graph)
nx.draw_networkx_nodes(self._graph, pos, node_color='yellow',node_size=400 )
nx.draw_networkx_edges(
self._graph,
pos,
edgelist=solid_edges,
arrowstyle="-|>",
arrowsize=12)
nx.draw_networkx_edges(
self._graph,
pos,
edgelist=dashed_edges,
arrowstyle="-|>",
style="dashed",
arrowsize=12)
labels = nx.draw_networkx_labels(self._graph, pos)
plt.axis('off')
plt.savefig(out_filename)
plt.draw()
[docs] def build_graph(self, common_cause_names, instrument_names, effect_modifier_names):
""" Creates nodes and edges based on variable names and their semantics.
Currently only considers the graphical representation of "direct" effect modifiers. Thus, all effect modifiers are assumed to be "direct" unless otherwise expressed using a graph. Based on the taxonomy of effect modifiers by VanderWheele and Robins: "Four types of effect modification: A classification based on directed acyclic graphs. Epidemiology. 2007."
"""
for treatment in self.treatment_name:
self._graph.add_node(treatment, observed="yes")
for outcome in self.outcome_name:
self._graph.add_node(outcome, observed="yes")
for treatment, outcome in itertools.product(self.treatment_name, self.outcome_name):
self._graph.add_edge(treatment, outcome)
# Adding common causes
if common_cause_names is not None:
for node_name in common_cause_names:
for treatment, outcome in itertools.product(self.treatment_name, self.outcome_name):
self._graph.add_node(node_name, observed="yes")
self._graph.add_edge(node_name, treatment)
self._graph.add_edge(node_name, outcome)
# Adding instruments
if instrument_names:
if type(instrument_names[0]) != tuple:
if len(self.treatment_name) > 1:
self.logger.info("Assuming Instrument points to all treatments! Use tuples for more granularity.")
for instrument, treatment in itertools.product(instrument_names, self.treatment_name):
self._graph.add_node(instrument, observed="yes")
self._graph.add_edge(instrument, treatment)
else:
for instrument, treatment in itertools.product(instrument_names):
self._graph.add_node(instrument, observed="yes")
self._graph.add_edge(instrument, treatment)
# Adding effect modifiers
if effect_modifier_names is not None:
for node_name in effect_modifier_names:
if node_name not in common_cause_names:
for outcome in self.outcome_name:
self._graph.add_node(node_name, observed="yes")
self._graph.add_edge(node_name, outcome, style = "dotted", headport="s", tailport="n")
self._graph.add_edge(outcome, node_name, style = "dotted", headport="n", tailport="s") # TODO make the ports more general so that they apply not just to top-bottom node configurations
return self._graph
[docs] def add_node_attributes(self, observed_node_names):
for node_name in self._graph:
if node_name in observed_node_names:
self._graph.nodes[node_name]["observed"] = "yes"
else:
self._graph.nodes[node_name]["observed"] = "no"
return self._graph
[docs] def add_missing_nodes_as_common_causes(self, observed_node_names):
# Adding columns in the dataframe as confounders that were not in the graph
for node_name in observed_node_names:
if node_name not in self._graph:
self._graph.add_node(node_name, observed="yes")
for treatment_outcome_node in self.treatment_name + self.outcome_name:
self._graph.add_edge(node_name, treatment_outcome_node)
return self._graph
[docs] def add_unobserved_common_cause(self, observed_node_names):
# Adding unobserved confounders
current_common_causes = self.get_common_causes(self.treatment_name,
self.outcome_name)
create_new_common_cause = True
for node_name in current_common_causes:
if self._graph.nodes[node_name]["observed"] == "no":
create_new_common_cause = False
if create_new_common_cause:
uc_label = "Unobserved Confounders"
self._graph.add_node('U', label=uc_label, observed="no")
for node in self.treatment_name + self.outcome_name:
self._graph.add_edge('U', node)
self.logger.info('If this is observed data (not from a randomized experiment), there might always be missing confounders. Adding a node named "Unobserved Confounders" to reflect this.')
return self._graph
[docs] def get_unconfounded_observed_subgraph(self):
observed_nodes = [node for node in self._graph.nodes() if self._graph.nodes[node]["observed"] == "yes"]
return self._graph.subgraph(observed_nodes)
[docs] def do_surgery(self, node_names, remove_outgoing_edges=False,
remove_incoming_edges=False):
new_graph = self._graph.copy()
for node_name in node_names:
if remove_outgoing_edges:
children = new_graph.successors(node_name)
edges_bunch = [(node_name, child) for child in children]
new_graph.remove_edges_from(edges_bunch)
if remove_incoming_edges:
parents = new_graph.predecessors(node_name)
edges_bunch = [(parent, node_name) for parent in parents]
new_graph.remove_edges_from(edges_bunch)
return new_graph
[docs] def get_causes(self, nodes, remove_edges = None):
nodes = parse_state(nodes)
new_graph=None
if remove_edges is not None:
new_graph = self._graph.copy() # caution: shallow copy of the attributes
sources = parse_state(remove_edges["sources"])
targets = parse_state(remove_edges["targets"])
for s in sources:
for t in targets:
new_graph.remove_edge(s, t)
causes = set()
for v in nodes:
causes = causes.union(self.get_ancestors(v, new_graph=new_graph))
return causes
[docs] def get_common_causes(self, nodes1, nodes2):
"""
Assume that nodes1 causes nodes2 (e.g., nodes1 are the treatments and nodes2 are the outcomes)
"""
# TODO Refactor to remove this from here and only implement this logic in causalIdentifier. Unnecessary assumption of nodes1 to be causing nodes2.
nodes1 = parse_state(nodes1)
nodes2 = parse_state(nodes2)
causes_1 = set()
causes_2 = set()
for node in nodes1:
causes_1 = causes_1.union(self.get_ancestors(node))
for node in nodes2:
# Cannot simply compute ancestors, since that will also include nodes1 and its parents (e.g. instruments)
parents_2 = self.get_parents(node)
for parent in parents_2:
if parent not in nodes1:
causes_2 = causes_2.union(set([parent,]))
causes_2 = causes_2.union(self.get_ancestors(parent))
return list(causes_1.intersection(causes_2))
[docs] def get_effect_modifiers(self, nodes1, nodes2):
modifiers = set()
for node in nodes2:
modifiers = modifiers.union(self.get_ancestors(node))
modifiers = modifiers.difference(nodes1)
for node in nodes1:
modifiers = modifiers.difference(self.get_ancestors(node))
return list(modifiers)
[docs] def get_parents(self, node_name):
return set(self._graph.predecessors(node_name))
[docs] def get_ancestors(self, node_name, new_graph=None):
if new_graph is None:
graph=self._graph
else:
graph=new_graph
return set(nx.ancestors(graph, node_name))
[docs] def get_descendants(self, node_name):
return set(nx.descendants(self._graph, node_name))
[docs] def all_observed(self, node_names):
for node_name in node_names:
if self._graph.nodes[node_name]["observed"] != "yes":
return False
return True
[docs] def filter_unobserved_variables(self, node_names):
observed_node_names = list()
for node_name in node_names:
if self._graph.nodes[node_name]["observed"] == "yes":
observed_node_names.append(node_name)
return observed_node_names
[docs] def get_instruments(self, treatment_nodes, outcome_nodes):
treatment_nodes = parse_state(treatment_nodes)
outcome_nodes = parse_state(outcome_nodes)
parents_treatment = set()
for node in treatment_nodes:
parents_treatment = parents_treatment.union(self.get_parents(node))
g_no_parents_treatment = self.do_surgery(treatment_nodes,
remove_incoming_edges=True)
ancestors_outcome = set()
for node in outcome_nodes:
ancestors_outcome = ancestors_outcome.union(nx.ancestors(g_no_parents_treatment, node))
# [TODO: double check these work with multivariate implementation:]
# Exclusion
candidate_instruments = parents_treatment.difference(ancestors_outcome)
self.logger.debug("Candidate instruments after exclusion %s",
candidate_instruments)
# As-if-random setup
children_causes_outcome = [nx.descendants(g_no_parents_treatment, v)
for v in ancestors_outcome]
children_causes_outcome = set([item
for sublist in children_causes_outcome
for item in sublist])
# As-if-random
instruments = candidate_instruments.difference(children_causes_outcome)
return list(instruments)