amazon-web-services - sagemaker 中的逻辑回归
问题描述
我正在使用 aws sagemaker 进行逻辑回归。为了在测试数据上验证模型,使用以下代码
runtime= boto3.client('runtime.sagemaker')
payload = np2csv(test_X)
response = runtime.invoke_endpoint(EndpointName=linear_endpoint,
ContentType='text/csv',
Body=payload)
result = json.loads(response['Body'].read().decode())
test_pred = np.array([r['score'] for r in result['predictions']])
结果包含预测值和概率分数。我想知道如何运行预测模型来根据两个特定特征预测结果。例如。我在模型中有 30 个特征,并使用这些特征训练了模型。现在对于我的预测,我想知道 feature1='x' 和 feature2='y' 时的结果。但是当我将数据过滤到这些列并在相同的代码中传递时,我收到以下错误。
Customer Error: The feature dimension of the input: 4 does not match the feature dimension of the model: 30. Please fix the input and try again.
在 AWS Sagemaker 实施中,R 中的 say glm.predict('feature1','feature2') 等价物是什么?
解决方案
当您在数据上训练回归模型时,您正在学习从输入特征到响应变量的映射。然后,您可以使用该映射通过向模型提供新的输入特征来进行预测。
如果您在 30 个特征上训练了一个模型,则不可能使用相同的模型仅对其中的 2 个特征进行预测。您必须为其他 28 个特征提供值。
如果您只想知道这两个特征如何影响预测,那么您可以查看训练模型的权重(也称为“参数”或“系数”)。如果特征 1 的权重为 x,则当特征 1 增加 1.0 时,预测响应增加 x。
要查看在 Amazon SageMaker 中使用线性学习器算法训练的模型的权重,您可以下载 model.tar.gz 工件并在本地打开它。output
可以从您在该方法的参数中指定的 S3 位置下载模型工件sagemaker.estimator.Estimator
。
import os
import mxnet as mx
import boto3
bucket = "<your_bucket>"
key = "<your_model_prefix>"
boto3.resource('s3').Bucket(bucket).download_file(key, 'model.tar.gz')
os.system('tar -zxvf model.tar.gz')
# Linear learner model is itself a zip file, containing a mxnet model and other metadata.
# First unzip the model.
os.system('unzip model_algo-1')
# Load the mxnet module
mod = mx.module.Module.load("mx-mod", 0)
# model weights
weights = mod._arg_params['fc0_weight'].asnumpy().flatten()
# model bias
bias = mod._arg_params['fc0_bias'].asnumpy().flatten()
# weight for the first feature
weights[0]
推荐阅读
- installation - 在 R studio 1.2.1335 和 macOS Sierra 10.12.6 中使用 BiocManager::install 安装“org.Hs.eg.db”时出现问题
- sql - 有没有办法根据在 SELECT 语句中作为列名满足的多个条件来提取新计算?
- django - 使用频道 + Nginx + Daphne 部署 dockerized Django
- excel - 在用户窗体初始化上设置多个复选框
- rest - api 是否应该在返回响应之前更改其他系统状态?
- docker - 无法使用 shell 脚本启动 docker 容器
- faunadb - 如何在 FaunaDB 中获取嵌套文档?
- c++ - 浮点数组的 std::max_element
- angular - 使用 Amplify Codegen 自定义 GraphQL 查询
- c++ - 二维阵列反转线中的 SegFault