python - 如何从pyspark中的每一行中减去spark数据帧中的每一行?
问题描述
我有一个带有 3 列的 spark 数据框,表示原子的位置,即位置 X、Y 和 Z。现在要找到我需要应用距离公式的每 2 个原子之间的距离。距离公式为 d= sqrt((x2−x1)^2+(y2−y1)^2+(z2-z1)^2)
所以要应用上面的公式,我需要从 x 中的每一行中减去 x 中的每一行,从 y 中的每一行中减去 y 中的每一行,依此类推。然后对每两个原子应用上述公式。
我试图制作一个用户定义的函数(udf),但我无法将整个火花数据帧传递给它,我只能分别传递每一列而不是整个数据帧。因此,我无法遍历整个数据框,而是必须在每一列上应用循环。下面的代码显示了我只为 Position_X 所做的迭代。
@udf
def Distance(Position_X,Position_Y, Position_Z):
try:
for x,z in enumerate(Position_X) :
firstAtom = z
for y, a in enumerate(Position_X):
if (x!=y):
diff = firstAtom - a
return diff
except:
return None
newDF1 = atomsDF.withColumn("Distance", Distance(*atomsDF.columns))
My atomDF spark dataframe look like this, each row shows the x,y,z coordinates of one atom in space. Right now we are taking only 10 atoms.
Position_X|Position_Y|Position_Z|
+----------+----------+----------+
| 27.545| 6.743| 12.111|
| 27.708| 7.543| 13.332|
| 27.640| 9.039| 12.970|
| 26.991| 9.793| 13.693|
| 29.016| 7.166| 14.106|
| 29.286| 8.104| 15.273|
| 28.977| 5.725| 14.603|
| 28.267| 9.456| 11.844|
| 28.290| 10.849| 11.372|
| 26.869| 11.393| 11.161|
+----------+----------+----------+
如何在 pyspark 中解决上述问题,即。如何从每一行中减去每一行?如何将整个火花数据框传递给 udf 而不是它的列?以及如何避免使用太多 for 循环?
每两个原子(行)的预期输出将是使用上述距离公式计算的两行之间的距离。我不需要保留那个距离,因为我将使用它的另一个势能公式。或者,如果它可以保留在单独的数据框中,我不介意。
解决方案
我你想对执行交叉连接所需的原子(线)进行 2 对 2 的比较......这是不推荐的。
您可以使用该函数monotonically_increasing_id
为每一行生成一个 id。
from pyspark.sql import functions as F
df = df.withColumn("id", F.monotonically_increasing_id())
然后你自己交叉加入你的数据框,然后用“id_1 > id_2”的行过滤
df_1 = df.select(*(F.col(col).alias("{}_1".format(col)) for col in df.columns))
df_2 = df.select(*(F.col(col).alias("{}_2".format(col)) for col in df.columns))
df_3 = df_1.crossJoin(df_2).where("id_1 > id_2")
df_3 包含您需要的 45 行。你只需要应用你的公式:
df_4 = df_3.withColumn(
"distance",
F.sqrt(
F.pow(F.col("Position_X_1") - F.col("Position_X_2"), F.lit(2))
+ F.pow(F.col("Position_Y_1") - F.col("Position_Y_2"), F.lit(2))
+ F.pow(F.col("Position_Z_1") - F.col("Position_Z_2"), F.lit(2))
)
)
df_4.orderBy('id_2', 'id_1').show()
+------------+------------+------------+----------+------------+------------+------------+----+------------------+
|Position_X_1|Position_Y_1|Position_Z_1| id_1|Position_X_2|Position_Y_2|Position_Z_2|id_2| distance|
+------------+------------+------------+----------+------------+------------+------------+----+------------------+
| 27.708| 7.543| 13.332| 1| 27.545| 6.743| 12.111| 0|1.4688124454810418|
| 27.64| 9.039| 12.97| 2| 27.545| 6.743| 12.111| 0| 2.453267616873462|
| 26.991| 9.793| 13.693| 3| 27.545| 6.743| 12.111| 0| 3.480249991020759|
| 29.016| 7.166| 14.106| 4| 27.545| 6.743| 12.111| 0|2.5145168522004355|
| 29.286| 8.104| 15.273|8589934592| 27.545| 6.743| 12.111| 0|3.8576736513085175|
| 28.977| 5.725| 14.603|8589934593| 27.545| 6.743| 12.111| 0| 3.049100195139542|
| 28.267| 9.456| 11.844|8589934594| 27.545| 6.743| 12.111| 0|2.8200960976534106|
| 28.29| 10.849| 11.372|8589934595| 27.545| 6.743| 12.111| 0| 4.237969089080287|
| 26.869| 11.393| 11.161|8589934596| 27.545| 6.743| 12.111| 0| 4.793952023122468|
| 27.64| 9.039| 12.97| 2| 27.708| 7.543| 13.332| 1|1.5406764747993003|
| 26.991| 9.793| 13.693| 3| 27.708| 7.543| 13.332| 1|2.3889139791964036|
| 29.016| 7.166| 14.106| 4| 27.708| 7.543| 13.332| 1|1.5659083625806454|
| 29.286| 8.104| 15.273|8589934592| 27.708| 7.543| 13.332| 1|2.5636470115833037|
| 28.977| 5.725| 14.603|8589934593| 27.708| 7.543| 13.332| 1|2.5555676473143896|
| 28.267| 9.456| 11.844|8589934594| 27.708| 7.543| 13.332| 1| 2.48720606303539|
| 28.29| 10.849| 11.372|8589934595| 27.708| 7.543| 13.332| 1| 3.88715319996524|
| 26.869| 11.393| 11.161|8589934596| 27.708| 7.543| 13.332| 1| 4.498851186691999|
| 26.991| 9.793| 13.693| 3| 27.64| 9.039| 12.97| 2|1.2298154333069653|
| 29.016| 7.166| 14.106| 4| 27.64| 9.039| 12.97| 2|2.5868902180030737|
| 29.286| 8.104| 15.273|8589934592| 27.64| 9.039| 12.97| 2|2.9811658793163454|
+------------+------------+------------+----------+------------+------------+------------+----+------------------+
only showing top 20 rows
它适用于少量数据,但大量数据crossJoin
会破坏性能。
推荐阅读
- git - 使用 git 进行更改然后回滚
- c++ - 如果我在多线程中重置相同的 shared_ptr 不会崩溃
- c# - 如何在 C# 中将两个 StringCollection 合并/合并为一个
- java - Are there ways set a local variable of a function inside stream java 8
- grails - Grails 3.3.9 从服务调用 taglib
- aws-amplify - AWS AppSync Null @connection 查询结果
- c++ - 在一行中从类构造中初始化抽象类引用
- r - 使用 varlist 在 r studio 上循环
- python-3.x - 无法从 python lambda 将 XML 写入 S3
- python - 在所有 Anaconda 环境中,Spyder 在加载期间不断崩溃