首页 > 解决方案 > Scala试图从数据框中的udt(数组)中提取值

问题描述

这似乎应该是一个简单的问题,但我找不到任何可以在文档或 stackoverflow 中工作的东西。

我正在尝试从继承的 Scala 代码中的数据框中的 udt 中提取值。我的目标是从(或Scala中的任何正确语法)中DataFrame提取适合的列yProbability.values(1)

DataFrame 具有以下结构:

outputDataAAL:org.apache.spark.sql.DataFrame
 - info_conversationid:string 
 - document:string 
 - yProbability:udt
 - yHat:double

yProbability 的示例元素为:

array
 - 0: 1
 - 1: 2
 - 2: []
 - 3:
     - 0: 0.8054468196483193
     - 1: 0.19455318035168068

在 r 我做了一个简单的:

outputDataAAL$fit <- outputDataAAL$yProbability %>% lapply(function(x) {x[[2]][2]}) %>% unlist

这很容易,但对于我正在查看的数据大小来说很慢。这就是为什么我想在 Scala 中做这件事。

我尝试仅提取值(这是 yProbability 中的数组元素 #3),但是以下两种方法都给我以下错误。

val newSample = outputDataAAL.select("yProbability.values(1)")
val newSample = outputDataAAL.select($"yProbability".getItem("values(1)"))

错误:

Can't extract value from yProbability#4404: need struct type but got 
struct<type:tinyint,size:int,indices:array<int>,values:array<double>>

我还尝试在使我失望时对 outputDataAAL 进行采样,但#4404我不知道这是否是由于列错误造成的。显然,没有运气。

非常感谢您的帮助。

-瑞克

标签: arraysscaladataframeapache-sparkextract

解决方案


好吧,我想出了一个解决方案,但它有点难看。我觉得应该有一个更简单的答案,但这有效。

import org.apache.spark.sql.functions._
import org.apache.spark.ml._
val vecToArray = udf( (xs: linalg.Vector) => xs.toArray )  
val newSample = outputDataAAL.withColumn("yProbabilityArr" , vecToArray($"yProbability") )
val outputDataAALNew = newSample.withColumn("fit", $"yProbabilityArr".getItem(1))

推荐阅读