首页 > 解决方案 > 在集群上部署 PySpark 作业

问题描述

我是 Apache Spark 的新手,并且正在使用 Python 3.8 和 Pyspark 3.1.2,并使用以下代码使用 MNIST 多类分类:

from pyspark.sql import SparkSession
import os
import logging
import pandas as pd
import pyspark

# Create a new Spark session-
spark = SparkSession\
        .builder.master('local[2]')\
        .appName('MLOps')\
        .getOrCreate()


# Read train CSV file-
train = spark.read.csv(
        "mnist_train.csv",
        inferSchema = True, header = True
        )

# Read test CSV file-
test = spark.read.csv(
        "mnist_test.csv",
        inferSchema = True, header = True
        )


print(f"number of partitions in 'train' = {train.rdd.getNumPartitions()}")
print(f"number of partitions in 'test' = {test.rdd.getNumPartitions()}")
# number of partitions in 'train' = 2
# number of partitions in 'test' = 2 

def shape(df):
    '''
    Function to return shape/dimension.
    
    Input
    df - pyspark.sql.dataframe.DataFrame
    '''
    return (df.count(), len(df.columns))


print(f"df.shape = {shape(df)}")
# df.shape = (10000, 785)

# Get distribution of 'label' attribute-
train.groupBy('label').count().show()
'''
+-----+-----+                                                                   
|label|count|
+-----+-----+
|    1| 6742|
|    6| 5918|
|    3| 6131|
|    5| 5421|
|    9| 5949|
|    4| 5842|
|    8| 5851|
|    7| 6265|
|    2| 5958|
|    0| 5923|
+-----+-----+
'''

test.groupBy('label').count().show()
'''
+-----+-----+
|label|count|
+-----+-----+
|    1| 1135|
|    6|  958|
|    3| 1010|
|    5|  892|
|    9| 1009|
|    4|  982|
|    8|  974|
|    7| 1028|
|    2| 1032|
|    0|  980|
+-----+-----+
'''


# Split data into training and evaluation datasets using 'randomSplit()'-
train_df, test_df = train.randomSplit(weights = [0.7, 0.3], seed = 223)
# 'randomSplit()' - randomly splits this 'DataFrame' with the provided weights

# Count number of rows-
train_df.count(), test_df.count(), train.count()
# (41840, 18160, 60000)

到目前为止,它在我的桌面上以独立模式本地运行,因此在创建新的 Spark 会话时使用了“local[2]”,其中 2 表示使用 RDD、DataFrame 和 Dataset 时要创建的分区数。理想情况下,“x”应该是可用 CPU 内核的数量。

但是,如何在具有 20 个计算节点的集群上部署此批处理作业?

谢谢!

标签: python-3.xapache-sparkpyspark

解决方案


推荐阅读