首页 > 解决方案 > 带有 MLLIB 的 pyspark 数据帧中的点积

问题描述

我在 pyspark 中有一个非常简单的数据框,如下所示:

from pyspark.sql import Row
from pyspark.mllib.linalg import DenseVector

row = Row("a", "b")
df = spark.sparkContext.parallelize([
    offer_row(DenseVector([1, 1, 1]), DenseVector([1, 0, 0])),
]).toDF()

我想计算这些向量的点积而不求助于 UDF 调用。

spark MLLIB文档引用了一个dot方法,DenseVectors但如果我尝试按如下方式应用它:

df_offers = df_offers.withColumn("c", col("a").dot(col("b")))

我收到如下错误:

TypeError: 'Column' object is not callable

有谁知道这些 mllib 方法是否可以在 DataFrame 对象上调用?

标签: pythonapache-sparkpysparkapache-spark-mllib

解决方案


在这里,您将dot方法应用于列而不是 a DenseVector,这确实不起作用:

df_offers = df_offers.withColumn("c", col("a").dot(col("b")))

您将不得不使用 udf :

from pyspark.sql.functions import udf, array
from pyspark.sql.types import DoubleType

def dot_fun(array):
    return array[0].dot(array[1])

dot_udf = udf(dot_fun, DoubleType())

df_offers = df_offers.withColumn("c", dot_udf(array('a', 'b')))

推荐阅读