Source code for dowhy.causal_identifier.backdoor

import networkx as nx

from dowhy.utils.graph_operations import adjacency_matrix_to_adjacency_list


[docs]class NodePair: """ Data structure to store backdoor variables between 2 nodes. """ def __init__(self, node1, node2): self._node1 = node1 self._node2 = node2 self._is_blocked = None # To store if all paths between node1 and node2 are blocked self._condition_vars = [] # To store variable to be conditioned on to block all paths between node1 and node2 self._complete = False # To store to paths between node pair have been completely explored.
[docs] def update(self, path, condition_vars=None): if condition_vars is None: """path is a Path variable""" if self._is_blocked is None: self._is_blocked = path.is_blocked() else: self._is_blocked = self._is_blocked and path.is_blocked() if not path.is_blocked(): self._condition_vars.append(path.get_condition_vars()) else: """path is a list""" condition_vars = list(condition_vars) self._condition_vars.append(set([*path[1:], *condition_vars]))
[docs] def get_condition_vars(self): return self._condition_vars
[docs] def set_complete(self): self._complete = True
[docs] def is_complete(self): return self._complete
def __str__(self): string = "" string += "Blocked: " + str(self._is_blocked) + "\n" if not self._is_blocked: condition_vars = [str(s) for s in self._condition_vars] string += "To block path, condition on: " + ",".join(condition_vars) + "\n" return string
[docs]class Path: """ Data structure to store a particular path between 2 nodes. """ def __init__(self): self._is_blocked = None # To store if path is blocked self._condition_vars = set() # To store variables needed to block the path
[docs] def update(self, path, is_blocked): """ path is a list """ self._is_blocked = is_blocked if not is_blocked: self._condition_vars = self._condition_vars.union(set(path[1:-1]))
[docs] def is_blocked(self): return self._is_blocked
[docs] def get_condition_vars(self): return self._condition_vars
def __str__(self): string = "" string += "Blocked: " + str(self._is_blocked) + "\n" if not self._is_blocked: string += "To block path, condition on: " + ",".join(self._condition_vars) + "\n" return string
[docs]class Backdoor: """ Class for optimized implementation of Backdoor variable search between the source nodes and the target nodes. """ def __init__(self, graph, nodes1, nodes2): self._graph = graph self._nodes1 = nodes1 self._nodes2 = nodes2 self._nodes12 = set(self._nodes1).union(self._nodes2) # Total set of nodes self._colliders = set()
[docs] def get_backdoor_vars(self): """ Obtains sets of backdoor variable to condition on for each node pair. :returns: List of sets with each set containing backdoor variable corresponding to a given node pair. """ undirected_graph = self._graph.to_undirected() # Get adjacency list adjlist = adjacency_matrix_to_adjacency_list( nx.to_numpy_matrix(undirected_graph), labels=list(undirected_graph.nodes) ) path_dict = {} backdoor_sets = [] # Put in backdoor sets format for node1 in self._nodes1: for node2 in self._nodes2: if (node1, node2) in path_dict: continue self._path_search(adjlist, node1, node2, path_dict) if len(path_dict) != 0: obj = HittingSetAlgorithm(path_dict[(node1, node2)].get_condition_vars(), self._colliders) backdoor_set = {} backdoor_set["backdoor_set"] = tuple(obj.find_set()) backdoor_set["num_paths_blocked_by_observed_nodes"] = obj.num_sets() backdoor_sets.append(backdoor_set) return backdoor_sets
[docs] def is_backdoor(self, path): """ Check if path is a backdoor path. :param path: List of nodes comprising the path. """ if len(path) < 2: return False return True if self._graph.has_edge(path[1], path[0]) else False
def _path_search_util(self, graph, node1, node2, vis, path, path_dict, is_blocked=False, prev_arrow=None): """ :param graph: Adjacency list of the graph under consideration. :param node1: Current node being considered. :param node2: Target node. :param vis: Set of already visited nodes. :param path: List of nodes comprising the path upto node1. :path path_dict: Dictionary of node pairs. :param is_blocked: True is path is blocked by a collider, else False. :param prev_arrow: Described state of previous arrow. True if arrow incoming, False if arrow outgoing. """ if is_blocked: return # If node pair has been fully explored if ((node1, node2) in path_dict) and (path_dict[(node1, node2)].is_complete()): for i in range(len(path)): if (path[i], node2) not in path_dict: path_dict[(path[i], node2)] = NodePair(path[i], node2) obj = HittingSetAlgorithm(path_dict[(node1, node2)].get_condition_vars(), self._colliders) # Add node1 to backdoor set of node_pair s = set([node1]) s = s.union(obj.find_set()) path_dict[(path[i], node2)].update(path[i:], s) else: path.append(node1) vis.add(node1) if node1 == node2: # Check if path is backdoor and does not have nodes1\node1 or nodes2\node2 as intermediate nodes if self.is_backdoor(path) and len(self._nodes12.intersection(set(path[1:-1]))) == 0: for i in range(len(path) - 1): if (path[i], node2) not in path_dict: path_dict[(path[i], node2)] = NodePair(path[i], node2) path_var = Path() path_var.update(path[i:].copy(), is_blocked) path_dict[(path[i], node2)].update(path_var) else: for neighbour in graph[node1]: if neighbour not in vis: # True if arrow incoming, False if arrow outgoing next_arrow = False if self._graph.has_edge(node1, neighbour) else True if next_arrow == True and prev_arrow == True: is_blocked = True self._colliders.add(node1) self._path_search_util( graph, neighbour, node2, vis, path, path_dict, is_blocked, not next_arrow ) # Since incoming for current node is outgoing for the next path.pop() vis.remove(node1) # Mark pair (node1, node2) complete if (node1, node2) in path_dict: path_dict[(node1, node2)].set_complete() def _path_search(self, graph, node1, node2, path_dict): """ Path search using DFS. :param graph: Adjacency list of the graph under consideration. :param node1: Current node being considered. :param node2: Target node. :path path_dict: Dictionary of node pairs. """ vis = set() self._path_search_util(graph, node1, node2, vis, [], path_dict, is_blocked=False)
[docs]class HittingSetAlgorithm: """ Class for the Hitting Set Algorithm to obtain a approximate minimal set of backdoor variables to condition on for each node pair. """ def __init__(self, list_of_sets, colliders=set()): """ :param list_of_sets: List of sets such that each set comprises nodes representing a single backdoor path between a source node and a target node. """ self._list_of_sets = list_of_sets self._colliders = colliders self._var_count = self._count_vars()
[docs] def num_sets(self): """ Obtain number of backdoor paths between a node pair. """ return len(self._list_of_sets)
[docs] def find_set(self): """ Find approximate minimal set of nodes such that there is atleast one node from each set in list_of_sets. :returns: Approximate minimal set of nodes. """ var_set = set() num_indices = len(self._list_of_sets) indices_covered = set() all_set_indices = set([i for i in range(num_indices)]) while not self._is_covered(indices_covered, num_indices): set_index = all_set_indices - indices_covered max_el = self._max_occurence_var(var_dict=self._var_count) if max_el is None: break var_set.add(max_el) # Modify variable count and indices covered covered_present = self._indices_covered(el=max_el, set_index=set_index) self._modify_count(covered_present) indices_covered = indices_covered.union(covered_present) return var_set
def _count_vars(self, set_index=None): """ Obtain count of number of sets each particular node belongs to. :param set_index: Set of indices to consider for calculating the number of sets "hit" by a variable.. """ var_dict = {} if set_index == None: set_index = set([i for i in range(len(self._list_of_sets))]) for idx in set_index: s = self._list_of_sets[idx] for el in s: if el not in self._colliders: if el not in var_dict: var_dict[el] = 0 var_dict[el] += 1 return var_dict def _modify_count(self, indices_covered): """ Modify count of number of sets each particular node belongs to based on nodes already covered in the previous iteration of the algorithm. """ for idx in indices_covered: for el in self._list_of_sets[idx]: if el not in self._colliders: self._var_count[el] -= 1 def _max_occurence_var(self, var_dict): """ Find the node contained in most number of sets. """ max_el = None max_count = 0 for key, val in var_dict.items(): if val > max_count: max_count = val max_el = key return max_el def _indices_covered(self, el, set_index=None): """ Obtain indices covered in a particular iteration of the algorithm. """ covered = set() if set_index == None: set_index = set([i for i in range(len(self._list_of_sets))]) for idx in set_index: if el in self._list_of_sets[idx]: covered.add(idx) return covered def _is_covered(self, indices_covered, num_indices): """ List of sets is covered by the variable set. """ covered = [False for i in range(num_indices)] for idx in indices_covered: covered[idx] = True return all(covered)