Source code for dodiscover.ci.oracle

from typing import Optional, Set

import networkx as nx
import numpy as np
import pandas as pd

from dodiscover.typing import Column

from .._protocol import Graph
from .base import BaseConditionalIndependenceTest


[docs]class Oracle(BaseConditionalIndependenceTest): """Oracle conditional independence testing. Used for unit testing and checking intuition. Parameters ---------- graph : nx.DiGraph | Graph The ground-truth causal graph. """ _allow_multivariate_input: bool = True def __init__(self, graph: Graph, included_nodes: Optional[Set[Column]] = None) -> None: self.graph = graph self.included_nodes = included_nodes
[docs] def test( self, df: pd.DataFrame, x_vars: Set[Column], y_vars: Set[Column], z_covariates: Optional[Set[Column]] = None, ): """Conditional independence test given an oracle. Checks conditional independence between 'x_vars' and 'y_vars' given 'z_covariates' of variables using the causal graph as an oracle. The oracle uses d-separation statements given the graph to query conditional independences. This is known as the Markov property for graphs :footcite:`Pearl_causality_2009,Spirtes1993`. Parameters ---------- df : pd.DataFrame of shape (n_samples, n_variables) The data matrix. Passed in for API consistency, but not used. x_vars : node A node in the dataset. y_vars : node A node in the dataset. z_covariates : set The set of variables to check that separates x_vars and y_vars. Returns ------- statistic : None A return argument for the statistic. pvalue : float The pvalue. Return '1.0' if not independent and '0.0' if they are. References ---------- .. footbibliography:: """ self._check_test_input(df, x_vars, y_vars, z_covariates) # generate a set of included nodes always in the Z-covariates included_nodes = set() if self.included_nodes is not None: included_nodes = ( set(self.included_nodes).difference(set(x_vars)).difference(set(y_vars)) ) if z_covariates is None: z_covariates_ = set(included_nodes) else: z_covariates_ = set(z_covariates).union(included_nodes) # just check for d-separation between x and y given sep_set if isinstance(self.graph, nx.DiGraph): is_sep = nx.d_separated(self.graph, x_vars, y_vars, z_covariates_) else: import pywhy_graphs.networkx as pywhy_nx is_sep = pywhy_nx.m_separated(self.graph, x_vars, y_vars, z_covariates_) if is_sep: pvalue = 1 test_stat = 0 else: pvalue = 0 test_stat = np.inf return test_stat, pvalue