python - 如何在自定义 keras 模型中添加 tf.keras.Input?
问题描述
我正在编写一个自定义 keras 模型,这是我的代码:
class Model(tf.keras.Model):
def __init__(self, first_layer, num_classes):
super(Model, self).__init__()
self.layer_1 = tf.keras.layers.Dense(first_layer, activation='relu')
self.layer_2 = tf.keras.layers.Dense(num_classes, activation='softmax')
def call(self,inp):
output = self.layer_1(inp)
output = self.layer_2(output)
return output
但我收到此错误:
ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`.
我找到了 tf.keras.input,但是所有的例子都是顺序模型,例如顺序模型,这是 keras.input 的解决方案:
encoder_input = keras.Input(shape=(28, 28, 1), name="img")
x = layers.Conv2D(16, 3, activation="relu")(encoder_input)
x = layers.Conv2D(32, 3, activation="relu")(x)
但是我如何在自定义 keras 模型中引入它?
请
解决方案
版本:TF 2.6
编辑代码:
import tensorflow as tf
class Model(tf.keras.Model):
def __init__(self, first_layer, num_classes):
super(Model, self).__init__()
self.layer_1 = tf.keras.layers.Dense(first_layer, activation='relu')
self.layer_2 = tf.keras.layers.Dense(num_classes, activation='softmax')
def call(self,inp):
output = self.layer_1(inp)
output = self.layer_2(output)
return output
encoder_input = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(16, 3, activation="relu")
#x = layers.Conv2D(32, 3, activation="relu")(x)
model = Model(5,3) #instantiated the model
y=model(encoder_input)#took input
print(y)
输出:
KerasTensor(type_spec=TensorSpec(shape=(None, 28, 28, 3), dtype=tf.float32, name=None), name='model_4/dense_6/Softmax:0', description="created by layer 'model_4' ")
参考: https ://www.tensorflow.org/guide/keras/custom_layers_and_models
推荐阅读
- php - 使用 laravel excel 或 php 编辑现有的 excel 文件
- css - 表格中的div错位
- javascript - 如何获取javascript的输入值
- node.js - 如何将 Post Item ID 传递给 getInitialProps?
- ubuntu - 连接到 MariaDb 时出现 MySql Workbench 错误。表“performance_schema.user_variables_by_thread”不存在
- python - 使用pandas groupby不工作计数不同
- javascript - React 纯 Javascript 上的 Photo Sphere Viewer
- postgresql - Postgresql MD5 VS 信任功能
- java - Jsoup 不适用于包含非字母字符的编码链接
- python - 如何对 pandas 列中的列表执行 One Hot Encoding?