apache-spark - Scala 和 Python API 中的 LSH
问题描述
我正在关注这篇 SO post Efficient string matching in Apache Spark,以使用 LSH 算法获得一些字符串匹配。出于某种原因,通过 python API 获得结果,但不是在 Scala 中。我真的看不出 Scala 代码中缺少什么。
下面是这两个代码:
from pyspark.ml import Pipeline
from pyspark.ml.feature import RegexTokenizer, NGram, HashingTF, MinHashLSH
query = spark.createDataFrame(["Bob Jones"], "string").toDF("text")
db = spark.createDataFrame(["Tim Jones"], "string").toDF("text")
model = Pipeline(stages=[
RegexTokenizer(
pattern="", inputCol="text", outputCol="tokens", minTokenLength=1
),
NGram(n=3, inputCol="tokens", outputCol="ngrams"),
HashingTF(inputCol="ngrams", outputCol="vectors"),
MinHashLSH(inputCol="vectors", outputCol="lsh")
]).fit(db)
db_hashed = model.transform(db)
query_hashed = model.transform(query)
model.stages[-1].approxSimilarityJoin(db_hashed, query_hashed, 0.75).show()
它返回:
> +--------------------+--------------------+-------+ | datasetA| datasetB|distCol| > +--------------------+--------------------+-------+ |[Tim Jones, [t, i...|[Bob Jones, [b, o...| 0.6| > +--------------------+--------------------+-------+
然而 Scala 什么也没返回,这里是代码:
import org.apache.spark.ml.feature.RegexTokenizer
val tokenizer = new RegexTokenizer().setPattern("").setInputCol("text").setMinTokenLength(1).setOutputCol("tokens")
import org.apache.spark.ml.feature.NGram
val ngram = new NGram().setN(3).setInputCol("tokens").setOutputCol("ngrams")
import org.apache.spark.ml.feature.HashingTF
val vectorizer = new HashingTF().setInputCol("ngrams").setOutputCol("vectors")
import org.apache.spark.ml.feature.{MinHashLSH, MinHashLSHModel}
val lsh = new MinHashLSH().setInputCol("vectors").setOutputCol("lsh")
import org.apache.spark.ml.Pipeline
val pipeline = new Pipeline().setStages(Array(tokenizer, ngram, vectorizer, lsh))
val query = Seq("Bob Jones").toDF("text")
val db = Seq("Tim Jones").toDF("text")
val model = pipeline.fit(db)
val dbHashed = model.transform(db)
val queryHashed = model.transform(query)
model.stages.last.asInstanceOf[MinHashLSHModel].approxSimilarityJoin(dbHashed, queryHashed, 0.75).show
我正在使用 Spark 3.0,我知道它是一个测试,但无法在不同版本上真正测试它。我怀疑是否存在这样的错误:)
解决方案
如果正确设置 numHashTables,此代码将在 Spark 3.0.1 中运行。
val lsh = new MinHashLSH().setInputCol("vectors").setOutputCol("lsh").setNumHashTables(3)
推荐阅读
- sql - 循环所有表并执行存储过程
- javascript - PHP将文件发送到Web服务器无法访问的浏览器
- jquery - 当您在每个循环中修改 jQuery 集合时,是包含修改后的集合还是仅包含原始集合?
- vue.js - 无法在 vuetify 项目中添加自定义颜色
- git - Git rebase to orphan 导致二进制文件冲突
- c# - 可空值类型上的提升运算符是否使用短路?
- php - 在每个函数调用中重新加载使用“require”检索的 php 文件
- arrays - 在 python 逻辑中接受多个输入
- c# - 我应该如何编写一个函数(n)来计算仅包含 1、3、4 的所有可能序列
- c++ - 在 Eigen3 中实现 Bartels–Stewart 算法——仅实矩阵?