python - 在 AzureML 上使用远程数据集在本地运行脚本
问题描述
我有一个用于开发目的的脚本,我想在本地运行和调试。但是,我不想将实验所需的数据存储在本地机器上。
我正在使用azureml
带有 Azure 机器学习工作室的库。请参阅下面的代码
# General
import os
import argparse
# Data analysis and wrangling
import pandas as pd
# Machine learning
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from azureml.core import Run
# Get the environment of this run
run = Run.get_context()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_path',
type=str,
help='Path to the training data',
# The default path is on my local machine, however I would like to reference a remote datastore on Azure as a parameter to this script
default=os.path.join(os.getcwd(), 'data')
)
args = parser.parse_args()
# Obtain the data from the datastore
train_df = pd.read_csv(os.path.join(args.data_path, os.listdir(args.data_path)[0]))
# Drop unnecessary columns
train_df = train_df.drop(['Name', 'PassengerId', 'Ticket', 'Cabin'], axis=1)
# Encode non-numeric features as dummies
train_df = pd.get_dummies(train_df)
# Drop NA's
train_df.dropna(inplace=True)
# Use gridsearch CV to find the best parameters for the model
parameters = {'kernel': ('linear', 'rbf'),
'C': [1, 10]}
# Initialize the grid search
search = GridSearchCV(SVC(), param_grid=parameters, cv=8)
# Train the model
search.fit(train_df.drop("Survived", axis=1), train_df["Survived"])
现在,该脚本使用本地文件夹“数据”。但是,我想为此脚本提供一个参数,表明我想在 Azure 机器学习工作室中使用远程数据存储。我怎么能做到这一点?
解决方案
经过一番尝试,我找到了一种方法来做到这一点,尽管我认为应该/可能有更好的方法。
在这里,我使用argparse
一个azureml
函数将数据直接读入内存(在本例中为pandas
数据帧)。
这是代码:
# General
import os
import argparse
# Data analysis and wrangling
import pandas as pd
# Machine learning
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
# Azure libraries
from azureml.core import Run, Workspace, Dataset
# Get the environment of this run
run = Run.get_context()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_path',
type=str,
help='Path to the training data',
default='local'
)
args = parser.parse_args()
if args.data_path == 'local':
# Use the local folder
data_path = os.path.join(os.getcwd(), 'data')
train_df = pd.read_csv(os.path.join(data_path, os.listdir(data_path)[0]))
elif args.data_path == 'remote':
# Use the remote location
# Get the workspace from the config file
ws = Workspace.from_config()
# Obtain the dataset
datastore = ws.get_default_datastore()
data_path = Dataset.Tabular.from_delimited_files(path=[(datastore,
'path_to_file')])
# Read the data
train_df = data_path.to_pandas_dataframe()
else:
# Unknown source, create empty dataframe
data_path = "Unknown"
train_df = pd.DataFrame()
# Drop unnecessary columns
train_df = train_df.drop(['Name', 'PassengerId', 'Ticket', 'Cabin'], axis=1)
# Encode non-numeric features as dummies
train_df = pd.get_dummies(train_df)
# Drop NA's
train_df.dropna(inplace=True)
# Use gridsearch CV to find the best parameters for the model
parameters = {'kernel': ('linear', 'rbf'),
'C': [1, 10]}
# Initialize the grid search
search = GridSearchCV(SVC(), param_grid=parameters, cv=8)
# Train the model
search.fit(train_df.drop("Survived", axis=1), train_df["Survived"])
run.log_table(name='Gridsearch results', value=pd.DataFrame(search.cv_results_).to_dict(orient="list"))
推荐阅读
- html - Inline-Level Boxes 和 Inline Boxes (W3C) 定义中的矛盾
- php - $_POST 内容的持续时间(生命周期)
- c# - C# Webbrowser 可以根据 IE 版本编写不同的 html 文件?
- html - 从 nutch 以纯文本格式获取数据
- javascript - Javascript数组包括
- django - 当表单无效时,Django CreateView 重定向到 ListView
- scala - 使用reduce时scala中的类型不匹配
- jasper-reports - Jasper 集合参数
- tableau-api - 仅基于 1 列删除重复项
- python - Python:读取数据集文件