tensorflow - TypeError: ('不是 JSON 可序列化的:',)
问题描述
我有一个生产模型,可以在ModelConfig
我自己编写的一个class Config(dict)
. ModelConfig
我正在设置hidden_size=123
. 以下是我的生产模型代码的简化:
class Config(dict):
# ..
def as_dict(self) -> dict:
return self._serialize()
class ModelConfig(Config):
hidden_size: int = 123
现在,我有一个自定义的 keras 模型,它是这样实现的:
class MyModel(keras.Model):
def __init__(self, config: ModelConfig):
self.config = config
self.dense = layers.Dense(config.hidden_size)
# ...
def get_config():
return self.config.as_dict()
一切正常,除了调用MyModel#save
:
model.save(model_path, save_format='tf')
我得到一个TypeError
说法:
TypeError: ('Not JSON Serializable:', <tf.Tensor: shape=(), dtype=float32, numpy=123.0>)
我知道这ModelConfig#hidden_size
是成为 Tensor/EagerTensor 的属性。
这里奇怪的是,在我的一个测试脚本中,一切都按预期工作。我创建它是为了查看是否Config
会导致问题,save()
但情况似乎并非如此:
因此,以下内容按预期工作:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from config import Config
class ModelConfig(Config):
hidden_size: int = 123
class MyLayer(layers.Layer):
def __init__(self, config, other=None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.other = other if other is not None else None
self.dense = layers.Dense(config.hidden_size + 1)
def call(self, inputs, **kwargs):
x = self.other(inputs, **kwargs) if self.other is not None else inputs
return self.dense(x, **kwargs)
class MyModel(keras.Model):
def __init__(self, config: ModelConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.other = MyLayer(config, other=MyLayer(config))
self.dense = layers.Dense(config.hidden_size)
def call(self, inputs, **kwargs):
x = self.dense(inputs, **kwargs)
return self.other(x, **kwargs)
def get_config(self):
return self.config.to_dict()
@tf.function(
input_signature=[tf.TensorSpec(shape=(None, 100))]
)
def infer(self, inputs):
return self.call(inputs)
def main():
fp = '/tmp/mymodel'
config = ModelConfig()
model = MyModel(config)
model(np.random.rand(1, 100))
model.save(fp, save_format='tf')
model = tf.saved_model.load(fp)
print(model.infer(np.random.rand(1, 100)))
print('All done.')
if __name__ == '__main__':
main()
这意味着问题不是Config
对象,而是由于某种原因在我的生产模型中,它的一个属性导致它在调用时失败save()
。
这可能是一个大问题,但有人知道在哪里寻找问题吗?
在一种情况下,该属性成为无法序列化的 Tensor/EagerTensor(这是有道理的),而在测试脚本情况下,它保持int
按预期工作。
我也尝试tf.saved_model.save
过同样的结果。
解决方案
我猜是 tf.Tensor 不是 JSON 可序列化的。您可以尝试将 tf.Tensor 的对象转换为 numpy 数组,然后您应该能够保存配置以供以后重新加载。
推荐阅读
- sql - 带有用户提示的 PL SQL 记录填充
- php - Yajra 数据表未在数据表中正确加载数据
- php - 给定流在 FPDI 中不可搜索
- excel - 索引超出范围必须小于或为负
- flutter - 添加 # 使用 Url_launcher 登录 tel uri
- css - 如何在 Firefox 中裁剪和对齐带有绝对单位的淡出 SVG 蒙版到右下角?
- sql-server - SQL 服务器负载平衡选项?
- angular - 从查看子组件中获取值(从子到父)
- paypal - Sandbox paypal php 付款发送到 null 未验证
- java - Spring 在手动类实例化期间注入 @Autowired 字段