Source code for dodiscover.ci.oracle
from typing import Optional, Set
import networkx as nx
import numpy as np
import pandas as pd
from dodiscover.typing import Column
from .._protocol import Graph
from .base import BaseConditionalIndependenceTest
[docs]class Oracle(BaseConditionalIndependenceTest):
"""Oracle conditional independence testing.
Used for unit testing and checking intuition.
Parameters
----------
graph : nx.DiGraph | Graph
The ground-truth causal graph.
"""
_allow_multivariate_input: bool = True
def __init__(self, graph: Graph, included_nodes: Optional[Set[Column]] = None) -> None:
self.graph = graph
self.included_nodes = included_nodes
[docs] def test(
self,
df: pd.DataFrame,
x_vars: Set[Column],
y_vars: Set[Column],
z_covariates: Optional[Set[Column]] = None,
):
"""Conditional independence test given an oracle.
Checks conditional independence between 'x_vars' and 'y_vars'
given 'z_covariates' of variables using the causal graph
as an oracle. The oracle uses d-separation statements given
the graph to query conditional independences. This is known
as the Markov property for graphs
:footcite:`Pearl_causality_2009,Spirtes1993`.
Parameters
----------
df : pd.DataFrame of shape (n_samples, n_variables)
The data matrix. Passed in for API consistency, but not
used.
x_vars : node
A node in the dataset.
y_vars : node
A node in the dataset.
z_covariates : set
The set of variables to check that separates x_vars and y_vars.
Returns
-------
statistic : None
A return argument for the statistic.
pvalue : float
The pvalue. Return '1.0' if not independent and '0.0'
if they are.
References
----------
.. footbibliography::
"""
self._check_test_input(df, x_vars, y_vars, z_covariates)
# generate a set of included nodes always in the Z-covariates
included_nodes = set()
if self.included_nodes is not None:
included_nodes = (
set(self.included_nodes).difference(set(x_vars)).difference(set(y_vars))
)
if z_covariates is None:
z_covariates_ = set(included_nodes)
else:
z_covariates_ = set(z_covariates).union(included_nodes)
# just check for d-separation between x and y given sep_set
if isinstance(self.graph, nx.DiGraph):
is_sep = nx.d_separated(self.graph, x_vars, y_vars, z_covariates_)
else:
import pywhy_graphs.networkx as pywhy_nx
is_sep = pywhy_nx.m_separated(self.graph, x_vars, y_vars, z_covariates_)
if is_sep:
pvalue = 1
test_stat = 0
else:
pvalue = 0
test_stat = np.inf
return test_stat, pvalue