首页 > 解决方案 > 无法正确修改 Python 类中的字典

问题描述

我在尝试修改类中的字典列表时遇到了一个奇怪的问题。这是显示该行为的最小可重现代码:

from itertools import product

class Test():
    def __init__(self, grid):
        self.grid = grid
        self.pc =  [dict(zip(self.grid, v)) for v in product(*self.grid.values())]
        for i in range(0, len(self.pc)):
            self.pc[i]['sim_options']['id'] = i

grid = {
    'k': [5, 10, 15, 20],
    'sim_options': [
        {'name': 'cosine', 'batched': True},
        {'name': 'pearson', 'batched': True}
    ]
}
t = Test(grid)

我期望的输出如下:

[{'k': 5, 'sim_options': {'name': 'cosine', 'batched': True, 'id': 0}},
 {'k': 5, 'sim_options': {'name': 'pearson', 'batched': True, 'id': 1}},
 {'k': 10, 'sim_options': {'name': 'cosine', 'batched': True, 'id': 2}},
 {'k': 10, 'sim_options': {'name': 'pearson', 'batched': True, 'id': 3}},
 {'k': 15, 'sim_options': {'name': 'cosine', 'batched': True, 'id': 4}},
 {'k': 15, 'sim_options': {'name': 'pearson', 'batched': True, 'id': 5}},
 {'k': 20, 'sim_options': {'name': 'cosine', 'batched': True, 'id': 6}},
 {'k': 20, 'sim_options': {'name': 'pearson', 'batched': True, 'id': 7}}]

但我得到:

[{'k': 5, 'sim_options': {'name': 'cosine', 'batched': True, 'id': 6}},
 {'k': 5, 'sim_options': {'name': 'pearson', 'batched': True, 'id': 7}},
 {'k': 10, 'sim_options': {'name': 'cosine', 'batched': True, 'id': 6}},
 {'k': 10, 'sim_options': {'name': 'pearson', 'batched': True, 'id': 7}},
 {'k': 15, 'sim_options': {'name': 'cosine', 'batched': True, 'id': 6}},
 {'k': 15, 'sim_options': {'name': 'pearson', 'batched': True, 'id': 7}},
 {'k': 20, 'sim_options': {'name': 'cosine', 'batched': True, 'id': 6}},
 {'k': 20, 'sim_options': {'name': 'pearson', 'batched': True, 'id': 7}}]

我不明白我做错了什么,我不是在迭代一个列表,访问'sim_options'第 i 个列表的字段并在该字典中创建一个新的键值对 ('id':i) 吗?

标签: pythonpython-3.xclassdictionary

解决方案


似乎字典正在通过引用进行更新,所以无论你对一个做什么,都会发生在另一个身上。

您可以使用以下id()功能进行检查:

for i in range(0, len(self.pc)):
    print(id(self.pc[i]['sim_options']))
    self.pc[i]['sim_options']['id'] = i

这给了我相同的重复参考身份(20789353511042078964975104)两个字典'sim_options'

2078935351104
2078964975104
2078935351104
2078964975104
2078935351104
2078964975104
2078935351104
2078964975104

解决此问题的一种方法是复制选项,这将为您提供不同的参考身份。我已经稍微修改了您的代码,以便使用copy(). 它还用于enumerate()循环索引和项目,这比range(len(...))同时需要时更好用。

from itertools import product
from pprint import pprint

class Test:
    def __init__(self, grid):
        self.grid = grid

        self.pc = []
        for i, (k, options) in enumerate(product(self.grid["k"], self.grid["sim_options"])):
            temp = {"k": k, "sim_options": options.copy()}
            temp["sim_options"]["id"] = i
            self.pc.append(temp)

    def get(self):
        return self.pc

grid = {
    "k": [5, 10, 15, 20],
    "sim_options": [
        {"name": "cosine", "batched": True},
        {"name": "pearson", "batched": True},
    ],
}

t = Test(grid)

pprint(t.get())

或者只是使用列表理解构建一个新的字典列表。我通常觉得这种方式更可取。

self.pc = [
    {"k": k, "sim_options": {**option, "id": idx}}
    for idx, (k, option) in enumerate(
        product(self.grid["k"], self.grid["sim_options"])
    )
]

输出:

[{'k': 5, 'sim_options': {'batched': True, 'id': 0, 'name': 'cosine'}},
 {'k': 5, 'sim_options': {'batched': True, 'id': 1, 'name': 'pearson'}},
 {'k': 10, 'sim_options': {'batched': True, 'id': 2, 'name': 'cosine'}},
 {'k': 10, 'sim_options': {'batched': True, 'id': 3, 'name': 'pearson'}},
 {'k': 15, 'sim_options': {'batched': True, 'id': 4, 'name': 'cosine'}},
 {'k': 15, 'sim_options': {'batched': True, 'id': 5, 'name': 'pearson'}},
 {'k': 20, 'sim_options': {'batched': True, 'id': 6, 'name': 'cosine'}},
 {'k': 20, 'sim_options': {'batched': True, 'id': 7, 'name': 'pearson'}}]

推荐阅读