python - 如何将我的腌制 ML 模型从 GCS 加载到 Dataflow/Apache Beam
问题描述
我在本地开发了一个 apache 光束管道,在其中对示例文件运行预测。
在我的计算机上本地我可以像这样加载模型:
with open('gs://newbucket322/my_dumped_classifier.pkl', 'rb') as fid:
gnb_loaded = cPickle.load(fid)
但是在谷歌数据流上运行时显然不起作用。我尝试将路径更改为 GS:// 但这显然也不起作用。
我还尝试了这个用于加载文件的代码片段(来自这里) :
class ReadGcsBlobs(beam.DoFn):
def process(self, element, *args, **kwargs):
from apache_beam.io.gcp import gcsio
gcs = gcsio.GcsIO()
yield (element, gcs.open(element).read())
model = (p
| "Initialize" >> beam.Create(["gs://bucket/file.pkl"])
| "Read blobs" >> beam.ParDo(ReadGcsBlobs())
)
但这在想要加载我的模型时不起作用,或者至少我不能使用这个模型变量来调用 predict 方法。
应该是一个非常简单的任务,但我似乎无法找到一个简单的答案。
解决方案
您可以如下定义 ParDo
class PerdictOutcome(beam.DoFn):
""" Format the input to the desired shape"""
def __init__(self, project=None, bucket_name=None, model_path=None, destination_name=None):
self._model = None
self._project = project
self._bucket_name = bucket_name
self._model_path = model_path
self._destination_name = destination_name
def download_blob(bucket_name=None, source_blob_name=None, project=None, destination_file_name=None):
"""Downloads a blob from the bucket."""
destination_file_name = source_blob_name
storage_client = storage.Client(<gs://path">)
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(source_blob_name)
blob.download_to_filename(destination_file_name)
# Load once or very few times
def setup(self):
logging.info(
"Model Initialization {}".format(self._model_path))
download_blob(bucket_name=self._bucket_name, source_blob_name=self._model_path,
project=self._project, destination_file_name=self._destination_name)
# unpickle model model
self._model = pickle.load(open(self._destination_name, 'rb'))
def process(self, element):
element["prediction"] = self._model.predict(element["data"])
return [element]
然后你可以在你的管道中调用这个 ParDo,如下所示: -
model = (p
| "Read Files" >> TextIO...
| "Run Predictions" >> beam.ParDo(PredictSklearn(project=known_args.bucket_project_id, bucket_name=known_args.bucket_name, model_path=known_args.model_path, destination_name=known_args.destination_name)
)
推荐阅读
- composer-php - 作曲家可以使用局部变量来指定存储库版本吗?
- vue.js - 如何将 codepen.io 示例复制到 .vue
- ignite - Apache Ignite 缓存到 SQL,反之亦然
- tensorflow - 将模型加载到 TensorFlow 服务容器中并使用 protobufs 与之通信
- amazon-dynamodb - AWS Amplify + Appsync - 是否可以使用@connection 转换级联删除相关数据?
- python-3.x - 如何在 beautifulsoup 中过滤掉缩略图附件?
- java - 使用 ECDSA 公钥验证 JWT 签名 - 解码签名字节时出错
- java - 如何使用java从csv文件的每一行中删除逗号
- php - 当所选行具有 NULL 值时,允许的内存大小已耗尽
- rest - 意外路径变量类型上的 400 与 404