首页 > 解决方案 > TFJS:来自 HTTP 服务器的 loadGraphModel

问题描述

我正在使用 React 和 Tensorflow JS 实现一个对象检测 Web 应用程序。我将我的模型转换为 tensorflow JS 模型,这样我就可以将它加载到我的 React 应用程序中。我想使用一个简单的 HTTP 端点加载模型,该端点是当前托管在我的本地计算机上的 Flask 服务器。Flask 主文件如下所示:

from flask import Flask
from flask_cors import CORS, cross_origin
import os

app = Flask(__name__)
cors = CORS(app)

@app.route('/')
def hello_world():
    return 'Hello, World!'

@app.route('/model', methods=['GET'])
def get_modeljson():
    """
    Get the model.json file and return it's contents.
    """
    current_dir = os.getcwd()
    file_path = os.path.join(current_dir, "models", "model.json")
    with open(file_path, "r") as f:
        return f.read()


if __name__ == '__main__':
    app.run(debug=True, host="0.0.0.0", threaded=True)

我在我的 React 应用程序中编写了一个函数,该函数使用上面代码中定义的端点 /model 加载图形模型。React 函数如下所示:

import {useEffect, useState} from 'react';
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

function Model(props) {
    const [model, setModel] = useState();

    async function loadModel() {
        try {
          const model_url = "http://127.0.0.1:5000/model";
          const result = await fetch(model_url);
          const result_json = await result.json();
          
          const model = await loadGraphModel(result_json);
          console.log('model loaded...')
          setModel(model);
          console.log("Model correctly loaded");
        } catch (err) {
          console.log(err);
          console.log("failed load model");
        }
      }

    useEffect(() => {
        tf.ready().then(() => {
          loadModel();
        });
      }, []);

    async function predictFunction() {
        // use model to make predictions
    }

    return (      
        <Button onClick={() => {
            predictFunction();
        }}
        />
    );
}

export default Model;

FLASK API 正确返回 model.json 文件,但是 loadGraphModel 返回以下错误:

TypeError: url.startsWith is not a function
    at indexedDBRouter (indexed_db.ts:215)
    at router_registry.ts:95
    at Array.forEach (<anonymous>)
    at Function.getHandlers (router_registry.ts:94)
    at Function.getLoadHandlers (router_registry.ts:84)
    at Module.getLoadHandlers (router_registry.ts:110)
    at GraphModel.findIOHandler (graph_model.ts:107)
    at GraphModel.load (graph_model.ts:126)
    at loadGraphModel (graph_model.ts:440)
    at loadModel (Model.js:16)

我找不到任何关于 url.startsWith 的文档。谁知道这里出了什么问题?

标签: reactjsreact-nativetensorflowflasktensorflow.js

解决方案


通过代码,我发现它存在一个主要问题,您尝试基本上model.json从后端发送一个到前端,然后从中加载模型model.json并对其执行推理。它会起作用,但它根本没有效率。想象一下必须这样做几百次,我知道model.json文件可能很大。相反,您可以选择两条路线:

  1. 将模型托管在后端,通过 POST 请求将数据发送到后端,然后对来自请求的数据进行预测。
  2. 在前端使用模型,然后从那里对输入数据进行预测。

代码中存在一些导致错误的错误,但这是您需要首先解决的问题。如果你能给我更多关于你正在使用的输入的信息,我可以起草一个可行的解决方案。


推荐阅读