Source code for dowhy.utils.graphviz_plotting
import os
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Union
import networkx as nx
import numpy as np
import pygraphviz
[docs]def plot_causal_graph_graphviz(
causal_graph: nx.Graph,
layout_prog: Optional[str] = None,
display_causal_strengths: bool = True,
causal_strengths: Optional[Dict[Tuple[Any, Any], float]] = None,
colors: Optional[Dict[Union[Any, Tuple[Any, Any]], str]] = None,
filename: Optional[str] = None,
display_plot: bool = True,
figure_size: Optional[Tuple[int, int]] = None,
) -> None:
if causal_strengths is None:
causal_strengths = {}
else:
causal_strengths = deepcopy(causal_strengths)
if colors is None:
colors = {}
else:
colors = deepcopy(colors)
if layout_prog is None:
layout_prog = "dot"
max_strength = 0.0
for (source, target, strength) in causal_graph.edges(data="CAUSAL_STRENGTH", default=None):
if (source, target) not in causal_strengths:
causal_strengths[(source, target)] = strength
if causal_strengths[(source, target)] is not None:
max_strength = max(max_strength, abs(causal_strengths[(source, target)]))
if (source, target) not in colors:
colors[(source, target)] = "black"
pygraphviz_graph = pygraphviz.AGraph(directed=isinstance(causal_graph, nx.DiGraph))
for node in causal_graph.nodes:
if node in colors:
pygraphviz_graph.add_node(node, color=colors[node], fontcolor=colors[node])
else:
pygraphviz_graph.add_node(node)
for (source, target) in causal_graph.edges():
causal_strength = causal_strengths[(source, target)]
color = colors[(source, target)]
if causal_strength is not None:
if np.isinf(causal_strength):
causal_strength = 10000
tmp_label = "Inf"
else:
tmp_label = str(" %s" % str(int(causal_strength * 100) / 100))
from dowhy.utils.plotting import _calc_arrow_width
pygraphviz_graph.add_edge(
str(source),
str(target),
label=tmp_label if display_causal_strengths else None,
penwidth=str(_calc_arrow_width(causal_strength, max_strength)),
color=color,
)
else:
pygraphviz_graph.add_edge(str(source), str(target), color=color)
pygraphviz_graph.layout(prog=layout_prog)
if filename is not None:
filename, file_extension = os.path.splitext(filename)
if file_extension == "":
file_extension = ".pdf"
pygraphviz_graph.draw(filename + file_extension)
if display_plot:
from dowhy.utils.plotting import _plot_as_pyplot_figure
_plot_as_pyplot_figure(pygraphviz_graph, figure_size)