[docs]classCDT(GraphLearner):""" Causal discivery using the Causal Discovery Toolbox. Link: https://github.com/FenTechSolutions/CausalDiscoveryToolbox """def__init__(self,data,full_method_name,*args,**kwargs):super().__init__(data,full_method_name,*args,**kwargs)library_class=get_library_class_object(full_method_name)self._method=library_class(*args,**kwargs)
[docs]deflearn_graph(self,labels=None):""" Discover causal graph and return the graph in DOT format. """graph=self._method.predict(self._data)# Get adjacency matrixself._adjacency_matrix=nx.to_numpy_matrix(graph)self._adjacency_matrix=np.asarray(self._adjacency_matrix)# If labels not providediflabelsisnotNone:self._labels=labelsself._graph_dot=adjacency_matrix_to_graph(self._adjacency_matrix,self._labels)# Obtain valid DOT formatself._graph_dot=str_to_dot(self._graph_dot.source)returnself._graph_dot