首页 > 解决方案 > 将表加载到有限制的 PySpark Dataframe 中

问题描述

PySpark 是否可以在从数据库读取数据时将一定数量的数据加载到数据框中?通过某个数字,我的意思是如果可以在sqlContext从数据库中读取它时对它进行限制,以便不必读取整个表(因为迭代 750K 行非常昂贵)。

这是我目前用来过滤所需数据的代码。除了 PySpark,我还使用了 Python3.7 和 Cassandra DB:

def connect_cassandra():
    spark = SparkSession.builder \
      .appName('SparkCassandraApp') \
      .config('spark.cassandra.connection.host', 'localhost') \
      .config("spark.driver.memory","15g") \
      .config("spark.executor.memory","15g") \
      .config("spark.driver.cores","4") \
      .config("spark.num.executors","6") \
      .config("spark.executor.cores","4") \
      .config('spark.cassandra.connection.port', '9042') \
      .config('spark.cassandra.output.consistency.level','ONE') \
      .master('local[*]') \
      .getOrCreate()

    sqlContext = SQLContext(spark)
    return sqlContext

def total_bandwidth(start_date, end_date):
    sqlContext = connect_cassandra()

    try:
        df = sqlContext \
          .read \
          .format("org.apache.spark.sql.cassandra") \
          .options(table="user_info", keyspace="acrs") \
          .load()
    except Exception as e:
        print(e)

    rows = df.where(df["created"] > str(start_date)) \
            .where(df["created"] < str(end_date)) \
            .groupBy(['src_ip', 'dst_ip']) \
            .agg(_sum('data').alias('total')) \
            .collect()

    data_dict = []
    for row in rows:
        src_ip = row['src_ip']
        dst_ip = row['dst_ip']
        data = row['total']
        data = {'src_ip' : src_ip, 'dst_ip' : dst_ip, 'data' : data}
        data_dict.append(data)

    print(data_dict)

正如你们所看到的,我正在尝试使用start_dateand过滤掉数据end_date。但这需要太多时间,导致操作缓慢。我想知道在将表加载到数据框中时是否有任何可用的 DataFrameReader 选项,以便减少所花费的时间(指数首选:p)。

我阅读了 Data-Frame-Reader 文档并找到了option(String key, String value)这些选项,但这些选项没有记录,因此无法找出 Cassandra 数据库有哪些选项以及如何使用它们。

标签: pythonpython-3.xdataframecassandrapyspark

解决方案


您的主要问题是您使用的是 append 方法。由于您的数据框中有大量行,因此效率非常低。我宁愿使用专用的 pyspark 方法来达到预期的结果。

我在本地机器上创建了一些包含 100 万行的临时数据框(我假设您已经创建了 SparkSession)

>>> import pandas as pd

>>> n = 1000000
>>> df = spark.createDataFrame(
        pd.DataFrame({
            'src_ip': n * ['192.160.1.0'],
            'dst_ip': n * ['192.168.1.1'],
            'total': n * [1]
        })
    )
>>> df.count()
1000000

让我们从表中只选择所需的列。

>>> import pyspark.sql.functions as F
>>> df.select('src_ip', 'dst_ip', F.col('total').alias('data')).show(5)
+-----------+-----------+----+
|     src_ip|     dst_ip|data|
+-----------+-----------+----+
|192.160.1.0|192.168.1.1|   1|
|192.160.1.0|192.168.1.1|   1|
|192.160.1.0|192.168.1.1|   1|
|192.160.1.0|192.168.1.1|   1|
|192.160.1.0|192.168.1.1|   1|
+-----------+-----------+----+
only showing top 5 rows

最后,让我们创建所需的数据字典列表。收集所有数据的最简单方法是使用列表推导。一旦我们选择了要组合到字典中的列,我们就可以toDict()在每个 DataFrame 行上使用方法。

挑剔:

  • 如果要收集所有值,请使用collect()DataFrame 上的方法。
  • 如果您不知道 DataFrame 的确切大小,您可以使用从 DataFrametake(n)返回n元素的方法。
>>> dict_list = [i.asDict() for i in df.select('src_ip', 'dst_ip', F.col('total').alias('data')).take(5)]
>>> dict_list
[{'data': 1, 'dst_ip': '192.168.1.1', 'src_ip': '192.160.1.0'},
 {'data': 1, 'dst_ip': '192.168.1.1', 'src_ip': '192.160.1.0'},
 {'data': 1, 'dst_ip': '192.168.1.1', 'src_ip': '192.160.1.0'},
 {'data': 1, 'dst_ip': '192.168.1.1', 'src_ip': '192.160.1.0'},
 {'data': 1, 'dst_ip': '192.168.1.1', 'src_ip': '192.160.1.0'}]

推荐阅读