首页 > 解决方案 > 用 CSV 文件读写 NumPy 数组字典

问题描述

我正在使用 Python 2.7 和 networkx 来制作我的网络的弹簧布局图。为了比较不同的设置,我想将 networkx 计算和使用的位置存储到一个文件中(我目前的选择:csv)并在每次制作新图时读取它。听起来很简单,我的代码如下所示:

pos_spring = nx.spring_layout(H, pos=fixed_positions, fixed = fixed_nodes, k = 4, weight='passengers')

这条线计算稍后用于绘图的位置,我想存储。字典 (pos_spring) 如下所示:

{1536: array([ 0.53892015,  0.984306  ]), 
1025: array([ 0.12096853,  0.82587976]), 
1030: array([ 0.20388712,  0.7046137 ]),

写入文件:

w = csv.writer(open("Mexico_spring_layout_positions_2.csv", "w"))
for key, val in pos_spring.items():
    w.writerow([key, val])

该文件的内容如下所示:

1536,[ 0.51060853  0.80129841]
1025,[ 0.47442269  0.99838177]
1030,[ 0.02952256  0.45073233]

读取文件:

with open('Mexico_spring_layout_positions_2.csv', mode='r') as infile:
    reader = csv.reader(infile)
    pos_spring = dict((rows[0],rows[1]) for rows in reader)

pos_spring 的内容现在看起来像这样:

{'2652': '[ 0.78480322  0.103894  ]', 
'1260': '[ 0.8834103   0.82542163]', 
'2969': '[ 0.33044548  0.31282113]',

不知何故,这些数据看起来与存储在 csv 文件中的原始字典不同。写入和/或读取数据以解决此问题时需要更改哪些内容?提前致谢。

亲切的问候,弗兰克

标签: pythonpython-2.7csvnumpydictionary

解决方案


您不能将 NumPy 数组存储在 CSV 文件中并维护数据类型。请记住,CSV 文件只能存储文本。您看到的是 NumPy 数组的文本表示。

相反,您可以在写入 csv 文件时解压缩 NumPy 数组:

import csv

d = {1536: np.array([ 0.53892015,  0.984306  ]), 
     1025: np.array([ 0.12096853,  0.82587976]), 
     1030: np.array([ 0.20388712,  0.7046137 ])}

fp = r'C:\temp\out.csv'

with open(fp, 'w', newline='') as fout:
    w = csv.writer(fout)
    for key, val in d.items():
        w.writerow([key, *val])

然后在您回读时转换回 NumPy。对于此步骤,您可以使用字典理解:

with open(fp, 'r') as fin:
    r = csv.reader(fin)
    res = {int(k): np.array(list(map(float, v))) for k, *v in r}

print(res)

{1536: array([ 0.53892015,  0.984306  ]),
 1025: array([ 0.12096853,  0.82587976]),
 1030: array([ 0.20388712,  0.7046137 ])}

推荐阅读