首页 > 解决方案 > Chainer 和 float64 支持

问题描述

我想知道如何让chainer 支持float64 精度计算。下面是我在chainer 中编写的简单自动编码器代码。最初我使用 float32 dtype 运行它,但现在我想测试 float64 精度我该怎么做?

class AutoEncoder(chainer.Chain):
    def __init__(self, num_input=1024, num_hidden=512, activate=F.leaky_relu, dropout_ratio=0.0):
        super(AutoEncoder, self).__init__(
            encoder=L.Linear(num_input, num_hidden),
            decoder=L.Linear(num_hidden, num_input))
        self.num_input = num_input
        self.num_hidden = num_hidden
        self.activate = activate
        self.dropout_ratio = dropout_ratio

    def __call__(self, x, hidden=False):
        h = F.dropout(self.activate(self.encoder(x)), ratio=self.dropout_ratio)
        if hidden:
            return h

        y = F.dropout((self.decoder(h)), ratio=self.dropout_ratio)
        return y

当我尝试使用 float64 数据运行它时,它会引发 dtype 不匹配错误。

标签: pythonchainer

解决方案


推荐阅读