python - Pytorch 错误:输入应为标量类型 Long 但发现 Float
问题描述
我正在尝试创建一个深度学习算法来玩蛇。我正在尝试使用 PyTorch 来实现这一点。这是我的(凌乱的,稍后会修复)代码片段:
## DOUBLE Q DEEP LEARNING NETWORK
class SnakeNet(nn.Module):
"""mini cnn structure
input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.online = nn.Sequential(
# nn.Conv2d(in_channels=input_dim, out_channels=32, kernel_size=8, stride=4),
# nn.ReLU(),
# nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
# nn.ReLU(),
# nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
# nn.ReLU(),
# nn.Flatten(),
# nn.Linear(3136, 512),
# nn.ReLU(),
# nn.Linear(512, output_dim),
nn.Linear(input_dim, 200),
nn.Linear(200, 20),
nn.Linear(20, 50),
nn.Linear(50, output_dim),
)
self.target = copy.deepcopy(self.online)
# Q_target parameters are frozen.
for p in self.target.parameters():
p.requires_grad = False
def forward(self, input, model):
input = input.long()
if model == "online":
return self.online(input)
elif model == "target":
return self.target(input)
# EXPLOIT
else:
state = torch.tensor(state)
state = state.unsqueeze(0)
action_values = self.net(state, model="online")
dir = torch.argmax(action_values, axis=1).item()
我在第 221 行收到错误:action_values = self.net(state, model="online")
声明我的输入(状态)是一个浮点数,尽管它是一个 tensorLong,我通过打印 type() 证明了这一点。在建议添加state = state.type.tensorLong()
这个之前没有用,主要是因为它已经很长了。
错误:
Traceback (most recent call last):
File "snakeGame.py", line 324, in <module>
prev_location, action = snake.act(current_state)
File "snakeGame.py", line 222, in act
action_values = self.net(state, model="online")
File "/Users/gavinhartog/opt/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "snakeGame.py", line 57, in forward
return self.online(input)
File "/Users/gavinhartog/opt/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/gavinhartog/opt/anaconda3/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/Users/gavinhartog/opt/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/gavinhartog/opt/anaconda3/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward
return F.linear(input, self.weight, self.bias)
File "/Users/gavinhartog/opt/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: expected scalar type Long but found Float
这是在 torch.tensor 之前的状态的原始内容和形状:
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], etc, etc, etc
我尝试了不同的东西,比如 Conv2d 和不同的损失函数,都是同样的错误。提前致谢。
解决方案
该错误有点令人困惑。但我想你有类型state
转换 tofloat
而不是 tolong
state = state.float()
因为nn.Linear
总是需要浮点数。
推荐阅读
- java - Zip4j 允许用户通过 7zip 更新加密的 zip
- java - Spring-Cloud-Stream-Kafka 自定义健康检查未提供 Kafka 状态
- symfony - Symfony Doctrine - 防止 slug 为空
- java - 如何在 FileInputStream 中加载外部图像
- node.js - IBM Watson 双字节字符串转换
- dart - 双击/单击颤动中的onBack按钮时,我需要关闭我的应用程序
- authentication - 第三方 API 访问的 OAuth 流程
- python - Django 数据库连接库问题
- highcharts - 3D Highcharts 中的数据标签错位
- sql - 正则表达式