python - 使用 iris 示例加载 csv 时 tensorflow 的值错误
问题描述
当试图在 Jupyter 上运行一个简单的获取数据序列时,为了让系统识别鸢尾花 tyoes teough Fisher's table,错误:
ValueError Traceback (most recent call last)
<ipython-input-12-269564554b65> in <module>
10 training_set = base.load_csv_with_header(filename=IRIS_TRAINING,
11 features_dtype=np.float32,
---> 12 target_dtype=np.float32)
13 test_set = base.load_csv_with_header(filename=IRIS_TEST,
14 features_dtype=np.float32,
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py in load_csv_with_header(filename, target_dtype, features_dtype, target_column)
46 data_file = csv.reader(csv_file)
47 header = next(data_file)
---> 48 n_samples = int(header[0])
49 n_features = int(header[1])
50 data = np.zeros((n_samples, n_features), dtype=features_dtype)
ValueError: invalid literal for int() with base 10: '5.1'
正在显示。该错误表明它无法使用 int() 函数,尽管代码中根本没有 int 。这是代码:
import tensorflow as tf
import numpy as np
from tensorflow.contrib.learn.python.learn.datasets import base
# Data files
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
# Load datasets.
training_set = base.load_csv_with_header(filename=IRIS_TRAINING,
features_dtype=np.float32,
target_dtype=np.float32)
test_set = base.load_csv_with_header(filename=IRIS_TEST,
features_dtype=np.float32,
target_dtype=np.float32)
print(training_set.data)
print(training_set.target)
为什么target_dytype=np.int
不工作,如错误所示?提前致谢。
解决方案
答案就在那里,您无法将小数转换为整数。尝试改用 numpy 数据类型float32
。
target_dtype=np.float32
更新
tensorflow.base
有几个load_csv..
,你可以试试base.load_csv_without_header
或者base.load_csv
。
具体的iris
tensorflow例子是应用到自己的数据集上的,这里第一列的header存放的是例子的个数,导致下面的错误:
46 data_file = csv.reader(csv_file)
47 header = next(data_file)
---> 48 n_samples = int(header[0])
49 n_features = int(header[1])
header
返回 csv 文件的第一行,n_samples
用于存储第一列的样本数。
推荐阅读
- javascript - 如何根据您所在的当前时间使水平线向上或向下移动?
- powerbi - 如何在 Power BI 图表中控制轴中的年龄组?
- node.js - 如何将 Google 用户(oauth2)集成到 Nodejs express API
- multidimensional-array - 如何为xarray中的特定变量点查找时间、纬度、经度的索引
- c++ - 为什么不可能为 QPropertyAnimation 制作 qobject_cast
- django - Django Postgres 应用程序在 localhost 上运行良好,但在从 GitHub 部署后无法在 Heroku 上加载(错误代码 H10,状态 503)
- vba - 来自外部模块的事件回调(VB6 示例不适用于 VBA)
- oracle - 在 oracle 中打印进度
- python - 如何在 HTML 中访问查询集的所有 value_list 值的值?
- kubernetes - 基于 JVM 堆内存的 Kubernetes HPA