首页 > 解决方案 > TFJS 使用标头将模型保存到 http

问题描述

我正在尝试使用 https://www.tensorflow.org/js/guide/save_load 上的指南保存和上传带有附加标头(用于类名)的 tfjs 模型,后端https://gist.github.com/ dsmilkov/1b6046fd6132d7408d5257b0976f7864。但遵循指南并不能按指南中的预期和说明工作。我在哪里犯错?谢谢

我的浏览器代码是:

const saveResult = await model.save(tf.io.http('http://localhost:5000/upload', {method: 'POST', headers: {'class': 'Dog'}}));

服务器的代码是:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import io

from flask import Flask, Response, request
from flask_cors import CORS, cross_origin
import tensorflow as tf
import tensorflowjs as tfjs
import werkzeug.formparser

class ModelReceiver(object):

  def __init__(self):
    self._model = None
    self._model_json_bytes = None
    self._model_json_writer = None
    self._weight_bytes = None
    self._weight_writer = None

  @property
  def model(self):
    self._model_json_writer.flush()
    self._weight_writer.flush()
    self._model_json_writer.seek(0)
    self._weight_writer.seek(0)

    json_content = self._model_json_bytes.read()
    weights_content = self._weight_bytes.read()
    return tfjs.converters.deserialize_keras_model(
        json_content,
        weight_data=[weights_content],
        use_unique_name_scope=True)

  def stream_factory(self,
                     total_content_length,
                     content_type,
                     filename,
                     content_length=None):
    # Note: this example code isnot* thread-safe.
    if filename == 'model.json':
      self._model_json_bytes = io.BytesIO()
      self._model_json_writer = io.BufferedWriter(self._model_json_bytes)
      return self._model_json_writer
    elif filename == 'model.weights.bin':
      self._weight_bytes = io.BytesIO()
      self._weight_writer = io.BufferedWriter(self._weight_bytes)
      return self._weight_writer


def main():
  app = Flask('model-server')
  CORS(app)
  app.config['CORS_HEADER'] = 'Content-Type'

  model_receiver = ModelReceiver()

  @app.route('/upload', methods=['POST'])
  @cross_origin()
  def upload():
    print('headers are:')
    print(request.headers)
    print('Handling request...')
    werkzeug.formparser.parse_form_data(
        request.environ, stream_factory=model_receiver.stream_factory)
    print('Received model:')
    with tf.Graph().as_default(), tf.Session():
      model = model_receiver.model
      model.summary()
      # You can perform `model.predict()`, `model.fit()`,
      # `model.evaluate()` etc. here.
    return Response(status=200)

  app.run('localhost', 5000)


if __name__ == '__main__':
  main()

标签: pythonflasktensorflow.js

解决方案


如果您的目标是在模型中存储一些辅助信息(例如类标签),那么 TensorFlow.js 中有一个相对鲜为人知的功能tf.LayersModel可以让您的生活更轻松。它比使用标题更简单。

这是setUserDefinedMetadata()andgetUserDefinedMetadata()方法。

在 JavaScript 方面,执行以下操作:

// The argument to setUserDefinedMetadata() can be any serializable JSON
// object of a reasonable size.
myModel.setUserDefinedMetadata({outputClassLabels: ['Cat', 'Dog', 'Turtle']});

// The user metadata is stored with the model itself. No need to specify 
// additional headers.
await model.save('http://localhost:5000/upload');

接收模型工件的服务器可以简单地检查请求中 JSON 有效负载的“userDefinedMetadata”字段。


推荐阅读