首页 > 解决方案 > 我需要 MNIST 打印大量数字

问题描述

所以我有这个代码:

with torch.no_grad():
  X_test = mnist_test.test_data.view(-1, 28 * 28).float().to(device)
  Y_test = mnist_test.test_labels.to(device)
  prediction = linear(X_test)
  correct_prediction = torch.argmax(prediction, 1) == Y_test
  accuracy = correct_prediction.float().mean()
  print('Accuracy:', accuracy.item())
  r = 1504207845
  X_single_data = mnist_test.test_data[r:r + 1].view(-1, 28 * 28).float().to(device)
  Y_single_data = mnist_test.test_labels[r:r + 1].to(device)
  print('Label: ', Y_single_data.item())
  single_prediction = linear(X_single_data)
  print('Prediction: ', torch.argmax(single_prediction, 1).item())
  plt.imshow(mnist_test.test_data[r:r + 1].view(28, 28), cmap='Greys', interpolation='nearest')
  plt.show()

我收到此错误:ValueError: only one element tensors can be converted to Python scalars

如果r是一位数字,则有效。但是,如果r超过四位数字,它不会移动。我希望这个设备能识别更多的数字。很高兴能识别出我想要的位置的数字。

标签: pythonconv-neural-networktorchmnist

解决方案


推荐阅读