scala - 如何在 spark 中的 windows 上制作更快的窗口文本文件和机器学习
问题描述
我正在尝试使用 Spark 在窗口文本文件上学习多类逻辑回归。我正在做的是首先创建窗口并将它们分解为$"word_winds"
. 然后将每个窗口的中心词移动到$"word"
. 为了拟合LogisticRegression
模型,我将每个不同的单词转换为一个类($"label"
),从而学习。我计算了不同的标签以倾向于那些minF
样本很少的标签。
问题是代码的某些部分非常慢,即使对于小的输入文件也是如此(您可以使用一些 README 文件来测试代码)。谷歌搜索,一些用户使用explode
. 他们建议对代码进行一些修改,以加快 2 倍的速度。但是,我认为对于 100MB 的输入文件,这还不够。请提出不同的建议,可能是为了避免降低代码速度的操作。我在 24 核机器上使用 Spark 2.4.0 和 sbt 1.2.8。
import org.apache.spark.sql.functions._
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, IDF}
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.SparkSession
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.types._
object SimpleApp {
def main(args: Array[String]) {
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
spark.sparkContext.setCheckpointDir("checked_dfs")
val in_file = "sample.txt"
val stratified = true
val wsize = 7
val ngram = 3
val minF = 2
val windUdf = udf{s: String => s.sliding(ngram).toList.sliding(wsize).toList}
val get_mid = udf{s: Seq[String] => s(s.size/2)}
val rm_punct = udf{s: String => s.replaceAll("""([\p{Punct}|¿|\?|¡|!]|\p{C}|\b\p{IsLetter}{1,2}\b)\s*""", "")}
// Read and remove punctuation
var df = spark.read.text(in_file)
.withColumn("value", rm_punct($"value"))
// Creating windows and explode them, and get the center word into $"word"
df = df.withColumn("char_nGrams", windUdf('value))
.withColumn("word_winds", explode($"char_nGrams"))
.withColumn("word", get_mid('word_winds))
val indexer = new StringIndexer().setInputCol("word")
.setOutputCol("label")
df = indexer.fit(df).transform(df)
val hashingTF = new HashingTF().setInputCol("word_winds")
.setOutputCol("freqFeatures")
df = hashingTF.transform(df)
val idf = new IDF().setInputCol("freqFeatures")
.setOutputCol("features")
df = idf.fit(df).transform(df)
// Remove word whose freq is less than minF
var counts = df.groupBy("label").count
.filter(col("count") > minF)
.orderBy(desc("count"))
.withColumn("id", monotonically_increasing_id())
var filtro = df.groupBy("label").count.filter(col("count") <= minF)
df = df.join(filtro, Seq("label"), "leftanti")
var dfs = if(stratified){
// Create stratified sample 'dfs'
var revs = counts.orderBy(asc("count")).select("count")
.withColumn("id", monotonically_increasing_id())
revs = revs.withColumnRenamed("count", "ascc")
// Weigh the labels (linearly) inversely ("ascc") proportional NORMALIZED weights to word ferquency
counts = counts.join(revs, Seq("id"), "inner").withColumn("weight", col("ascc")/df.count)
val minn = counts.select("weight").agg(min("weight")).first.getDouble(0) - 0.01
val maxx = counts.select("weight").agg(max("weight")).first.getDouble(0) - 0.01
counts = counts.withColumn("weight_n", (col("weight") - minn) / (maxx - minn))
counts = counts.withColumn("weight_n", when(col("weight_n") > 1.0, 1.0)
.otherwise(col("weight_n")))
var fractions = counts.select("label", "weight_n").rdd.map(x => (x(0), x(1)
.asInstanceOf[scala.Double])).collectAsMap.toMap
df.stat.sampleBy("label", fractions, 36L).select("features", "word_winds", "word", "label")
}else{ df }
dfs = dfs.checkpoint()
val lr = new LogisticRegression().setRegParam(0.01)
val Array(tr, ts) = dfs.randomSplit(Array(0.7, 0.3), seed = 12345)
val training = tr.select("word_winds", "features", "label", "word")
val test = ts.select("word_winds", "features", "label", "word")
val model = lr.fit(training)
def mapCode(m: scala.collection.Map[Any, String]) = udf( (s: Double) =>
m.getOrElse(s, "")
)
var labels = training.select("label", "word").distinct.rdd
.map(x => (x(0), x(1).asInstanceOf[String]))
.collectAsMap
var predictions = model.transform(test)
predictions = predictions.withColumn("pred_word", mapCode(labels)($"prediction"))
predictions.write.format("csv").save("spark_predictions")
spark.stop()
}
}
解决方案
由于您的数据有点小,如果您在爆炸前使用合并可能会有所帮助。有时节点过多可能效率低下,尤其是在代码中有很多改组的情况下。
就像你说的,似乎很多人都有爆炸的问题。我查看了您提供的链接,但没有人提到尝试flatMap而不是爆炸。
推荐阅读
- extjs - ExtJS 3 锁定:真的不工作
- javascript - Firebase 可调用函数不起作用
- kubernetes - 在 kubernetes 上公开我的传出 IP 地址可以访问的特定端口
- ruby - Ruby:处理布尔值
- c - 代码构建成功但崩溃了
- android - Android Jetpack WorkManager 数据保留政策
- javascript - 在 JS 中检查时,我无法读取非默认单选按钮的值
- ibm-cloud - API Connect 最大请求大小
- port - 无法使用 iptables 启动 IPv4 防火墙 Can't Telnet to port
- python-3.x - python:使用熊猫链接复杂的条件数据框更改