首页 > 解决方案 > 加载到新模型实例后,pytorch state_dict 的序列化更改

问题描述

为什么state_dict在将 a 加载state_dict到相同模型架构的新实例中后,从序列化 pytorch 获得的字节会发生变化?

看一看:

import binascii
import torch.nn as nn
import pickle

lin1 = nn.Linear(1, 1, bias=False)
lin1s = pickle.dumps(lin1.state_dict())
print("--- original model ---")
print(f"hash of state dict: {hex(binascii.crc32(lin1s))}")
print(f"weight: {lin1.state_dict()['weight'].item()}")

lin2 = nn.Linear(1, 1, bias=False)
lin2.load_state_dict(pickle.loads(lin1s))
lin2s = pickle.dumps(lin2.state_dict())
print("\n--- model from deserialized state dict ---")
print(f"hash of state dict: {hex(binascii.crc32(lin2s))}")
print(f"weight: {lin2.state_dict()['weight'].item()}")

印刷

--- original model ---
hash of state dict: 0x4806e6b6
weight: -0.30337071418762207

--- model from deserialized state dict ---
hash of state dict: 0xe2881422
weight: -0.30337071418762207

如您所见, (pickles of the) 的哈希值state_dict不同,而权重被正确复制。我会假设state_dict新模型中的 a 在各个方面都等于旧模型。看起来,它没有,因此不同的哈希值。

标签: pythonpytorchpickle

解决方案


这可能是因为 pickle 不会产生适合散列的 repr(请参阅Using pickle.dumps to hash mutable objects)。比较键可能是一个更好的主意,然后比较存储在 dict-keys 中的张量是否相等/接近。

下面是该想法的粗略实现。

def compare_state_dict(dict1, dict2):
    # compare keys
    for key in dict1:
        if key not in dict2:
            return False
    
    for key in dict2:
        if key not in dict1:
            return False

    for (k,v) in dict1.items():
        if not torch.all(torch.isclose(v, dict2[k]))
            return False
    
    return True

但是,如果您仍然想对 state-dict 进行哈希处理并避免使用isclose上面的比较,您可以使用下面的函数。

def dict_hash(dictionary):
    for (k,v) in dictionary.items():
        # it did not work without hashing the tensor
        dictionary[k] = hash(v)

    # dictionaries are not hashable and need to be converted to frozenset. 
    return hash(frozenset(sorted(dictionary.items(), key=lambda x: x[0])))

推荐阅读