Networkx拥有用于探索性数据分析的不错的绘图工具,它不是制作出版物质量数据的工具,出于各种原因,我不想在这里讨论。因此,我从头开始重写了代码库的那一部分,并制作了一个称为netgraph的独立绘图模块,可以在此处找到它(就像纯粹基于matplotlib的原始绘图模块
一样)。该API非常非常相似,并且文档齐全,因此按照您的目的进行建模应该不会太困难。
在此基础上,我得到以下结果:
我选择颜色来表示边缘强度,因为您可以
1)表示负值,
2)更好地区分较小的值。
但是,您也可以将边缘宽度传递给netgraph(请参阅参考资料
netgraph.draw_edges())。
分支的不同顺序是数据结构(字典)的结果,它表示没有固有顺序。您必须修改数据结构和
_parse_input()下面的功能以解决该问题。
码:
import itertoolsimport numpy as npimport matplotlib.pyplot as pltimport netgraph; reload(netgraph)def plot_layered_network(weight_matrices, distance_between_layers=2, distance_between_nodes=1, layer_labels=None, **kwargs): """ Convenience function to plot layered network. Arguments: ---------- weight_matrices: [w1, w2, ..., wn] list of weight matrices defining the connectivity between layers; each weight matrix is a 2-D ndarray with rows indexing source and columns indexing targets; the number of sources has to match the number of targets in the last layer distance_between_layers: int distance_between_nodes: int layer_labels: [str1, str2, ..., strn+1] labels of layers **kwargs: passed to netgraph.draw() Returns: -------- ax: matplotlib axis instance """ nodes_per_layer = _get_nodes_per_layer(weight_matrices) node_positions = _get_node_positions(nodes_per_layer, distance_between_layers, distance_between_nodes) w = _combine_weight_matrices(weight_matrices, nodes_per_layer) ax = netgraph.draw(w, node_positions, **kwargs) if not layer_labels is None: ax.set_xticks(distance_between_layers*np.arange(len(weight_matrices)+1)) ax.set_xticklabels(layer_labels) ax.xaxis.set_ticks_position('bottom') return axdef _get_nodes_per_layer(weight_matrices): nodes_per_layer = [] for w in weight_matrices: sources, targets = w.shape nodes_per_layer.append(sources) nodes_per_layer.append(targets) return nodes_per_layerdef _get_node_positions(nodes_per_layer, distance_between_layers, distance_between_nodes): x = [] y = [] for ii, n in enumerate(nodes_per_layer): x.append(distance_between_nodes * np.arange(0., n)) y.append(ii * distance_between_layers * np.ones((n))) x = np.concatenate(x) y = np.concatenate(y) return np.c_[y,x]def _combine_weight_matrices(weight_matrices, nodes_per_layer): total_nodes = np.sum(nodes_per_layer) w = np.full((total_nodes, total_nodes), np.nan, np.float) a = 0 b = nodes_per_layer[0] for ii, ww in enumerate(weight_matrices): w[a:a+ww.shape[0], b:b+ww.shape[1]] = ww a += nodes_per_layer[ii] b += nodes_per_layer[ii+1] return wdef test(): w1 = np.random.rand(4,5) #< 0.50 w2 = np.random.rand(5,6) #< 0.25 w3 = np.random.rand(6,3) #< 0.75 import string node_labels = dict(zip(range(18), list(string.ascii_lowercase))) fig, ax = plt.subplots(1,1) plot_layered_network([w1,w2,w3], layer_labels=['start', 'step 1', 'step 2', 'finish'], ax=ax, node_size=20, node_edge_width=2, node_labels=node_labels, edge_width=5, ) plt.show() returndef test_example(input_dict): weight_matrices, node_labels = _parse_input(input_dict) fig, ax = plt.subplots(1,1) plot_layered_network(weight_matrices, layer_labels=['', '1', '2', '3', '4'], distance_between_layers=10, distance_between_nodes=8, ax=ax, node_size=300, node_edge_width=10, node_labels=node_labels, edge_width=50, ) plt.show() returndef _parse_input(input_dict): weight_matrices = [] node_labels = [] # initialise sources sources = set() for v in input_dict[1].values(): for s in v.keys(): sources.add(s) sources = list(sources) for ii in range(len(input_dict)): inner_dict = input_dict[ii+1] targets = inner_dict.keys() w = np.full((len(sources), len(targets)), np.nan, np.float) for ii, s in enumerate(sources): for jj, t in enumerate(targets): try: w[ii,jj] = inner_dict[t][s] except KeyError: pass weight_matrices.append(w) node_labels.append(sources) sources = targets node_labels.append(targets) node_labels = list(itertools.chain.from_iterable(node_labels)) node_labels = dict(enumerate(node_labels)) return weight_matrices, node_labels# --------------------------------------------------------------------------------# script# --------------------------------------------------------------------------------if __name__ == "__main__": # test() input_dict = { 1: { "Group 1":{"sample_0":0.5, "sample_1":0.5, "sample_2":0, "sample_3":0, "sample_4":0}, "Group 2":{"sample_0":0, "sample_1":0, "sample_2":1, "sample_3":0, "sample_4":0}, "Group 3":{"sample_0":0, "sample_1":0, "sample_2":0, "sample_3":0.5, "sample_4":0.5} }, 2: { "Group 1":{"Group 1":1, "Group 2":0, "Group 3":0}, "Group 2":{"Group 1":0, "Group 2":1, "Group 3":0}, "Group 3":{"Group 1":0, "Group 2":0, "Group 3":1} }, 3: { "Group 1":{"Group 1":0.25, "Group 2":0, "Group 3":0.75}, "Group 2":{"Group 1":0.25, "Group 2":0.75, "Group 3":0} }, 4: { "Group 1":{"Group 1":1, "Group 2":0}, "Group 2":{"Group 1":0.25, "Group 2":0.75} } } test_example(input_dict) pass
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)