tensorflow.js - 如何在 TFJS hub 模型上使用 Model.save 功能?
问题描述
我不知道 javascript,所以我想将仅在 JS 中可用的 HUB 模型移动到 SavedModel 格式。
我从教程中复制了这个脚本并尝试添加 model.save 函数,但它不起作用。
这是脚本:
<html><head>
<!-- Load the latest version of TensorFlow.js -->
<script src="https://unpkg.com/@tensorflow/tfjs"></script>
<script src="https://unpkg.com/@tensorflow-models/mobilenet"></script>
</head>
<body>
<div id="console"></div>
<!-- Add an image that we will use to test -->
<img id="img" src="https://i.imgur.com/JlUvsxa.jpg" width="227" height="227">
<script>
let net;
async function app() {
console.log('Loading mobilenet..');
// Load the model.
net = await mobilenet.load();
console.log('Successfully loaded model');
// Make a prediction through the model on our image.
const imgEl = document.getElementById('img');
const result = await net.classify(imgEl);
console.log(result);
console.log('Saving mobilenet...');
const saveResults = await net.save('downloads://my-model-1');
console.log('Mobilenet saved');
}
app();
</script>
</body></html>
这是我得到的错误:
Uncaught (in promise) TypeError: net.save is not a function
at app (TFjsmodelSaver.html:27)
app @ TFjsmodelSaver.html:27
async function (async)
app @ TFjsmodelSaver.html:19
(anonymous) @ TFjsmodelSaver.html:30
该错误清楚地表明 net.save 不是应用程序中的功能,但同时 net.classify 有效,并且保存在 tfjs 中:https ://js.tensorflow.org/api/0.12.5/# tf.Model.save
我错过了什么?
顺便说一句,如果有一种方法可以在 SavedModel 中获取 HUB 模型而无需通过此操作,请指出它。我假设模型首先在 TF 中创建,然后移植到 TFJS,所以它们可能在某处可用......
解决方案
mobilenet.load()
返回 MobileNet 类型的承诺。这是接口定义:
export interface MobileNet {
load(): Promise<void>;
infer(
img: tf.Tensor|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
embedding?: boolean): tf.Tensor;
classify(
img: tf.Tensor3D|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
topk?: number): Promise<Array<{className: string, probability: number}>>;
}
加载的模型不包含save
因此引发错误的方法。
保存不是函数
保存模型值得吗?加载的模型不用于训练。因此,每次需要时,都可以使用mobilenet.load
.
mobilenet 包只是mobilet savedModel 的包装器。github repo 包含不同版本的 mobilenet的url,可以从中下载 savedModel。可以使用本地加载模型tf.loadGraphModel
。但是这个本地加载的模型将是类型tf.GraphModel
并且不包含方法classify
和infer
下一个版本将提供保存功能tf.GraphModel
推荐阅读
- javascript - 如何检查数组的至少一个键值是否包含特定字符串
- c -
使用 va_list 的函数的混合输入类型 - swift - 根据tableview语句将变量快速保存到JSON文件中
- python - 替换 Pandas DF 中的窗口
- c++ - C++ 使用 std::enable_if 创建最多 10 个参数的 std::tuple 特化
- jquery - jQuery选择整个类,解决办法是什么?
- python - Python中两个图的子图未正确显示
- r - 使用列名进行“映射”时如何使用准引用/整洁评估
- javascript - 反应路由器侧边栏正确路由但不显示顶部菜单下的组件
- python - pymysql - 外键创建表