python - Tensorflow 模型训练:缺少 1 个必需的位置参数:'self'
问题描述
我正在尝试按照训练你的第一个神经网络的示例练习神经网络分类器:基本分类,这是我的代码,直到模型训练点:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib.pyplot import show
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from matplotlib.pyplot import imshow
from matplotlib.pyplot import colorbar
from matplotlib.pyplot import axis
from matplotlib.pyplot import plot
from matplotlib.pyplot import show
print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
#figure(); imshow(train_images[1]); colorbar(); axis('auto')
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
N1, N2, N3 = test_images.shape
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential
([
keras.layers.Flatten(input_shape=(N2, N3)),
keras.layers.Dense(128, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5)
它返回错误
TypeError: _method_wrapper() missing 1 required positional argument: 'self'
这发生在
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
我google了一下,好像
m = model()
m.compile()
可以避免“自我”错误。然而,它得到了新的错误,训练仍然没有发生。
我只是想知道我应该如何修改代码,以便我可以让模型像这样训练:
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 1s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
解决方案
对您的代码进行了一些细微的修改。希望你能跟进。我没有Sequential()
在model
.
import tensorflow as tf
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
from matplotlib.pyplot import show
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from matplotlib.pyplot import imshow
from matplotlib.pyplot import colorbar
from matplotlib.pyplot import axis
from matplotlib.pyplot import plot
from matplotlib.pyplot import show
print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
#figure(); imshow(train_images[1]); colorbar(); axis('auto')
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
N1, N2, N3 = test_images.shape
train_images = train_images / 255.0
test_images = test_images / 255.0
# model = Sequential
# ([
# keras.layers.Flatten(input_shape=(N2, N3)),
# keras.layers.Dense(128, activation=tf.nn.relu),
# keras.layers.Dense(10, activation=tf.nn.softmax)
# ])
model= Sequential()
model.add(Dense(128, activation=tf.nn.relu))
model.add(Dense(10, activation=tf.nn.softmax))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_images.reshape(len(train_images),784), train_labels, epochs=5)
使用此代码,它运行如下。
32/60000 [.......................] - ETA:3:02 - 损失:2.6468 - acc:0
1344/60000 [.......................] - ETA:6s - 损失:1.3037 - acc:0.5
2816/60000 [>.......................] - ETA:4s - 损失:1.0207 - acc:0.6
4256/60000 [=>.......................] - ETA:3s - 损失:0.9073 - acc:0.6
5632/60000 [=>.......................] - ETA:2s - 损失:0.8394 - acc:0.7
7104/60000 [==>.......................] - ETA:2s - 损失:0.7912 - acc:0.7
推荐阅读
- angular - 路由使用参数在父组件之前执行子组件
- swift - 获取图像后,uiimageview 不显示任何内容
- javascript - 如何为材料表 [dataSource] 正确格式化此数据?
- python - 与不同的模块共享一个功能
- mysql - 在 Mysql 8 中选择顺序表时出现语法错误
- php - phpmyadmin 中的错误 - `缺少 mysqli 扩展`
- serial-port - rs-232的通讯参数是由我们决定还是由设备决定?
- extjs - 根据其他组合框中的值填充组合框存储
- java - 在 Spring Boot 2.0.4 中从 http 重定向到 https
- javascript - 什么是打字稿编号范围