首页 > 解决方案 > 将嵌套字典重构为 CSV 文件中的三角矩阵

问题描述

我有以下 CSV 文件:

A,B,0.5
A,C,0.4
A,D,0.2
B,C,0.7
B,D,0.6
C,D,0.9

任务是将文件转换为矩阵格式,如下所示(为了便于阅读,我插入了空格):

 ,A  ,B  ,C  ,D
B,0.5,   ,   ,
C,0.4,0.7,   ,
D,0.2,0.6,0.9,

矩阵可以是上三角形或下三角形。我知道使用 Pandas 有一种超级简单的方法,但我希望这个程序只依赖 Python 标准库(因为我无法控制的环境问题)。

这是我到目前为止开发的代码:

from csv import reader
from collections import defaultdict

# reading of CLI arguments to args.path omitted for brevity

# Part 1
with open(args.path, 'r') as infile:
    matrix = defaultdict(dict)
    for line in reader(infile):
        matrix[line[0]][line[1]] = float(line[2])

# Part 2
cols = list(matrix)
rows = set()
for col in cols:
    for j in matrix[col]:
        rows.add(j)
rows = sorted(list(rows), key=(lambda x: len(matrix[x].keys())), reverse=True)

# Part 3
print(',' + ','.join(cols))
for row in rows:
    curr_row = []
    for col in cols:
        if row in matrix[col]:
            curr_row.append(str(matrix[col][row]))
        else:
            curr_row.append('')
    curr_row = ','.join(curr_row)
    print(f'{row},{curr_row}')

它吐出以下输出:

,A,B,C
B,0.5,,
C,0.4,0.7,
D,0.2,0.6,0.9

虽然它做了我想要的,但我认为# Part 2它是一团糟,# Part 3可以改进。我想重构代码,使其更具可读性和性能(如果可能的话)。你知道一些我可以用来改进这段代码的 Python 技巧吗?请记住,我需要一个仅使用标准库的解决方案。

更好的是具有以下格式的输出,但这不是本文的目标。尽管如此,建议还是非常受欢迎的。

[   A    B    C    D   ]
[B] 0.5
[C] 0.4 0.7
[D] 0.2 0.6 0.9

标签: pythondictionaryexport-to-csv

解决方案


我想出了一个更具可读性的解决方案,尽管在性能方面不是最好的(它需要 O(n^2) 的时间和空间......)。我认为这足以满足我的需求,但如果有人知道更好的解决方案,请分享。

from csv import reader
from collections import defaultdict

# reading of CLI arguments to args.path omitted for brevity

with open(args.path, 'r') as infile:
    matrix = defaultdict(dict)
    for line in reader(infile):
        matrix[line[0]][line[1]] = line[2]
        matrix[line[1]][line[0]] = line[2]

labels = sorted(matrix)
print(',' + ','.join(labels))
for row in labels:
    curr_row = []
    repeated = False
    for col in labels:
        if row in matrix[col] and not repeated:
            curr_row.append(matrix[col][row])
        else:
            curr_row.append('')
            repeated = True
    curr_row = ','.join(curr_row)
    print(row, curr_row, sep=',')

推荐阅读