keras - 如果validation_data ValueError 在model.fit() 中引发错误:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()
问题描述
我正在尝试运行一个简单的自动编码器模型。我正在从包含单词嵌入的 csv 中读取训练数据。我有这段代码,但标题中的错误是在model.fit()
函数中引发的,并与我的validation data
. 我尝试了很多东西,但错误仍然存在。我是 NLP 的新手,也许我的逻辑完全错误我不知道。因此,如果有人可以提供帮助,我将不胜感激。这是我的代码:
def train_predict(df):
X_train, X_validation = train_test_split(df, test_size=0.3, random_state=42, shuffle=True)
X = X_train.iloc[:, :-1].to_numpy() #shape is (1880,220) in here
X = tf.expand_dims(X, axis=-1) #shape is (1880,220,1)
X_val = X_validation.iloc[:,:-1].to_numpy() #shape is (300,220)
X_val= tf.expand_dims(X_val, axis=-1) #shape is (300,220,1)
inputs, decoder_output, visualization = autoEncoder(X)
model = Model(inputs=inputs, outputs=decoder_output)
encoder_model = Model(inputs=inputs, outputs=visualization)
batch_size = 128
train_steps = len(X) // batch_size
val_steps = len(X_val) // batch_size
model.summary()
model.compile(optimizer='adam', metrics=['accuracy'], loss='mean_squared_error')
model.fit(X, steps_per_epoch=train_steps, validation_data=X_val, validation_steps=val_steps,epochs=100)
result = model.evaluate(X_val, steps=10)
我的自动编码器功能代码的详细信息如下:
def autoEncoder(X_train):
inputs = tf.keras.layers.Input(shape=(X_train.shape[1],1))
# parameters
conv_1 = Conv1D(filters=64, kernel_size=3, activation='relu', padding='same')(inputs)
max_pool_1 = MaxPool1D(pool_size=2)(conv_1)
conv_2 = Conv1D(filters=128, kernel_size=3, activation='relu', padding='same')(max_pool_1)
max_pool_2 = MaxPool1D(pool_size=2)(conv_2)
# BOTTLE NECK
bottle_neck = Conv1D(filters=256, kernel_size=3, activation='relu', padding='same')(max_pool_2)
visualization = Conv1D(filters=1, kernel_size=3, activation='sigmoid', padding='same')(bottle_neck)
# DECODER
conv_3 = Conv1D(filters=128, kernel_size=3, activation='relu', padding='same')(bottle_neck)
upsample_1 = UpSampling1D(size=2)(conv_3)
conv_4 = Conv1D(filters=64, kernel_size=3, activation='relu', padding='same')(upsample_1)
upsample_2 = UpSampling1D(size=2)(conv_4)
decoder_output = Conv1D(filters=1, kernel_size=3, activation='sigmoid', padding='same')(upsample_2)
return inputs, decoder_output, visualization
解决方案
如果您可以复制粘贴代码产生的整个错误堆栈,那就太好了,每个人都应该遵循与错误相关的问题,因为这使调试变得更加容易。
这是使用虚拟数据集重现相同错误的尝试:
import numpy as np
import tensorflow as tf
np.random.seed(11)
np.set_printoptions(precision=2)
def autoEncoder(X_train):
inputs = tf.keras.layers.Input(shape=(X_train.shape[1], 1))
conv_1 = tf.keras.layers.Conv1D(filters=64, kernel_size=3, activation='relu', padding='same')(inputs)
max_pool_1 = tf.keras.layers.MaxPool1D(pool_size=2)(conv_1)
conv_2 = tf.keras.layers.Conv1D(filters=128, kernel_size=3, activation='relu', padding='same')(max_pool_1)
max_pool_2 = tf.keras.layers.MaxPool1D(pool_size=2)(conv_2)
bottle_neck = tf.keras.layers.Conv1D(filters=256, kernel_size=3, activation='relu', padding='same')(max_pool_2)
visualization = tf.keras.layers.Conv1D(filters=1, kernel_size=3, activation='sigmoid', padding='same')(bottle_neck)
conv_3 = tf.keras.layers.Conv1D(filters=128, kernel_size=3, activation='relu', padding='same')(bottle_neck)
upsample_1 = tf.keras.layers.UpSampling1D(size=2)(conv_3)
conv_4 = tf.keras.layers.Conv1D(filters=64, kernel_size=3, activation='relu', padding='same')(upsample_1)
upsample_2 = tf.keras.layers.UpSampling1D(size=2)(conv_4)
decoder_output = tf.keras.layers.Conv1D(filters=1, kernel_size=3, activation='sigmoid', padding='same')(upsample_2)
return inputs, decoder_output, visualization
X = np.random.randn(1880, 220)
X_val = np.random.randn(300, 220)
X = np.expand_dims(X, axis=-1)
X = tf.convert_to_tensor(X) # (1880, 220, 1)
X_val = np.expand_dims(X_val, axis=-1)
X_val = tf.convert_to_tensor(X_val) # (300, 220, 1)
inputs, decoder_output, visualization = autoEncoder(X)
model = tf.keras.Model(inputs=inputs, outputs=decoder_output)
encoder_model = tf.keras.Model(inputs=inputs, outputs=visualization)
batch_size = 128
train_steps = len(X) // batch_size
val_steps = len(X_val) // batch_size
model.compile(optimizer='adam', metrics=['accuracy'], loss='mean_squared_error')
model.fit(X, steps_per_epoch=train_steps, validation_data = X_val, validation_steps=val_steps, epochs=100)
在 google-colab 上,这给出了以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-29-a889c5a46f35> in <module>()
3 val_steps = len(X_val) // batch_size
4 model.compile(optimizer='adam', metrics=['accuracy'], loss='mean_squared_error')
----> 5 model.fit(X, steps_per_epoch=train_steps, validation_data = X_val, validation_steps=val_steps, epochs=100)
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1041 (x, y, sample_weight), validation_split=validation_split))
1042
-> 1043 if validation_data:
1044 val_x, val_y, val_sample_weight = (
1045 data_adapter.unpack_x_y_sample_weight(validation_data))
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in __bool__(self)
990
991 def __bool__(self):
--> 992 return bool(self._numpy())
993
994 __nonzero__ = __bool__
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
这与您的 OP 相同。最好发布错误堆栈的原因是因为答案隐藏在这些行中,特别是:
1043 if validation_data:
1044 val_x, val_y, val_sample_weight = (
1045 data_adapter.unpack_x_y_sample_weight(validation_data))
的格式validation_data
与(x, y, sample_weight)
. 这是fit 方法文档必须说的:
validation_data
将覆盖validation_split
.validation_data
可能是: -(x_val, y_val)
Numpy 数组或张量的元组 -(x_val, y_val, val_sample_weights)
Numpy 数组的元组 - 数据集 对于前两种情况,必须提供 batch_size。对于最后一种情况,validation_steps
可以提供。
我想你现在明白为什么你会出错,Y
你的自动编码器没有。这不应该有任何问题,因为你X
本身就是你的Y
. 这是编码器教程中的一句话,可以在这种情况下帮助我们:
x_train
使用作为输入和目标来训练模型。将encoder
学习将数据集从 784 维压缩到潜在空间,decoder
并将学习重建原始图像。
因此,您应该做的是编写以下内容:
model.fit(X, X, steps_per_epoch=train_steps, validation_data=(X_val, X_val), validation_steps=val_steps, epochs=100)
这确实开始了培训!
推荐阅读
- laravel-5 - 如何使用 willvincent/feeds 插件读取 pubdate 字段?
- python - Python dash 给了我一个“错误加载布局”我该如何解决这个问题?
- javascript - Bull 队列并发问题
- javascript - 是否有代码可以与脚本中的每个函数一起调用函数?
- git - Visual Studio Git sln 文件合并冲突
- powerbi - Power BI DAX 计算列以从不同的表中查找替代项
- c# - System.Web.Sockets 从套接字接收信息的奇怪行为
- javascript - 如何修复我的电子邮件链接中创建错误地址的变量?
- javascript - Javascript 下拉选项 0-100 以 0.5 递增
- node.js - 如何通过 Jade 和 NodeJS 访问嵌入式 Mongodb 文档?