首页 > 解决方案 > PySpark 3.0.1 无法在 Tensorflow 2.1.0 中运行分布式训练

问题描述

我正在尝试根据原始 TensorBoard Api 文档在 tensorflow 上训练一个简单的 fashion_mnist 模型,关于超参数调整,您可以在此处找到

目前,出于测试目的,我正在独立模式下运行。master = 'local[*]'

我已经安装pyspark==3.0.1tensorflow==2.1.0. 以下是我要运行的内容:

# For a given hyper parameter, this will run the train & return the model + accuracy which I'm looking for. 
# This works when I run without spark.

def train(hparam) -> Tuple[Model, Any]:
    fashion_mnist = fashion
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    model = Sequential([
        Flatten(),
        Dense(hparam['num_units'], activation=tf.nn.relu),
        Dropout(hparam['dropout']),
        Dense(10, activation=tf.nn.softmax),
    ])
    model.compile(
        optimizer=hparam['optimizer'],
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
    )
    model.fit(x_train, y_train, epochs=1)  # Run with 1 epoch to speed things up for demo purposes
    _, accuracy = model.evaluate(x_test, y_test)
    return model, accuracy

这是我运行的火花代码。

if __name__ == '__main__':

     hp_nums = hp.HParam('num_units', hp.Discrete([16, 32]))
     hp_dropouts = hp.HParam('dropout', hp.RealInterval(0.1, 0.2))
     hp_opts = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd']))

     all_params = [] ##contains a list of different hparams
     
     for num_units in hp_nums.domain.values:
         for dropout_rate in (hp_dropouts.domain.min_value, hp_dropouts.domain.max_value):
             for optimizer in hp_opts.domain.values:
                 hparams = {
                    'num_units': num_units,
                    'dropout': dropout_rate,
                    'optimizer': optimizer,
                 }
                 all_params.append(hparams)
     

    
     spark_sess = SparkSession.builder.master(
         'local[*]'
     ).appName(
         'LocalTraining'
     ).getOrCreate()
     
     res = spark_sess.sparkContext.parallelize(
          all_hparams, len(all_hparams)
     ).map(
          train #above function
     ).collect()
     
     temp = 0.0
     best_model = None
     for model, acc in res:
         if acc > temp:
             best_model = model
     
     print("best accuracy is -> " + str(temp))


这对我来说看起来不错,适用于任何简单的 mapreduce(如基本示例)。这让我相信我的环境是完美的。

我的环境:

java : Java 11.0.8 2020-07-14 LTS
python: Python 3.6.5
pyspark: 3.0.1
tensorflow: 2.1.0
Keras: 2.3.1
windows: 10 (if this really matters)
cores : 8 (i5 10th gen)
Memory: 6G

但是当我运行上面的代码时。我收到以下错误。我可以看到训练运行,它在 1 个执行程序运行后停止

59168/60000 [============================>.] - ETA: 0s - loss: 0.7350 - accuracy: 0.7471
60000/60000 [==============================] - 3s 42us/step - loss: 0.7331 - accuracy: 0.7477
20/12/05 14:03:57 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.net.SocketException: Connection reset
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
    at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
    at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
0/12/05 14:03:57 ERROR TaskSetManager: Task 0 in stage 0.0 failed 1 times; aborting job
20/12/05 14:03:57 INFO TaskSchedulerImpl: Cancelling stage 0
20/12/05 14:03:57 INFO TaskSchedulerImpl: Killing all running tasks in stage 0: Stage cancelled
20/12/05 14:03:57 INFO Executor: Executor is trying to kill task 1.0 in stage 0.0 (TID 1), reason: Stage cancelled
20/12/05 14:03:57 INFO TaskSchedulerImpl: Stage 0 was cancelled
20/12/05 14:03:57 INFO DAGScheduler: ResultStage 0 (collect at C:/Users/<>/<>/<>/main.py:<>) failed in 7.506 s due to Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, host.docker.internal, executor driver): java.net.SocketException: Connection reset
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
    at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
    at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
    at java.base/java.io.DataInputStream.readInt(DataInputStream.java:392)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:628)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at scala.collection.Iterator.foreach(Iterator.scala:941)
    at scala.collection.Iterator.foreach$(Iterator.scala:941)

py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, host.docker.internal, executor driver): java.net.SocketException: Connection reset
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
    at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
    at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
    at java.base/java.io.DataInputStream.readInt(DataInputStream.java:392)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:628)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at scala.collection.Iterator.foreach(Iterator.scala:941)

Driver stacktrace:
20/12/05 14:03:57 INFO DAGScheduler: Job 0 failed: collect at C:/<>/<>/<>/main.py, took 7.541442 s
Traceback (most recent call last):
  File "C:/<>/<>/<>/main.py", line 68, in main
    return res.collect()
  File "C:\Users\<>\<>\<>\venv\lib\site-packages\pyspark\rdd.py", line 889, in collect
    sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
  File "C:\Users\<>\<>\<>\venv\lib\site-packages\py4j\java_gateway.py", line 1305, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "C:\Users\<>\<>\<>\venv\lib\site-packages\pyspark\sql\utils.py", line 128, in deco
    return f(*a, **kw)
  File "C:\Users\<>\<>\<>\venv\lib\site-packages\py4j\protocol.py", line 328, in get_return_value
    format(target_id, ".", name), value)

错误在线model.fit()。[只有当我这样做时才会发生model.fit如果我将其注释掉并在那里有其他东西,它就可以正常工作。我不确定为什么它在 model.fit() 上失败]

标签: javapythontensorflowpysparktensorflow2.0

解决方案


推荐阅读