python - 复制 AutoKeras StructuredDataClassifier 的问题
问题描述
我有一个使用 AutoKeras 生成的模型,我想复制该模型,以便可以使用 keras Tuner 构建它以进行进一步的超参数调整。但是我在复制模型时遇到了问题。autokeras模型的模型总结是:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 11)] 0
_________________________________________________________________
multi_category_encoding (Mul (None, 11) 0
_________________________________________________________________
normalization (Normalization (None, 11) 23
_________________________________________________________________
dense (Dense) (None, 16) 192
_________________________________________________________________
re_lu (ReLU) (None, 16) 0
_________________________________________________________________
dense_1 (Dense) (None, 32) 544
_________________________________________________________________
re_lu_1 (ReLU) (None, 32) 0
_________________________________________________________________
dense_2 (Dense) (None, 3) 99
_________________________________________________________________
classification_head_1 (Softm (None, 3) 0
=================================================================
Total params: 858
Trainable params: 835
Non-trainable params: 23
层配置
{'batch_input_shape': (None, 11), 'dtype': 'string', 'sparse': False, 'ragged': False, 'name': 'input_1'}
{'name': 'multi_category_encoding', 'trainable': True, 'dtype': 'float32', 'encoding': ListWrapper(['int', 'int', 'int', 'int', 'int', 'int', 'int', 'int', 'int', 'int', 'int'])}
{'name': 'normalization', 'trainable': True, 'dtype': 'float32', 'axis': (-1,)}
{'name': 'dense', 'trainable': True, 'dtype': 'float32', 'units': 16, 'activation': 'linear', 'use_bias': True, 'kernel_initializer': {'class_name': 'GlorotUniform', 'config': {'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None}
{'name': 're_lu', 'trainable': True, 'dtype': 'float32', 'max_value': None, 'negative_slope': array(0., dtype=float32), 'threshold': array(0., dtype=float32)}
{'name': 'dense_1', 'trainable': True, 'dtype': 'float32', 'units': 32, 'activation': 'linear', 'use_bias': True, 'kernel_initializer': {'class_name': 'GlorotUniform', 'config': {'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None}
{'name': 're_lu_1', 'trainable': True, 'dtype': 'float32', 'max_value': None, 'negative_slope': array(0., dtype=float32), 'threshold': array(0., dtype=float32)}
{'name': 'dense_2', 'trainable': True, 'dtype': 'float32', 'units': 3, 'activation': 'linear', 'use_bias': True, 'kernel_initializer': {'class_name': 'GlorotUniform', 'config': {'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None}
{'name': 'classification_head_1', 'trainable': True, 'dtype': 'float32', 'axis': -1}
我的训练数据是一个数据框,它转换为包含数字和分类数据的字符串类型。由于输出是softmax,我用来LabelBinarizer
转换目标类。
为了确保模型被正确复制,我曾经keras.clone_model
创建模型的副本并尝试自己进行训练。但是当我尝试自己训练它时,尽管达到了 500 个 epoch,但准确性并没有提高。
在从头开始训练模型时,我有什么遗漏吗?
解决方案
AutoKeras 不支持任何直接转换 - 它的依赖关系过于内置,无法与包本身隔离。上面的答案表明缺乏softmax
激活是错误的,因为确实存在:
classification_head_1 (Softm
--> 可能文本被截断
接下来-您是否注意到缺少参数?858
是一个非常小的数字 - 那是因为大多数图层都有0
参数 - Autokeras 使用自定义图层来构成他们的自定义块(更多关于他们的块的信息来自他们的文档)
您可以看到,要重新创建这些自定义层,您需要它们的确切代码 - 在撰写本文时无法隔离(尽管 @haifeng-jin 正在讨论它),因为它们使用特定的包来处理输入数据以及他们的 NAS(神经架构搜索)和他们执行的优化例程的动力。
除非您可以研究他们的代码和自定义层的实现并重新创建它(这本身就是一些工作,但由于代码已经可用,所以工作量不大),如果您将keras.clone_model
其与 pre 一起使用,那将是徒劳的尝试- 定义的 keras 层。这显然会导致模型损坏(例如您目前拥有的模型)。
更重要的是,AutoKeras
HyperParameter 是否可以自行调整 - 如果您想进一步调整模型,只需运行 AutoKeras 更长的时间以获得更好的结果。
tl;博士您不能直接克隆具有包内依赖项的自定义层和块。但是如果你想进行超参数调优,你可以运行更长时间的搜索以获得更好的模型。
推荐阅读
- java - 如何每 8 个输入创建一个新行
- javascript - 如何使用没有箭头函数的 javascript 将此 javascript 对象转换为数组?
- bash - 使用 curl 提交多个 json 有效负载
- typescript - 打字稿需要字段的子集
- typescript - TypeScript 无法选择正确的函数签名
- html - 细节元素似乎忽略了显示弹性或网格?
- java - 积极向后看行为不正确
- android - 如何在 Android 中读取和显示 csv 文件中的数据?
- python - 我们可以使用包含标签图像的 caltech101 数据集检测图像中的多个对象吗?
- python-3.x - 如何访问从 kaggle 下载到 Colaboratory notebook 的文件?