python - 我的自定义 keras 损失函数的问题
问题描述
我是 Tensorflow/Keras 的初学者,我想自定义我的损失函数,我的代码在这里:
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=input_shape),
tf.keras.layers.Conv2D(64, kernel_size=(3, 3),
activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(4, activation=tf.nn.softmax)
])
def lossFunction(y_true, y_pred):
y_true = tf.placeholder(shape=[330,28,28,3], dtype=tf.float32)
y_pred = tf.placeholder(shape=[1,28,28,3], dtype=tf.float32)
y_true = tf.convert_to_tensor(y_true, dtype=tf.float32)
y_pred = tf.convert_to_tensor(y_pred, dtype=tf.float32)
loss=(1/1568)*(K.abs(K.pow((y_true - y_pred),2)))
return loss
lossFunction(x_train,arrTest)
with tf.device('/gpu:0'):
model.compile(optimizer = 'adam',loss=lossFunction, metrics=['accuracy'])
model_log = model.fit(x_train, arrTest, batch_size=x_train.shape[0], epochs=x_train.shape[0],verbose=1)
score = model.evaluate(x_test, y_test, verbose=1)
我收到这些错误:
File "Classification.py", line 184, in <module>
model_log = model.fit(x_train, arrTest, batch_size=x_train.shape[0], epochs=x_train.shape[0],verbose=1)
File "/home/sabrinamehlal/.local/lib/python2.7/site-packages/tensorflow/python/keras/engine/training.py", line 776, in fit
shuffle=shuffle)
File "/home/sabrinamehlal/.local/lib/python2.7/site-packages/tensorflow/python/keras/engine/training.py", line 2436, in _standardize_user_data
training_utils.check_array_lengths(x, y, sample_weights)
File "/home/sabrinamehlal/.local/lib/python2.7/site-packages/tensorflow/python/keras/engine/training_utils.py", line 456, in check_array_lengths
'and ' + str(list(set_y)[0]) + ' target samples.')
ValueError: Input arrays should have the same number of samples as target arrays. Found 330 input samples and 1 target samples.
我的 x_train =[330,28,28,3] 的形状和我的 arrTest 的形状 = [1,28,28,3]
请问你能帮帮我吗?
解决方案
推荐阅读
- python - 相当于scitkit-learn的decision_function?
- javascript - 格式通过浏览器存储用户的语言
- javascript - Angular Highcharts - 点击后图表参考消失了
- android - 访问 Android MainActivity“布局”
- ios - 我正在尝试将 firebase 合并到我的统一项目中,但不断收到关于我的“Firebase.Crashlytics.Editor.dll”的错误
- javascript - 刽子手游戏,我的代码选择一个单词然后生成一个隐藏的单词长度,然后出于某种原因选择另一个单词?
- html - 中心标志与背景图像一起显示
- html - 协助构建线框部分
- composer-php - autoload.php中ComposerAutoloaderInit后的随机字符串有什么作用?
- javascript - NodeJS Observable 订阅不返回任何内容