首页 > 解决方案 > 错误 TreeNodeException:使用 Pyspark 执行 PipelineModel.transform 中的树

问题描述

所以我在管道中进行一次性编码并对其进行拟合方法。

我有一个具有分类列和数字列的数据框,因此我有一个使用字符串索引器的热编码分类列。

from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler


categoricalColumns = ['IncomeDetails','B2C','Gender','Occupation','POA_Status']
stages = []

for categoricalCol in categoricalColumns:
    stringIndexer = StringIndexer(inputCol = categoricalCol, outputCol = categoricalCol + 'Index')
    encoder = OneHotEncoderEstimator(inputCols=[stringIndexer.getOutputCol()], outputCols=[categoricalCol + "classVec"])
    stages += [stringIndexer, encoder]
    
    
label_stringIdx = StringIndexer(inputCol = 'target', outputCol = 'label')
stages += [label_stringIdx]


#new_col_array.remove("client_id")

numericCols = new_col_array
numericCols.append('age')


assemblerInputs = [c + "classVec" for c in categoricalColumns] + numericCols
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
stages += [assembler]


from pyspark.ml import Pipeline
pipeline = Pipeline(stages = stages)
pipelineModel = pipeline.fit(new_df1)
new_df1 = pipelineModel.transform(new_df1)
selectedCols = ['label', 'features'] + cols

我收到此错误:

Py4JJavaError: An error occurred while calling o2053.fit.
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: execute, tree:
Exchange hashpartitioning(client_id#*****, 200)
+- *(4) HashAggregate(keys=[client_id#*****], functions=[], output=[client_id#*****])
   +- Exchange hashpartitioning(client_id#*****, 200)
      +- *(3) HashAggregate(keys=[client_id#*****], functions=[], output=[client_id#*****])
         +- *(3) HashAggregate(keys=[client_id#*****, event_name#27993], functions=[], output=[client_id#27980])
            +- Exchange hashpartitioning(client_id#*****, event_name#27993, 200)
               +- *(2) HashAggregate(keys=[client_id#*****, event_name#27993], functions=[], output=[client_id#*****, event_name#27993])
                  +- *(2) Project [client_id#*****, event_name#27993]
                     +- *(2) BroadcastHashJoin [client_id#*****], [Party_Code#*****], LeftSemi, BuildRight, false
                        :- *(2) Project [client_id#*****, event_name#27993]
                        :  +- *(2) Filter isnotnull(client_id#*****)
                        :     +- *(2) FileScan orc dbo.dp_clickstream_[client_id#*****,event_name#27993,dt#28010] Batched: true, Format: ORC, Location: **PrunedInMemoryFileIndex**[s3n://processed/db-dbo-..., PartitionCount: 6, PartitionFilters: [isnotnull(dt#28010), (cast(dt#28010 as timestamp) >= 1610409600000000), (cast(dt#28010 as timest..., PushedFilters: [IsNotNull(client_id)], ReadSchema: struct<client_id:string,event_name:string>
                        +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]),false)


at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
    at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.doExecute(ShuffleExchangeExec.scala:83)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:173)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:169)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:197)


Caused by: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: execute, tree:

我的 Spark 版本是2.4.3

标签: apache-sparkpysparkpyspark-dataframesapache-spark-ml

解决方案


推荐阅读