java - PySpark 3.0.1 无法在 Tensorflow 2.1.0 中运行分布式训练
问题描述
我正在尝试根据原始 TensorBoard Api 文档在 tensorflow 上训练一个简单的 fashion_mnist 模型,关于超参数调整,您可以在此处找到
目前,出于测试目的,我正在独立模式下运行。master = 'local[*]'
我已经安装pyspark==3.0.1
和tensorflow==2.1.0
. 以下是我要运行的内容:
# For a given hyper parameter, this will run the train & return the model + accuracy which I'm looking for.
# This works when I run without spark.
def train(hparam) -> Tuple[Model, Any]:
fashion_mnist = fashion
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = Sequential([
Flatten(),
Dense(hparam['num_units'], activation=tf.nn.relu),
Dropout(hparam['dropout']),
Dense(10, activation=tf.nn.softmax),
])
model.compile(
optimizer=hparam['optimizer'],
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
model.fit(x_train, y_train, epochs=1) # Run with 1 epoch to speed things up for demo purposes
_, accuracy = model.evaluate(x_test, y_test)
return model, accuracy
这是我运行的火花代码。
if __name__ == '__main__':
hp_nums = hp.HParam('num_units', hp.Discrete([16, 32]))
hp_dropouts = hp.HParam('dropout', hp.RealInterval(0.1, 0.2))
hp_opts = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd']))
all_params = [] ##contains a list of different hparams
for num_units in hp_nums.domain.values:
for dropout_rate in (hp_dropouts.domain.min_value, hp_dropouts.domain.max_value):
for optimizer in hp_opts.domain.values:
hparams = {
'num_units': num_units,
'dropout': dropout_rate,
'optimizer': optimizer,
}
all_params.append(hparams)
spark_sess = SparkSession.builder.master(
'local[*]'
).appName(
'LocalTraining'
).getOrCreate()
res = spark_sess.sparkContext.parallelize(
all_hparams, len(all_hparams)
).map(
train #above function
).collect()
temp = 0.0
best_model = None
for model, acc in res:
if acc > temp:
best_model = model
print("best accuracy is -> " + str(temp))
这对我来说看起来不错,适用于任何简单的 mapreduce(如基本示例)。这让我相信我的环境是完美的。
我的环境:
java : Java 11.0.8 2020-07-14 LTS
python: Python 3.6.5
pyspark: 3.0.1
tensorflow: 2.1.0
Keras: 2.3.1
windows: 10 (if this really matters)
cores : 8 (i5 10th gen)
Memory: 6G
但是当我运行上面的代码时。我收到以下错误。我可以看到训练运行,它在 1 个执行程序运行后停止
59168/60000 [============================>.] - ETA: 0s - loss: 0.7350 - accuracy: 0.7471
60000/60000 [==============================] - 3s 42us/step - loss: 0.7331 - accuracy: 0.7477
20/12/05 14:03:57 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.net.SocketException: Connection reset
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
0/12/05 14:03:57 ERROR TaskSetManager: Task 0 in stage 0.0 failed 1 times; aborting job
20/12/05 14:03:57 INFO TaskSchedulerImpl: Cancelling stage 0
20/12/05 14:03:57 INFO TaskSchedulerImpl: Killing all running tasks in stage 0: Stage cancelled
20/12/05 14:03:57 INFO Executor: Executor is trying to kill task 1.0 in stage 0.0 (TID 1), reason: Stage cancelled
20/12/05 14:03:57 INFO TaskSchedulerImpl: Stage 0 was cancelled
20/12/05 14:03:57 INFO DAGScheduler: ResultStage 0 (collect at C:/Users/<>/<>/<>/main.py:<>) failed in 7.506 s due to Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, host.docker.internal, executor driver): java.net.SocketException: Connection reset
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
at java.base/java.io.DataInputStream.readInt(DataInputStream.java:392)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:628)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, host.docker.internal, executor driver): java.net.SocketException: Connection reset
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
at java.base/java.io.DataInputStream.readInt(DataInputStream.java:392)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:628)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator.foreach(Iterator.scala:941)
Driver stacktrace:
20/12/05 14:03:57 INFO DAGScheduler: Job 0 failed: collect at C:/<>/<>/<>/main.py, took 7.541442 s
Traceback (most recent call last):
File "C:/<>/<>/<>/main.py", line 68, in main
return res.collect()
File "C:\Users\<>\<>\<>\venv\lib\site-packages\pyspark\rdd.py", line 889, in collect
sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
File "C:\Users\<>\<>\<>\venv\lib\site-packages\py4j\java_gateway.py", line 1305, in __call__
answer, self.gateway_client, self.target_id, self.name)
File "C:\Users\<>\<>\<>\venv\lib\site-packages\pyspark\sql\utils.py", line 128, in deco
return f(*a, **kw)
File "C:\Users\<>\<>\<>\venv\lib\site-packages\py4j\protocol.py", line 328, in get_return_value
format(target_id, ".", name), value)
错误在线model.fit()
。[只有当我这样做时才会发生model.fit
如果我将其注释掉并在那里有其他东西,它就可以正常工作。我不确定为什么它在 model.fit() 上失败]
解决方案
推荐阅读
- unit-testing - 在 Kafka Streams 中测试交互式查询
- configuration - 如何在 pylint 的 msg-template 中包含文件 url?
- scikit-learn - simpleimputer 没有处理我的数据
- python - 如何删除熊猫中的毫秒和相关数据
- sql - 用单表选择填空
- flutter - 我应该使用final,但是如果我需要修改一些东西怎么办?
- javascript - ReactJS 中的 Firebase 自定义声明 || 显示管理 UI
- java - 将包含浮点值的大型 CSV 或 Excel 文件传递给 Java MappedByteBuffer,而不必逐行或逐行分隔值
- docusaurus - 带有 Docusaurus 2 的每个文档页面的自定义主题
- google-app-engine - Java8 将 appcfg 迁移到 gcloud