apache-spark - 如何加载火花模型
问题描述
我没有成功加载模型并保存了。我有一个奇怪的错误。
from transforms.api import Output, transform,transform_df
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import LogisticRegressionModel
import logging
logger = logging.getLogger(__name__)
def save_model(spark_session, output, model, model_name='model4'):
foundry_file_system = output.filesystem()._foundry_fs
logger.info("The path 1 is : "+ str(foundry_file_system))
path = foundry_file_system._root_path + "/" + model_name
logger.info("The path 2 is : "+ str(path))
model.write().overwrite().session(spark_session).save(path)
model=LogisticRegressionModel.read().session(spark_session).load(path)
df_to_predict = spark_session.createDataFrame([(
Vectors.dense([0.0, 1.1, 0.1]),
Vectors.dense([2.0, 1.0, -1.0]),
Vectors.dense([2.0, 1.3, 1.0]),
Vectors.dense([0.0, 1.2, -0.5]),)], ["features"])
df_predicted = model.transform(df_to_predict)
logger.info(df_predicted.show())
logger.info(df_predicted.count())
def my_compute_function(ctx, output_model):
training = ctx.spark_session.createDataFrame([
(1.0, Vectors.dense([0.0, 1.1, 0.1])),
(0.0, Vectors.dense([2.0, 1.0, -1.0])),
(0.0, Vectors.dense([2.0, 1.3, 1.0])),
(1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"])
lr = LogisticRegression(maxIter=10, regParam=0.01)
model1 = lr.fit(training)
save_model(ctx.spark_session, output_model, model1, 'model4')
这是我得到的错误:
NonRetryableError: Py4JJavaError: 调用 o266.load 时出错。: scala.MatchError: [2,3,[1,null,null,WrappedArray(0.06817659473873602)],[1,1,3,null,null,WrappedArray(-3.1009356010205322, 2.6082147383214482, -0.38017912254303043],false) (类 org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema)在 org.apache.spark.ml.classification.LogisticRegressionModel$LogisticRegressionModelReader.load(LogisticRegression.scala:1273) ....
解决方案
该错误表明使用与编写模型不同的方法来加载模型。
您应该使用LogisticRegressionModel.load而不是 LogisticRegression.read()
如果 parquet 元数据不匹配,也可能导致此问题。我建议您将摘要元数据级别设置为NONE
spark.conf.set("parquet.summary.metadata.level", "NONE")
推荐阅读
- python - Formatting pandas dataframe
- c++ - 检查值的辅助函数是它的任何参数
- python - 检查字典是否没有所有键的值
- python - 在 Python 中循环 ConfigParser 时,如何停止重复不需要的值?
- javascript - 标记自动滚动计数器或文本以调用 Javascript 函数
- java - 使用 Java 和 HTML 获得两个字符串输入之和的正确方法是什么?
- mysql - 它无法正确打开 xampp 我该怎么办
- c - 循环队列。检查它是否已满
- python - SQLITE - 非空约束失败 - 将一行从一个表传输到另一个表
- html - 如何正确对齐此 td 的第二个跨度的内容?