from typing import Optional
import numpy as np
import pywhy_graphs.networkx as pywhy_nx
from pywhy_graphs.typing import TsNode
from .base import BaseTimeSeriesGraph, tsdict
from .digraph import StationaryTimeSeriesDiGraph, TimeSeriesDiGraph
from .graph import StationaryTimeSeriesGraph, TimeSeriesGraph
[docs]
class TimeSeriesMixedEdgeGraph(BaseTimeSeriesGraph, pywhy_nx.MixedEdgeGraph):
    """A class to imbue mixed-edge graph with time-series structure.
    This should not be used directly.
    """
    # whether or not the graph should be assumed to be stationary
    stationary: bool = False
    # overloaded factory dictionary types to hold time-series nodes
    node_dict_factory = tsdict
    node_attr_dict_factory = tsdict
    # supported graph types
    graph_types = (TimeSeriesGraph, TimeSeriesDiGraph)
    def __init__(self, graphs=None, edge_types=None, max_lag=1, **attr):
        if max_lag is not None:
            if graphs is not None and not all(max_lag == graph.max_lag for graph in graphs):
                raise ValueError(
                    f"Passing in max lag of {max_lag} to time-series mixed-edge graph, but "
                    f"sub-graphs have max-lag of {[graph.max_lag for graph in graphs]}."
                )
        elif graphs is not None:
            # infer max lag
            max_lags = [graph.max_lag for graph in graphs]
            if len(np.unique(max_lags)) != 1:
                raise ValueError(f"All max lags in passed in graphs must be equal: {max_lags}.")
        else:
            max_lag = 1
        if graphs is not None and not all(
            issubclass(graph.__class__, self.graph_types) for graph in graphs
        ):
            raise RuntimeError("All graphs for timeseries mixed-edge graph")
        attr.update(dict(max_lag=max_lag))
        self.graph = dict()
        self.graph["max_lag"] = max_lag
        super().__init__(graphs, edge_types, **attr)
[docs]
    def copy(self):
        """Returns a copy of the graph.
        Exactly the same as :meth:`pywhy_graphs.networkx.MixedEdgeGraph.copy`,
        except this preserves the max lag graph attribute.
        Parameters
        ----------
        as_view : bool, optional (default=False)
            If True, the returned graph-view provides a read-only view
            of the original graph without actually copying any data.
        Returns
        -------
        G : Graph
            A copy of the graph.
        See Also
        --------
        :meth:`pywhy_graphs.networkx.MixedEdgeGraph.to_directed`: return a
            directed copy of the graph.
        """
        G = self.__class__(max_lag=self.max_lag)
        G.graph.update(self.graph)
        graph_attr = G.graph
        # add all internal graphs to the copy
        for edge_type in self.edge_types:
            graph_func = self._internal_graph_nx_type(edge_type=edge_type)
            if edge_type not in G.edge_types:
                G.add_edge_type(graph_func(**graph_attr), edge_type)
        # add all nodes and edges now
        G.add_nodes_from((n, d.copy()) for n, d in self.nodes.items())
        for edge_type, adj in self.adj.items():
            for u, nbrs in adj.items():
                for v, datadict in nbrs.items():
                    if v[1] == 0:
                        G.add_edge(u, v, edge_type, **datadict.copy())
                    G.add_nodes_from((n, d.copy()) for n, d in self._node.items() if n[1] == 0)
        return G 
[docs]
    def add_edge(self, u_of_edge: TsNode, v_of_edge: TsNode, edge_type: str = "all", **attr):
        super().add_edge(u_of_edge, v_of_edge, edge_type=edge_type, **attr) 
[docs]
    def add_edges_from(self, ebunch, edge_type="all", **attr):
        super().add_edges_from(ebunch, edge_type=edge_type, **attr) 
[docs]
    def remove_edge(self, u_of_edge, v_of_edge, edge_type="all"):
        super().remove_edge(u_of_edge, v_of_edge, edge_type)  # type: ignore 
[docs]
    def remove_edges_from(self, ebunch, edge_type="all"):
        for edge in ebunch:
            self.remove_edge(*edge, edge_type) 
[docs]
    def add_homologous_edges(
        self, u_of_edge: TsNode, v_of_edge: TsNode, direction="both", edge_type="all", **attr
    ):
        """Add homologous edges.
        Assumes the edge that we consider is ``(u_of_edge, v_of_edge)``, that is 'u' points to 'v'.
        Parameters
        ----------
        u_of_edge : TsNode
            The from node.
        v_of_edge : TsNode
            The to node. The absolute value of the time lag should be less than or equal to
            the from node's time lag.
        direction : str, optional
            Which direction to add homologous edges to, by default 'both', corresponding
            to making the edge stationary over all time.
        """
        if edge_type == "all":
            for edge_type, graph in self.get_graphs().items():
                graph.add_homologous_edges(u_of_edge, v_of_edge, direction=direction, **attr)
        else:
            graph = self.get_graphs(edge_type=edge_type)
            graph.add_homologous_edges(u_of_edge, v_of_edge, direction=direction, **attr) 
[docs]
    def remove_homologous_edges(
        self, u_of_edge: TsNode, v_of_edge: TsNode, edge_type: str = "all", direction="both"
    ):
        """Remove homologous edges.
        Assumes the edge that we consider is ``(u_of_edge, v_of_edge)``, that is 'u' points to 'v'.
        Parameters
        ----------
        u_of_edge : TsNode
            The from node.
        v_of_edge : TsNode
            The to node. The absolute value of the time lag should be less than or equal to
            the from node's time lag.
        direction : str, optional
            Which direction to add homologous edges to, by default 'both', corresponding
            to making the edge stationary over all time.
        """
        if edge_type == "all":
            for edge_type, graph in self.get_graphs().items():
                graph.remove_homologous_edges(u_of_edge, v_of_edge, direction=direction)
        else:
            graph = self.get_graphs(edge_type=edge_type)
            graph.remove_homologous_edges(u_of_edge, v_of_edge, direction=direction) 
 
[docs]
class StationaryTimeSeriesMixedEdgeGraph(TimeSeriesMixedEdgeGraph):
    """A mixed-edge causal graph for stationary time-series.
    Parameters
    ----------
    graphs : List of Graph | DiGraph
        A list of networkx single-edge graphs.
    edge_types : List of str
        A list of names for each edge type.
    max_lag : int, optional
        The maximum lag, by default None.
    attr : keyword arguments, optional (default= no attributes)
        Attributes to add to graph as key=value pairs.
    """
    # whether or not the graph should be assumed to be stationary
    stationary: bool = True
    # supported graph types
    graph_types = (StationaryTimeSeriesGraph, StationaryTimeSeriesDiGraph)
    def __init__(self, graphs=None, edge_types=None, max_lag: Optional[int] = None, **attr):
        super().__init__(graphs, edge_types, max_lag=max_lag, **attr)
    def set_stationarity(self, stationary: bool):
        self.stationary = stationary
        for graph in self.get_graphs().values():
            graph.stationary = stationary