首页 > 解决方案 > 调整python graphviz图中边的宽度

问题描述

我正在尝试使用 python 到 graphviz 的接口来可视化有限马尔可夫链的转移概率矩阵。我希望马尔可夫链的状态成为图中的节点,并且我希望图的边缘具有与状态之间转换的条件概率成比例的宽度。即,我希望为权重大的边缘绘制粗边,而为权重小的边缘绘制细边。

来自 pandas dataframe 的有向加权图)的讨论与我想要的类似,但它将转换概率信息呈现为文本标签而不是边缘宽度,这将导致无用且难以阅读的图。

我很高兴考虑使用 graphviz 以外的工具来完成这项任务。

这是我正在尝试构建的课程:

import graphviz
import matplotlib.pyplot as plt
import numpy as np


class MarkovViz:
    """
    Visualize the transition probability matrix of a Markov chain as a directed
    graph, where the width of an edge is proportional to the transition
    probability between two states.
    """

    def __init__(self, transition_probability_matrix=None):
        self._graph = None
        if transition_probability_matrix is not None:
            self.build_from_matrix(transition_probability_matrix)

    def build_from_matrix(self, trans, labels=None):
        """
        Args:
          trans: A pd.DataFrame or 2D np.array.  A square matrix containing the
            conditional probabability of a transition from the level
            represented by the row to the level represented by the column.
            Each row sums to 1.
          labels: A list-like sequence of labels to use for the rows and
            columns of 'trans'.  If trans is a pd.DataFrame or similar then
            this entry can be None and labels will be taken from the column
            names of 'trans'.

        Effects:
          self._graph is created as a directed graph, and populated with nodes
            and edges, with edge weights taken from 'trans'.
        """

        if labels is None and hasattr(trans, "columns"):
            labels = list(trans.columns)
            index = list(trans.index)
            if labels != index:
                raise Exception("Mismatch between index and columns of "
                                "the transition probability matrix.")
            trans = trans.values

        trans = np.array(trans)
        self._graph = graphviz.Digraph("MyGraph")

        dim = trans.shape[0]
        if trans.shape[1] != dim:
            raise Exception("Matrix must be symmetric")

        for i in range(dim):
            for j in range(dim):
                if trans[i, j] > 0:
                    self._graph.edge(labels[i], labels[j], weight=trans[i, j])

    def plot(self, ax: plt.Axes):
        self._graph.view()

我将使用看起来像的数据框初始化示例对象

     foo  bar  baz
foo  0.5  0.5    0
bar  0.0  0.0    1
baz  1.0  0.0    0

我遇到以下错误

  File "<stdin>", line 1, in <module>
  File "/.../markov/markovviz.py", line 16, in __init__
    self.build_from_matrix(transition_probability_matrix)
  File "/.../markov/markovviz.py", line 53, in build_from_matrix
    self._graph.edge(labels[i], labels[j], weight=trans[i, j])
  File "/.../graphviz/dot.py", line 153, in edge
    attr_list = self._attr_list(label, attrs, _attributes)
  File "/.../graphviz/lang.py", line 139, in attr_list
    content = a_list(label, kwargs, attributes)
  File "/.../graphviz/lang.py", line 112, in a_list
    for k, v in tools.mapping_items(kwargs) if v is not None]
  File "/.../graphviz/lang.py", line 112, in <listcomp>
    for k, v in tools.mapping_items(kwargs) if v is not None]
  File ".../graphviz/lang.py", line 73, in quote
    if is_html_string(identifier) and not isinstance(identifier, NoHtml):
TypeError: cannot use a string pattern on a bytes-like object

这对我说,边缘唯一允许的属性是字符串或字节。我的问题:

标签: pythongraphvizpygraphviz

解决方案


您的问题源于以下行:

    self._graph.edge(labels[i], labels[j], weight=trans[i, j])

这里的问题是点属性只能是字符串值,而查看其余代码,它看起来好像trans[i, j]可能会返回一个浮点值。

最简单的解决方案可能只是调用str()

    self._graph.edge(labels[i], labels[j], weight=str(trans[i, j]))

这是一个重现问题和解决方案的测试:

>>> import graphviz
>>> g = graphviz.Digraph()
>>> g.edge('a', 'b', weight=1.5)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/lars/.local/share/virtualenvs/python-LD_ZK5QN/lib/python3.9/site-packages/graphviz/dot.py", line 153, in edge
    attr_list = self._attr_list(label, attrs, _attributes)
  File "/home/lars/.local/share/virtualenvs/python-LD_ZK5QN/lib/python3.9/site-packages/graphviz/lang.py", line 139, in attr_list
    content = a_list(label, kwargs, attributes)
  File "/home/lars/.local/share/virtualenvs/python-LD_ZK5QN/lib/python3.9/site-packages/graphviz/lang.py", line 111, in a_list
    items = [f'{quote(k)}={quote(v)}'
  File "/home/lars/.local/share/virtualenvs/python-LD_ZK5QN/lib/python3.9/site-packages/graphviz/lang.py", line 111, in <listcomp>
    items = [f'{quote(k)}={quote(v)}'
  File "/home/lars/.local/share/virtualenvs/python-LD_ZK5QN/lib/python3.9/site-packages/graphviz/lang.py", line 73, in quote
    if is_html_string(identifier) and not isinstance(identifier, NoHtml):
TypeError: expected string or bytes-like object
>>> g.edge('a', 'b', weight=str(1.5))
>>> print(g)
digraph {
        a -> b [weight=1.5]
}
>>>

将权重附加到边缘后,如何绘制图形?

看一下renderview方法:

>>> help(g.render)
render(filename=None, directory=None, view=False, cleanup=False, format=None, renderer=None, formatter=None, quiet=False, quiet_view=False) method of graphviz.dot.Digraph instance
    Save the source to file and render with the Graphviz engine.
[...]
>>> help(g.view)
view(filename=None, directory=None, cleanup=False, quiet=False, quiet_view=False) method of graphviz.dot.Digraph instance
    Save the source to file, open the rendered result in a viewer.
[...]

推荐阅读