Source code for dodiscover.context

from dataclasses import dataclass, field
from typing import Any, Dict, FrozenSet, List, Set, Tuple
from warnings import warn

import networkx as nx

from ._protocol import Graph
from .base import BasePyWhy
from .typing import Column


# TODO: we should try to make the thing frozen
# - this would require easy copying of the Context into a new context
# - but resetting e.g. only say one variable like the init_graph
# - IDEAS: perhaps add a function `new_context = copy_context(context, **kwargs)`
# - where kwargs are the things to change.
[docs]@dataclass( eq=True, # frozen=True ) class Context(BasePyWhy): """Context of assumptions, domain knowledge and data. This should NOT be instantiated directly. One should instead use `dodiscover.make_context` to build a Context data structure. Parameters ---------- variables : Set Set of observed variables. If neither ``latents``, nor ``variables`` is set, then it is presumed that ``variables`` consists of the columns of ``data`` and ``latents`` is the empty set. latents : Set Set of latent "unobserved" variables. If neither ``latents``, nor ``variables`` is set, then it is presumed that ``variables`` consists of the columns of ``data`` and ``latents`` is the empty set. init_graph : Graph The graph to start with. included_edges : nx.Graph Included edges without direction. excluded_edges : nx.Graph Excluded edges without direction. state_variables : Dict Name of intermediate state variables during the learning process. intervention_targets : list of tuple List of intervention targets (known, or unknown), which correspond to the nodes in the graph (known), or indices of datasets that contain interventions (unknown). Raises ------ ValueError ``variables`` and ``latents`` if both set, should contain the set of all columns in ``data``. Notes ----- Context is a data structure for storing assumptions, domain knowledge, priors and other structured contexts alongside the datasets. This class is used in conjunction with a discovery algorithm. Setting the a priori explicit direction of an edge is not supported yet. **Testing for equality** Currently, testing for equality is done on all attributes that are not graphs. Defining equality among graphs is ill-defined, and as such, we leave testing of the internal graphs to users. Some checks of equality for example can be :func:`networkx.algorithms.isomorphism.is_isomorphic` for checking isomorphism among two graphs. """ observed_variables: Set[Column] latent_variables: Set[Column] state_variables: Dict[str, Any] init_graph: Graph = field(compare=False) included_edges: nx.Graph = field(compare=False) excluded_edges: nx.Graph = field(compare=False) ######################################################## # for interventional data ######################################################## # the number of distributions we expect to have access to num_distributions: int = field(default=1) # whether or not observational distribution is present obs_distribution: bool = field(default=True) # (optional) known intervention targets, corresponding to nodes in the graph intervention_targets: List[Tuple[Column]] = field(default_factory=list) # (optional) mapping F-nodes to their symmetric difference intervention targets symmetric_diff_map: Dict[Any, FrozenSet] = field(default_factory=dict) # sigma-map mapping F-nodes to their distribution indices sigma_map: Dict[Any, Tuple] = field(default_factory=dict) f_nodes: List = field(default_factory=list) ######################################################## # for general multi-domain data ######################################################## # the number of domains we expect to have access to num_domains: int = field(default=1) # map each augmented node to a tuple of domains (e.g. (0, 1), or (1,)) domain_map: Dict[Any, Tuple] = field(default_factory=dict) s_nodes: List = field(default_factory=list)
[docs] def add_state_variable(self, name: str, var: Any) -> "Context": """Add a state variable. Called by an algorithm to persist data objects that are used in intermediate steps. Parameters ---------- name : str The name of the state variable. var : any Any state variable. """ self.state_variables[name] = var return self
[docs] def state_variable(self, name: str, on_missing: str = "raise") -> Any: """Get a state variable. Parameters ---------- name : str The name of the state variable. on_missing : {'raise', 'warn', 'ignore'} Behavior if ``name`` is not in the dictionary of state variables. If 'raise' (default) will raise a RuntimeError. If 'warn', will raise a UserWarning. If 'ignore', will return `None`. Returns ------- state_var : Any The state variable. """ if name not in self.state_variables and on_missing != "ignore": err_msg = f"{name} is not a state variable: {self.state_variables}" if on_missing == "raise": raise RuntimeError(err_msg) elif on_missing == "warn": warn(err_msg) return self.state_variables.get(name)
[docs] def copy(self) -> "Context": """Create a deepcopy of the context.""" return Context(**self.get_params(deep=True))
############################################################### # Methods for interventional data. ###############################################################
[docs] def get_non_augmented_nodes(self) -> Set: """Get the set of non f-nodes.""" non_augmented_nodes = set() f_nodes = set(self.f_nodes) s_nodes = set(self.s_nodes) for node in self.init_graph.nodes: if node not in f_nodes and node not in s_nodes: non_augmented_nodes.add(node) return non_augmented_nodes
[docs] def get_augmented_nodes(self) -> Set: """Get the set of f-nodes.""" return set(self.f_nodes).union(set(self.s_nodes))
[docs] def reverse_sigma_map(self) -> Dict: """Get the reverse sigma-map.""" reverse_map = dict() for node, mapping in self.sigma_map.items(): reverse_map[mapping] = node return reverse_map