首页 > 解决方案 > 使用 iris 示例加载 csv 时 tensorflow 的值错误

问题描述

当试图在 Jupyter 上运行一个简单的获取数据序列时,为了让系统识别鸢尾花 tyoes teough Fisher's table,错误:

  ValueError                                Traceback (most recent call last)
 <ipython-input-12-269564554b65> in <module>
 10 training_set = base.load_csv_with_header(filename=IRIS_TRAINING,
 11                                      features_dtype=np.float32,
 ---> 12                                      target_dtype=np.float32)
 13 test_set = base.load_csv_with_header(filename=IRIS_TEST,
 14                                  features_dtype=np.float32,

 /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py in load_csv_with_header(filename, target_dtype, features_dtype, target_column)
 46     data_file = csv.reader(csv_file)
 47     header = next(data_file)
 ---> 48     n_samples = int(header[0])
 49     n_features = int(header[1])
 50     data = np.zeros((n_samples, n_features), dtype=features_dtype)

ValueError: invalid literal for int() with base 10: '5.1'

正在显示。该错误表明它无法使用 int() 函数,尽管代码中根本没有 int 。这是代码:

import tensorflow as tf
import numpy as np
from tensorflow.contrib.learn.python.learn.datasets import base

# Data files
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

# Load datasets.
training_set = base.load_csv_with_header(filename=IRIS_TRAINING,
                                     features_dtype=np.float32,
                                     target_dtype=np.float32)
test_set = base.load_csv_with_header(filename=IRIS_TEST,
                                 features_dtype=np.float32,
                                 target_dtype=np.float32)

print(training_set.data)

print(training_set.target)

为什么target_dytype=np.int不工作,如错误所示?提前致谢。

标签: pythonnumpytensorflowjupyter-notebookartificial-intelligence

解决方案


答案就在那里,您无法将小数转换为整数。尝试改用 numpy 数据类型float32

target_dtype=np.float32

更新

tensorflow.base有几个load_csv..,你可以试试base.load_csv_without_header或者base.load_csv

具体的iristensorflow例子是应用到自己的数据集上的,这里第一列的header存放的是例子的个数,导致下面的错误:

 46     data_file = csv.reader(csv_file)
 47     header = next(data_file)
 ---> 48     n_samples = int(header[0])
 49     n_features = int(header[1])

header返回 csv 文件的第一行,n_samples用于存储第一列的样本数。


推荐阅读