python - Python的Keras:尽管我的检查显示不匹配,但仍声称形状不匹配
问题描述
当我逐行检查我的代码时,我表明我的形状符合要求:
num_samples=100;
input_shape = (num_samples,26,76,1);
x = tf.random.normal(input_shape); y=[];
for i in range(num_samples):
temp=[0,0,0]; proc = random.randint(0,len(alldata)-1);
temp[proc]=1; y.append(np.array(temp));
y=tf.convert_to_tensor(y);
model_conv = Sequential()
x2 = Conv2D(1, (3,3), activation='relu', input_shape=input_shape[1:])(x)
x3 = Conv2D(1, (5,5), activation='relu', input_shape=(20,70,1))(x2)
x4 = Conv2D(1, (5,21), activation='relu', input_shape=(16,50,1))(x3)
x5 = Conv2D(1, (1,50), activation='relu', input_shape=(16,1,1))(x4)
x6 = Flatten()(x5)
x7 = Dense(16, activation='relu')(x6)
x8 = Dense(3, activation='softmax')(x7)
print(np.shape(x8)); print(type(x8));
print(np.shape(y)); print(type(y))
...返回 x8 和 y 都是形状(100,3) 的 EagerTensor - 相同的形状。这是 y 的标签形状应该匹配的点,对吗?然而,当我构建此模型进行编译和拟合时,以下代码声称“ValueError:形状不匹配:标签的形状(收到 (30,))应该等于 logits 的形状,除了最后一个维度(收到 (10, 3) )。” 它认为这个 (30,) 和 (10,3) 来自哪里?为什么它的输出与我上面的输出有什么不同?我错过了什么?
model_conv = Sequential()
model_conv.add(Conv2D(1, (3,3), activation='relu', input_shape=input_shape[1:])); #--> (24,74)
model_conv.add(Conv2D(1, (5,5), activation='relu', input_shape=(20,70,1))); #--> (20,70)
model_conv.add(Conv2D(1, (5,21), activation='relu', input_shape=(16,50,1))); #--> (16,50)
model_conv.add(Conv2D(1, (1,50), activation='relu', input_shape=(16,1,1))); #--> (16,1)
model_conv.add(Flatten());
model_conv.add(Dense(16, activation='relu'));
model_conv.add(Dense(3, activation='softmax'));
model_conv.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(learning_rate=0.001,beta_1=0.8), metrics=['accuracy']);
#sparse_categorical since the labels are not one-hot encoded
model_conv.summary(); #tf.random.set_seed(47);
model_conv.fit(x,y, epochs=50, batch_size=10, verbose=0);
这些是我认为最不麻烦的事情,然后最终花时间尝试解决。请协助。谢谢
解决方案
推荐阅读
- java - JPA中如何实现复杂的实体关系
- javascript - 使用pixi视口时如何旋转容器内的对象?
- c# - 如何在浏览器中的真实 android 设备上运行 Selenium C# 测试?
- javascript - 在浏览器中使用 JavaScript 读取带有 for-await 的流
- ios - UNUserNotificationCenter.getDeliveredNotifications() 能否用于检索用户导向的推送通知
- java - Adobe Acrobat“另存为文本”解析器
- java - java.lang.NullPointerException:尝试在空对象引用上调用虚拟方法“double android.location.Location.getLongitude()”
- cmd - 在 main 方法中配置的 Cucumber 测试未从 jar 文件执行
- python - sys.excepthook 函数如何与 PyQt5 一起工作?
- r - 如何在单个帧中绘制两个二元直方图?