apache-spark - 无法在 spark 中读取 libsvm 文件
问题描述
我试图使用 Spark 和 pyspark 读取 .txt 文件,但我得到了我无法理解的错误。我已经正确安装了 py4j,而且我可以毫无问题地读取 csv 文件。
这是我的代码:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("test").getOrCreate()
my_data = spark.read.format("libsvm").load("sample_libsvm_data.txt")
我得到的错误是这样的:
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-4-3347b4cad068> in <module>
----> 1 my_data = spark.read.format("libsvm").load("sample_libsvm_data.txt")
C:\ProgramData\Anaconda3\lib\site-packages\pyspark\sql\readwriter.py in load(self, path, format, schema, **options)
164 self.options(**options)
165 if isinstance(path, basestring):
--> 166 return self._df(self._jreader.load(path))
167 elif path is not None:
168 if type(path) != list:
C:\ProgramData\Anaconda3\lib\site-packages\py4j\java_gateway.py in __call__(self, *args)
1255 answer = self.gateway_client.send_command(command)
1256 return_value = get_return_value(
-> 1257 answer, self.gateway_client, self.target_id, self.name)
1258
1259 for temp_arg in temp_args:
谢谢您的帮助。
解决方案
之前也遇到过同样的问题。通过设置“numFeatures”选项解决了这个问题。
my_data = spark.read.format('libsvm').option("numFeatures", "692").load('sample_libsvm_data.txt')
如果不知道 numFeatures 将很困难。您可以使用此自定义函数来读取 libsvm 文件。
from pyspark.sql import Row
from pyspark.ml.linalg import SparseVector
def read_libsvm(filepath, spark_session):
'''
A utility function that takes in a libsvm file and turn it to a pyspark dataframe.
Args:
filepath (str): The file path to the data file.
spark_session (object): The SparkSession object to create dataframe.
Returns:
A pyspark dataframe that contains the data loaded.
'''
with open(filepath, 'r') as f:
raw_data = [x.split() for x in f.readlines()]
outcome = [int(x[0]) for x in raw_data]
index_value_dict = list()
for row in raw_data:
index_value_dict.append(dict([(int(x.split(':')[0]), float(x.split(':')[1]))
for x in row[1:]]))
max_idx = max([max(x.keys()) for x in index_value_dict])
rows = [
Row(
label=outcome[i],
feat_vector=SparseVector(max_idx + 1, index_value_dict[i])
)
for i in range(len(index_value_dict))
]
df = spark_session.createDataFrame(rows)
return df
用法:
my_data = read_libsvm(filepath="sample_libsvm_data.txt", spark_session=spark)
推荐阅读
- caesar-cipher - 凯撒密码网站中潜在的错误内容
- c++ - 对我的 if 语句给出错误的答案
- c - ***检测到堆栈粉碎***:终止
- javascript - d3 geopath,geojsonlint,CW vs CCW
- c# - .Net Core Razor 页面 - 服务器端包含
- sql - 从 BLOB 中提取多个值作为 XML
- csh - 在 csh 中运行别名命令并将其存储在变量中
- python - 如何通过用户输入检查txt文件中是否存在多个字符串?
- python - BeautifulSoup/Python 解析网站的问题
- linux - 通过 HCI 更改远程设备名称和类别