首页 > 解决方案 > 无法将 Spark 数据帧转换为 Pandas 数据帧

问题描述

我有一个 spark 数据框 Df,它有大约 130 000 行、5000 个客户 ID 和 7000 个产品 ID。我正在使用交叉连接生成所有可能的客户 ID 和产品 ID 组合(3400 万行)并将其存储在 fullouter 中。我正在从 Df 中已经存在的 fullouter 中删除组合,然后使用我的模型查找 allPredictions。

到目前为止,一切都很好。但我想将 allPredictions(3000 万行)转换为 pandas 数据框。我知道toPandas()由于没有行,转换将很困难。所以我所做的是我只对每个客户 ID 进行了前 1 个预测 - 使用 windows 函数和行号函数来做到这一点。

我假设 allPredictions 的大小应该大大减少到 5000 个客户 * 每个客户 1 个预测 = 5000 行。我“假设”是因为count()返回行数也需要很长时间。toPandas()应该在 topPredictions 数据框上工作。但它不起作用。花费了超过 40 分钟的时间,并且由于我在 google colab 工作,因此会话在一段时间后变得不活跃。

我是 Spark 的新手。我在这里做错了吗?我应该对我的代码进行哪些更改?另外,我试着把它写成镶木地板 - 时间太长了。我也试过 write.csv - 同样的问题。

conf = SparkConf().setAppName("trial")
conf.set("spark.sql.execution.arrow.enabled",'true')
conf.set("spark.rpc.message.maxSize",'1024mb')
conf.set("spark.executor.memory", '8g')
conf.set('spark.executor.cores', '8')
conf.set('spark.cores.max', '8')
conf.set("spark.driver.memory", '45g')
conf.set('spark.driver.maxResultSize', '21G')
conf.set("spark.driver.bindAddress", '127.0.0.1')
conf.set("spark.worker.cleanup.enabled",True)
conf.set("spark.executor.heartbeatInterval", "200000")
conf.set("spark.network.timeout", "300000")
self.sparkContext = SparkContext().getOrCreate(conf=conf)
self.sparkContext.setCheckpointDir('/checkpoint')
best_als = ALS(rank=10, maxIter=20, regParam=1.0,alpha=200.0, userCol="customerId",itemCol="productId",ratingCol="purch",implicitPrefs=True)

model=best_als.fit(Df)

df1 = Df.select("customerId")
df2 = Df.select("productId")

fullouter = df1.crossJoin(df2)

bigtest=fullouter.join(data, ["customerId","productId"],"left_anti")

allPredictions=model.transform(bigtest)

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col, row_number

window = Window.partitionBy(allPredictions['customerId']).orderBy(allPredictions['prediction'].desc())

top_allPredictions=allPredictions.select('*', row_number().over(window).alias('rank')).filter(col('rank') <= 1)

dataframe=top_allPredictions.toPandas()

标签: pandasdataframeapache-sparkpysparkapache-spark-sql

解决方案


尝试这个

from pyspark.sql import *  
from pyspark.sql.functions import *  
from pyspark.sql.types import *  
import numpy as np    
import pandas as pd


dataframe= top_allPredictions.select("*").toPandas()

推荐阅读