首页 > 解决方案 > Pytorch 中的缓冲区是什么?

问题描述

我了解register_buffer的作用以及register_buffer 和 register_parameters之间的区别。

但是 PyTorch 中缓冲区的精确定义是什么?

标签: pythonpytorch

解决方案


这可以通过查看实现来回答:

def register_buffer(self, name, tensor):
    if '_buffers' not in self.__dict__:
        raise AttributeError(
            "cannot assign buffer before Module.__init__() call")
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("buffer name should be a string. "
                        "Got {}".format(torch.typename(name)))
    elif '.' in name:
        raise KeyError("buffer name can't contain \".\"")
    elif name == '':
        raise KeyError("buffer name can't be empty string \"\"")
    elif hasattr(self, name) and name not in self._buffers:
        raise KeyError("attribute '{}' already exists".format(name))
    elif tensor is not None and not isinstance(tensor, torch.Tensor):
        raise TypeError("cannot assign '{}' object to buffer '{}' "
                        "(torch Tensor or None required)"
                        .format(torch.typename(tensor), name))
    else:
        self._buffers[name] = tensor

也就是说,缓冲区的名称:

  • 必须是一个字符串:not isinstance(name, torch._six.string_classes)
  • 不能包含.(点):'.' in name
  • 不能为空字符串:name == ''
  • 不能是模块的属性:hasattr(self, name)
  • 应该是唯一的:name not in self._buffers

tensor(你猜怎么着?):

  • 应该是张量:isinstance(tensor, torch.Tensor)

因此,缓冲区只是具有这些属性的张量,注册在_buffersa 的属性中Module


推荐阅读