fromtypingimportList,Set,Unionimportnetworkxasnxfrompywhy_graphsimportADMG,CPDAG,PAG,StationaryTimeSeriesCPDAG,StationaryTimeSeriesPAGfrom..configimportEdgeTypefrom..typingimportNode__all__=["single_source_shortest_mixed_path","is_node_common_cause","set_nodes_as_latent_confounders","is_valid_mec_graph","inducing_path",]defis_node_common_cause(G:nx.DiGraph,node:Node,exclude_nodes:List[Node]=None)->bool:"""Check if a node is a common cause within the graph. Parameters ---------- G : DiGraph A networkx DiGraph. node : node A node in the graph. exclude_nodes : list, optional Set of nodes to exclude from consideration, by default None. Returns ------- is_common_cause : bool Whether or not the node is a common cause or not. """ifexclude_nodesisNone:exclude_nodes=[]successors=G.successors(node)count=0forsuccinsuccessors:ifsuccnotinexclude_nodes:count+=1ifcount>=2:returnTruereturnFalsedefset_nodes_as_latent_confounders(G:Union[nx.DiGraph,ADMG],nodes:List[Node])->ADMG:"""Set nodes as latent unobserved confounders. Note that this only works if the original node is a common cause of some variables in the graph. Parameters ---------- G : DiGraph A networkx DiGraph. nodes : list A list of nodes to set. They must all be common causes of variables within the graph. Returns ------- graph : ADMG The mixed-edge causal graph that results. """bidirected_edges=[]new_parent_ch_edges=[]fornodeinnodes:# check if the node is a common causeifnotis_node_common_cause(G,node,exclude_nodes=nodes):raiseRuntimeError(f"{node} is not a common cause within the graph "f"given excluding variables. This function will only convert common "f"causes to latent confounders.")# keep track of which nodes to form c-components oversuccessor_nodes=G.successors(node)foridx,succinenumerate(successor_nodes):# TODO: do we want this?; add parent -> successor edges# if there are parents to this node, they must now point to all the successorsforparentinG.predecessors(node):new_parent_ch_edges.append((parent,succ))# form a c-component among the successorsifidx==0:prev_succ=succcontinuebidirected_edges.append((prev_succ,succ))prev_succ=succ# create the graph with nodes excluding those that are converted to latent confoundersifisinstance(G,ADMG):new_graph=G.copy()elifisinstance(G,nx.DiGraph):new_graph=ADMG(G.copy())new_graph.remove_nodes_from(nodes)# create the c-component structuresnew_graph.add_edges_from(bidirected_edges,new_graph.bidirected_edge_name)# add additional edges that need to be accounted fornew_graph.add_edges_from(new_parent_ch_edges,new_graph.directed_edge_name)returnnew_graph
[docs]defis_valid_mec_graph(G,on_error:str="raise")->bool:"""Check G is a valid PAG. A valid CPDAG/PAG is one where each pair of nodes have at most one edge between them and the internal graph of directed edges do not form cycles. Parameters ---------- G : PAG | CPDAG The PAG or CPDAG. on_error : str Whether to raise an error if the graph is non-compliant. Default is 'raise'. Other options are 'ignore'. Returns ------- bool Whether G is a valid PAG or CPDAG. Notes ----- This function does not check whether or not the PAG, or CPDAG are valid in the sense that the construction of the PAG/CPDAG was constrained to not contain any directed edge cycles. """ifisinstance(G,CPDAG)orisinstance(G,StationaryTimeSeriesCPDAG):check_func=_check_adding_cpdag_edgeelifisinstance(G,PAG)orisinstance(G,StationaryTimeSeriesPAG):check_func=_check_adding_pag_edgeforedge_type,edgeviewinG.edges().items():foru,vinedgeview:check_func(G,u,v,edge_type)returnTrue
def_check_adding_cpdag_edge(graph:CPDAG,u_of_edge:Node,v_of_edge:Node,edge_type:EdgeType):"""Check compatibility among internal graphs when adding an edge of a certain type. Parameters ---------- graph : CPDAG The CPDAG we are adding edges to. u_of_edge : node The start node. v_of_edge : node The end node. edge_type : EdgeType The edge type that is being added. """raise_error=Falseifedge_type==EdgeType.DIRECTED.value:# there should not be a undirected edge, or a bidirected edgeifgraph.has_edge(u_of_edge,v_of_edge,graph.undirected_edge_name):raise_error=Trueifgraph.has_edge(v_of_edge,u_of_edge,graph.directed_edge_name):raiseRuntimeError(f"There is an existing {v_of_edge} -> {u_of_edge}. You are "f"trying to add a directed edge from {u_of_edge} -> {v_of_edge}. ")elifedge_type==EdgeType.UNDIRECTED.value:# there should not be any type of edge between the twoifgraph.has_edge(u_of_edge,v_of_edge,graph.directed_edge_name)orgraph.has_edge(v_of_edge,u_of_edge,graph.directed_edge_name):raise_error=Trueifraise_error:raiseRuntimeError(f"There is already an existing edge between {u_of_edge} and {v_of_edge}. "f"Adding a {edge_type} edge is not possible. Please remove the existing "f"edge first.")def_check_adding_pag_edge(graph:PAG,u_of_edge:Node,v_of_edge:Node,edge_type:EdgeType):"""Check compatibility among internal graphs when adding an edge of a certain type. Parameters ---------- u_of_edge : node The start node. v_of_edge : node The end node. edge_type : EdgeType The edge type that is being added. """raise_error=Falseifedge_type==EdgeType.ALL.value:ifgraph.has_edge(u_of_edge,v_of_edge):raise_error=Trueelifedge_type==EdgeType.CIRCLE.value:# there should not be an existing arrow# nor a bidirected arrowifgraph.has_edge(u_of_edge,v_of_edge,graph.directed_edge_name)orgraph.has_edge(u_of_edge,v_of_edge,graph.bidirected_edge_name):raise_error=Trueelifedge_type==EdgeType.DIRECTED.value:# there should not be a circle edge in the same direction, or a bidirected edgeifgraph.has_edge(u_of_edge,v_of_edge,graph.circle_edge_name)orgraph.has_edge(u_of_edge,v_of_edge,graph.bidirected_edge_name):raise_error=Trueifgraph.has_edge(v_of_edge,u_of_edge,graph.directed_edge_name):raiseRuntimeError(f"There is an existing {v_of_edge} -> {u_of_edge}. You are "f"trying to add a directed edge from {u_of_edge} -> {v_of_edge}. "f"If your intention is to create a bidirected edge, first remove the "f"edge and then explicitly add the bidirected edge.")elifedge_type==EdgeType.BIDIRECTED.value:# there should not be any type of edge between the twoif(graph.has_edge(u_of_edge,v_of_edge,graph.directed_edge_name)orgraph.has_edge(u_of_edge,v_of_edge,graph.circle_edge_name)orgraph.has_edge(v_of_edge,u_of_edge,graph.directed_edge_name)orgraph.has_edge(v_of_edge,u_of_edge,graph.circle_edge_name)):raise_error=Trueifraise_error:raiseRuntimeError(f"There is already an existing edge between {u_of_edge} and {v_of_edge}. "f"Adding a {edge_type} edge is not possible. Please remove the existing "f"edge first. {graph.edges()}")defsingle_source_shortest_mixed_path(G,source,cutoff=None,valid_path=None):"""Compute shortest mixed-edge path between source and all other nodes. Parameters ---------- G : NetworkX graph source : node label Starting node for path cutoff : integer, optional Depth to stop the search. Only paths of length <= cutoff are returned. valid_path : function Function to determine i Returns ------- lengths : dictionary Dictionary, keyed by target, of shortest paths. Examples -------- >>> G = nx.path_graph(5) >>> path = nx.single_source_shortest_path(G, 0) >>> path[4] [0, 1, 2, 3, 4] Notes ----- The shortest path is not necessarily unique. So there can be multiple paths between the source and each target node, all of which have the same 'shortest' length. For each target node, this function returns only one of those paths. See Also -------- shortest_path """ifsourcenotinG:raisenx.NodeNotFound(f"Source {source} not in G")defjoin(p1,p2):returnp1+p2ifcutoffisNone:cutoff=float("inf")ifvalid_pathisNone:valid_path=lambda*_:Truenextlevel={source:1}# list of nodes to check at next levelpaths={source:[source]}# paths dictionary (paths to key from source)returndict(_single_shortest_path_early_stop(G,nextlevel,paths,cutoff,join,valid_path))def_single_shortest_path_early_stop(G,firstlevel,paths,cutoff,join,valid_path):"""Return shortest paths. Shortest Path helper function. Parameters ---------- G : Graph Graph firstlevel : dict starting nodes, e.g. {source: 1} or {target: 1} paths : dict paths for starting nodes, e.g. {source: [source]} cutoff : int or float level at which we stop the process join : function function to construct a path from two partial paths. Requires two list inputs `p1` and `p2`, and returns a list. Usually returns `p1 + p2` (forward from source) or `p2 + p1` (backward from target) valid_path : function function to determine if the current path is a valid path. Input of graph, current node, and the next node. Returns true if continuing along the path of 'current node' *-* 'next node' is valid. Returns false otherwise and the path will be cut short. Returns ------- paths : dict The updated paths for starting nodes. """level=0# the current levelnextlevel=firstlevelwhilenextlevelandcutoff>level:thislevel=nextlevelnextlevel={}forvinthislevel:forwinG.neighbors(v):ifwnotinpathsandvalid_path(G,v,w):paths[w]=join(paths[v],[w])nextlevel[w]=1level+=1returnpathsdef_directed_sub_graph_ancestors(G,node:Node):"""Finds the ancestors of a node in the directed subgraph. Parameters ---------- G : Graph The graph. node : Node The node for which we have to find the ancestors. Returns ------- out : set The parents of the provided node. """returnnx.ancestors(G.sub_directed_graph(),node)def_directed_sub_graph_parents(G,node:Node):"""Finds the parents of a node in the directed subgraph. Parameters ---------- G : Graph The graph. node : Node The node for which we have to find the parents. Returns ------- out : set The parents of the provided node. """returnset(G.sub_directed_graph().predecessors(node))def_bidirected_sub_graph_neighbors(G,node:Node):"""Finds the neighbors of a node in the bidirected subgraph. Parameters ---------- G : Graph The graph. node : Node The node for which we have to find the neighbors. Returns ------- out : set The parents of the provided node. """bidirected_parents=set()ifnotisinstance(G,CPDAG):bidirected_parents=set(G.sub_bidirected_graph().neighbors(node))returnbidirected_parentsdef_is_collider(G,prev_node:Node,cur_node:Node,next_node:Node):"""Checks if the given node is a collider or not. Parameters ---------- G : graph The graph. prev_node : node The previous node in the path. cur_node : node The node to be checked. next_node: Node The next node in the path. Returns ------- iscollider : bool Bool is set true if the node is a collider, false otherwise. """parents=_directed_sub_graph_parents(G,cur_node)parents=parents.union(_bidirected_sub_graph_neighbors(G,cur_node))ifprev_nodeinparentsandnext_nodeinparents:returnTruereturnFalsedef_shortest_valid_path(G,node_x:Node,node_y:Node,L:Set,S:Set,visited:Set,all_ancestors:Set,cur_node:Node,prev_node:Node,):"""Recursively explores a graph to find a path. Finds path that are compliant with the inducing path requirements. Parameters ---------- G : graph The graph. node_x : node The source node. node_y : node The destination node L : Set Set containing all the non-colliders. S : Set Set containing all the colliders. visited : Set Set containing all the nodes already visited. all_ancestors : Set Set containing all the ancestors a collider needs to be checked against. cur_node : node The current node. prev_node : node The previous node in the path. Returns ------- path : Tuple[bool, path] A tuple containing a bool and a path which is empty if the bool is false. """path_exists=Falsepath=[]visited.add(cur_node)neighbors=G.neighbors(cur_node)ifcur_nodeisnode_y:return(True,[node_y])foreleminneighbors:ifeleminvisited:continueelse:# If the current node is a collider, check that it is either an# ancestor of X, Y or any element of S or that it is# the destination node itself.if(_is_collider(G,prev_node,cur_node,elem)and(cur_nodenotinall_ancestors)and(cur_nodenotinS)and(cur_nodeisnotnode_y)):continue# If the current node is not a collider, check that it is# either in L or the destination node itself.elif(not_is_collider(G,prev_node,cur_node,elem)and(cur_nodenotinL)and(cur_nodeisnotnode_y)):continue# if it is a valid node and not the destination node,# check if it has a path to the destination nodepath_exists,temp_path=_shortest_valid_path(G,node_x,node_y,L,S,visited,all_ancestors,elem,cur_node)ifpath_exists:path.append(cur_node)path.extend(temp_path)breakreturn(path_exists,path)
[docs]definducing_path(G,node_x:Node,node_y:Node,L:Set=None,S:Set=None):"""Checks if an inducing path exists between two nodes. An inducing path is defined in :footcite:`Zhang2008`. Parameters ---------- G : Graph The graph. node_x : node The source node. node_y : node The destination node. L : Set Nodes that are ignored on the path. Defaults to an empty set. See Notes for details. S: Set Nodes that are always conditioned on. Defaults to an empty set. See Notes for details. Returns ------- path : Tuple[bool, path] A tuple containing a bool and a path if the bool is true, an empty list otherwise. Notes ----- An inducing path intuitively is a path between two non-adjacent nodes that cannot be d-separated. Therefore, the path is always "active" regardless of what variables we condition on. L contains all the non-colliders, these nodes are ignored along the path. S contains nodes that are always conditioned on (hence if the ancestors of colliders are in S, then those collider paths are always "active"). References ---------- .. footbibliography:: """ifLisNone:L=set()ifSisNone:S=set()nodes=set(G.nodes)ifnode_xnotinnodesornode_ynotinnodes:raiseValueError("The provided nodes are not in the graph.")ifnode_x==node_y:raiseValueError("The source and destination nodes are the same.")edges=G.edges()# XXX: fix this when graphs are refactored to only check for directed/bidirected edge typesforeleminedges.keys():ifelemnotin{"directed","bidirected"}:iflen(edges[elem])!=0:raiseValueError("Inducing Path is not defined for this graph.")path=[]# this will contain the path.x_ancestors=_directed_sub_graph_ancestors(G,node_x)y_ancestors=_directed_sub_graph_ancestors(G,node_y)xy_ancestors=x_ancestors.union(y_ancestors)s_ancestors:set[Node]=set()foreleminS:s_ancestors=s_ancestors.union(_directed_sub_graph_ancestors(G,elem))# ancestors of X, Y and all the elements of Sall_ancestors=xy_ancestors.union(s_ancestors)x_neighbors=G.neighbors(node_x)path_exists=Falseforeleminx_neighbors:visited={node_x}ifelemnotinvisited:path_exists,temp_path=_shortest_valid_path(G,node_x,node_y,L,S,visited,all_ancestors,elem,node_x)ifpath_exists:path.append(node_x)path.extend(temp_path)breakreturn(path_exists,path)