首页 > 解决方案 > estimator.fit 在本地模式下挂在 sagemaker 上

问题描述

我正在尝试使用本地模式训练pytorch模型,但是每当我调用代码时,代码都会无限期挂起,我必须中断笔记本内核。这发生在我的本地机器和. 但是当我使用 EC2 时,训练运行正常。Sagemakerestimator.fitSagemaker Studio

这里是对估计器的调用,以及我中断内核后的堆栈跟踪:

import sagemaker
from sagemaker.pytorch import PyTorch

bucket = "bucket-name"
role = sagemaker.get_execution_role()
training_input_path = f"s3://{bucket}/dataset/path"

sagemaker_session = sagemaker.LocalSession()
sagemaker_session.config = {"local": {"local_code": True}}

output_path = "file://."

estimator = PyTorch(
    entry_point="train.py",
    source_dir="src",
    hyperparameters={"max-epochs": 1},
    framework_version="1.8",
    py_version="py3",
    instance_count=1,
    instance_type="local",
    role=role,
    output_path=output_path,
    sagemaker_session=sagemaker_session,
)


estimator.fit({"training": training_input_path})

堆栈跟踪:

    ---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-9-35cdd6021288> in <module>
----> 1 estimator.fit({"training": training_input_path})

/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in fit(self, inputs, wait, logs, job_name, experiment_config)
    678         self._prepare_for_training(job_name=job_name)
    679 
--> 680         self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
    681         self.jobs.append(self.latest_training_job)
    682         if wait:

/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in start_new(cls, estimator, inputs, experiment_config)
   1450         """
   1451         train_args = cls._get_train_args(estimator, inputs, experiment_config)
-> 1452         estimator.sagemaker_session.train(**train_args)
   1453 
   1454         return cls(estimator.sagemaker_session, estimator._current_job_name)

/opt/conda/lib/python3.7/site-packages/sagemaker/session.py in train(self, input_mode, input_config, role, job_name, output_config, resource_config, vpc_config, hyperparameters, stop_condition, tags, metric_definitions, enable_network_isolation, image_uri, algorithm_arn, encrypt_inter_container_traffic, use_spot_instances, checkpoint_s3_uri, checkpoint_local_path, experiment_config, debugger_rule_configs, debugger_hook_config, tensorboard_output_config, enable_sagemaker_metrics, profiler_rule_configs, profiler_config, environment, retry_strategy)
    572         LOGGER.info("Creating training-job with name: %s", job_name)
    573         LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
--> 574         self.sagemaker_client.create_training_job(**train_request)
    575 
    576     def _get_train_request(  # noqa: C901

/opt/conda/lib/python3.7/site-packages/sagemaker/local/local_session.py in create_training_job(self, TrainingJobName, AlgorithmSpecification, OutputDataConfig, ResourceConfig, InputDataConfig, **kwargs)
    184         hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
    185         logger.info("Starting training job")
--> 186         training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName)
    187 
    188         LocalSagemakerClient._training_jobs[TrainingJobName] = training_job

/opt/conda/lib/python3.7/site-packages/sagemaker/local/entities.py in start(self, input_data_config, output_data_config, hyperparameters, job_name)
    219 
    220         self.model_artifacts = self.container.train(
--> 221             input_data_config, output_data_config, hyperparameters, job_name
    222         )
    223         self.end_time = datetime.datetime.now()

/opt/conda/lib/python3.7/site-packages/sagemaker/local/image.py in train(self, input_data_config, output_data_config, hyperparameters, job_name)
    200         data_dir = self._create_tmp_folder()
    201         volumes = self._prepare_training_volumes(
--> 202             data_dir, input_data_config, output_data_config, hyperparameters
    203         )
    204         # If local, source directory needs to be updated to mounted /opt/ml/code path

/opt/conda/lib/python3.7/site-packages/sagemaker/local/image.py in _prepare_training_volumes(self, data_dir, input_data_config, output_data_config, hyperparameters)
    487             os.mkdir(channel_dir)
    488 
--> 489             data_source = sagemaker.local.data.get_data_source_instance(uri, self.sagemaker_session)
    490             volumes.append(_Volume(data_source.get_root_dir(), channel=channel_name))
    491 

/opt/conda/lib/python3.7/site-packages/sagemaker/local/data.py in get_data_source_instance(data_source, sagemaker_session)
     52         return LocalFileDataSource(parsed_uri.netloc + parsed_uri.path)
     53     if parsed_uri.scheme == "s3":
---> 54         return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session)
     55     raise ValueError(
     56         "data_source must be either file or s3. parsed_uri.scheme: {}".format(parsed_uri.scheme)

/opt/conda/lib/python3.7/site-packages/sagemaker/local/data.py in __init__(self, bucket, prefix, sagemaker_session)
    183             working_dir = "/private{}".format(working_dir)
    184 
--> 185         sagemaker.utils.download_folder(bucket, prefix, working_dir, sagemaker_session)
    186         self.files = LocalFileDataSource(working_dir)
    187 

/opt/conda/lib/python3.7/site-packages/sagemaker/utils.py in download_folder(bucket_name, prefix, target, sagemaker_session)
    286                 raise
    287 
--> 288     _download_files_under_prefix(bucket_name, prefix, target, s3)
    289 
    290 

/opt/conda/lib/python3.7/site-packages/sagemaker/utils.py in _download_files_under_prefix(bucket_name, prefix, target, s3)
    314             if exc.errno != errno.EEXIST:
    315                 raise
--> 316         obj.download_file(file_path)
    317 
    318 

/opt/conda/lib/python3.7/site-packages/boto3/s3/inject.py in object_download_file(self, Filename, ExtraArgs, Callback, Config)
    313     return self.meta.client.download_file(
    314         Bucket=self.bucket_name, Key=self.key, Filename=Filename,
--> 315         ExtraArgs=ExtraArgs, Callback=Callback, Config=Config)
    316 
    317 

/opt/conda/lib/python3.7/site-packages/boto3/s3/inject.py in download_file(self, Bucket, Key, Filename, ExtraArgs, Callback, Config)
    171         return transfer.download_file(
    172             bucket=Bucket, key=Key, filename=Filename,
--> 173             extra_args=ExtraArgs, callback=Callback)
    174 
    175 

/opt/conda/lib/python3.7/site-packages/boto3/s3/transfer.py in download_file(self, bucket, key, filename, extra_args, callback)
    305             bucket, key, filename, extra_args, subscribers)
    306         try:
--> 307             future.result()
    308         # This is for backwards compatibility where when retries are
    309         # exceeded we need to throw the same error from boto3 instead of

/opt/conda/lib/python3.7/site-packages/s3transfer/futures.py in result(self)
    107         except KeyboardInterrupt as e:
    108             self.cancel()
--> 109             raise e
    110 
    111     def cancel(self):

/opt/conda/lib/python3.7/site-packages/s3transfer/futures.py in result(self)
    104             # however if a KeyboardInterrupt is raised we want want to exit
    105             # out of this and propogate the exception.
--> 106             return self._coordinator.result()
    107         except KeyboardInterrupt as e:
    108             self.cancel()

/opt/conda/lib/python3.7/site-packages/s3transfer/futures.py in result(self)
    258         # possible value integer value, which is on the scale of billions of
    259         # years...
--> 260         self._done_event.wait(MAXINT)
    261 
    262         # Once done waiting, raise an exception if present or return the

/opt/conda/lib/python3.7/threading.py in wait(self, timeout)
    550             signaled = self._flag
    551             if not signaled:
--> 552                 signaled = self._cond.wait(timeout)
    553             return signaled
    554 

/opt/conda/lib/python3.7/threading.py in wait(self, timeout)
    294         try:    # restore state no matter what (e.g., KeyboardInterrupt)
    295             if timeout is None:
--> 296                 waiter.acquire()
    297                 gotit = True
    298             else:

KeyboardInterrupt: 

标签: pythonamazon-web-servicespytorchamazon-sagemaker

解决方案


推荐阅读