apache-spark - 火花数据框 - 将非列类型变量传递给udf
问题描述
我正在尝试通过擦除一些特征(存储在 中feature_idx_to_wipe
)来修改“特征”向量列。伪代码如下,问题是 udf does not take Set
。我想知道如何解决这个问题,或者是否有更好的方法。
//data
val feature_idx_to_wipe = Set(1, 2)
val dfA = spark.createDataFrame(Seq(
(0, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
(1, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (4, 1.0)))),
(2, Vectors.sparse(6, Seq((0, 1.0), (2, 1.0), (4, 1.0))))
)).toDF("id", "features")
dfA.show(false)
+---+-------------------------+
|id |features |
+---+-------------------------+
|0 |(6,[0,1,2],[1.0,1.0,1.0])|
|1 |(6,[2,3,4],[1.0,1.0,1.0])|
|2 |(6,[0,2,4],[1.0,1.0,1.0])|
+---+-------------------------+
//udf
def wipe(v: NewSparseVector, idx2clean:Set[Int]) : NewSparseVector = {
val lb:ListBuffer[(Int, Double)]=ListBuffer()
v.foreachActive {
case (i, v) =>
if(!idx2clean.contains(i)){
lb += ((i, v))
}
}
NewVectors.sparse(v.size, lb.toSeq).toSparse
}
val udf_wipe = udf((x: NewSparseVector, idx2clean:Set[Int]) => wipe(x, idx2clean))
//apply udf
dfA.withColumn("features_wiped", udf_wipe(col("features"), feature_idx_to_wipe))
// error:
// scala> dfA.withColumn("nf", udf_wipe(col("features"), tc))
// <console>:98: error: type mismatch;
// found : scala.collection.immutable.Set[Int]
// required: org.apache.spark.sql.Column
// dfA.withColumn("nf", udf_wipe(col("features"), tc))
//target (a new column of vector added, with features at index 1,2 are removed)
dfA.select("id","features_wiped").show(false)
+---+-------------------------+
|id |features_wiped |
+---+-------------------------+
|0 |(6,[0],[1.0]) |
|1 |(6,[3,4],[1.0,1.0]) |
|2 |(6,[0,4],[1.0,1.0]) |
+---+-------------------------+
解决方案
另一种选择——
测试数据
//data
val dfA = spark.createDataFrame(Seq(
(0, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
(1, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (4, 1.0)))),
(2, Vectors.sparse(6, Seq((0, 1.0), (2, 1.0), (4, 1.0))))
)).toDF("id", "features")
dfA.show(false)
/**
* +---+-------------------------+
* |id |features |
* +---+-------------------------+
* |0 |(6,[0,1,2],[1.0,1.0,1.0])|
* |1 |(6,[2,3,4],[1.0,1.0,1.0])|
* |2 |(6,[0,2,4],[1.0,1.0,1.0])|
* +---+-------------------------+
*/
Alternative-1 使用lit
如下-
// Alternative-1
//udf
val feature_idx_to_wipe = Array(1, 2)
import org.apache.spark.ml.linalg.{SparseVector => NewSparseVector}
def wipe(v: NewSparseVector, idx2clean:Seq[Int]) : NewSparseVector = {
val lb:ListBuffer[(Int, Double)]=ListBuffer()
v.foreachActive {
case (i, v) =>
if(!idx2clean.contains(i)){
lb += ((i, v))
}
}
Vectors.sparse(v.size, lb.toSeq).toSparse
}
val udf_wipe = udf((x: NewSparseVector, idx2clean:Seq[Int]) => wipe(x, idx2clean))
//apply udf
val newDF = dfA.withColumn("features_wiped", udf_wipe(col("features"), lit(feature_idx_to_wipe)))
//target (a new column of vector added, with features at index 1,2 are removed)
newDF.select("id","features_wiped").show(false)
/**
* +---+-------------------+
* |id |features_wiped |
* +---+-------------------+
* |0 |(6,[0],[1.0]) |
* |1 |(6,[3,4],[1.0,1.0])|
* |2 |(6,[0,4],[1.0,1.0])|
* +---+-------------------+
*/
Alternative-2 使用如下广播变量sparkcontext.broadcast
-
// Alternative2
//data
val feature_idx_to_wipe1 = Set(1, 2)
val broabcastSet = spark.sparkContext.broadcast(feature_idx_to_wipe1)
//udf
import org.apache.spark.ml.linalg.{SparseVector => NewSparseVector}
def wipe1(v: NewSparseVector) : NewSparseVector = {
val idx2clean = broabcastSet.value
val lb:ListBuffer[(Int, Double)]=ListBuffer()
v.foreachActive {
case (i, v) =>
if(!idx2clean.contains(i)){
lb += ((i, v))
}
}
Vectors.sparse(v.size, lb.toSeq).toSparse
}
val udf_wipe1 = udf((x: NewSparseVector) => wipe1(x))
//apply udf
val newDF1 = dfA.withColumn("features_wiped", udf_wipe1(col("features")))
//target (a new column of vector added, with features at index 1,2 are removed)
newDF1.select("id","features_wiped").show(false)
/**
* +---+-------------------+
* |id |features_wiped |
* +---+-------------------+
* |0 |(6,[0],[1.0]) |
* |1 |(6,[3,4],[1.0,1.0])|
* |2 |(6,[0,4],[1.0,1.0])|
* +---+-------------------+
*/
推荐阅读
- sql - 将枢轴上的计数转换为位/标志?
- reactjs - 这个例子中的 currCount 是从哪里来的?
- php - SESSION maxlifetime 问题
- r - as.Date 覆盖旧日期时返回数字
- f# - 是否可以计算 f# 中负数的连分数?
- sql-server - 外键连接加范围条件的最佳索引
- python - 无法在 python 中使用 netCDF4 读取 .nc 文件
- php - 在 Ubuntu 上安装 moodle 3.9 时出错
- python - SQLAlchemy 混合属性日期时间到假期
- python - Python 脚本在 Sublime Text 中运行时不起作用,但可以从命令行运行