python - ValueError:标签形状不匹配。预期的标签维度 = 1。收到 10
问题描述
不知道为什么会出现这个错误。
from __future__ import absolute_import, division, print_function, unicode_literals
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from six.moves import urllib
from IPython.display import clear_output
import tensorflow.compat.v2.feature_column as fc
import tensorflow as tf
colnames = ['face_shape', 'eye_shape','eye_color','nose_shape','lip_shape','skin_color','hair_type','height','weight','ethnicity']
Ethnicity = ['NativeBangladeshi','BengaliIndianAncestors)','Barua','Chakma','BihariStandedPakistani','TribalGaroMruSantalTripuri']
trainurl = "https://github.com/nurunnabi-cse/neural-network/blob/73e84b92292ed40c260e18479b6b75b9cba74aef/edataset/etrain.csv"
traindl = tf.keras.utils.get_file(fname=os.path.basename(trainurl), origin=trainurl)
print("Local copy of the dataset file: {}".format(traindl))
testurl = "https://github.com/nurunnabi-cse/neural-network/blob/73e84b92292ed40c260e18479b6b75b9cba74aef/edataset/etest.csv"
testdl = tf.keras.utils.get_file(fname=os.path.basename(testurl),origin=testurl)
print("Local copy of the dataset file: {}".format(testdl))
train = pd.read_csv(traindl,names=colnames, na_values=[])
test = pd.read_csv(testdl,names=colnames, na_values=[])
train.dropna(axis = 0, inplace = True)
test.dropna(axis = 0, inplace = True)
train_y = train#.pop('ethnicity')
test_y = test.pop('ethnicity')
def input_func(features, labels, trainning=True,batch_size = 10):
dataset = tf.data.Dataset.from_tensor_slices((dict(features),labels))
if trainning:
dataset = dataset.shuffle(100).repeat()
return dataset.batch(batch_size)
featurecols = []
for key in train.keys():
featurecols.append(tf.feature_column.numeric_column(key=key))
print(featurecols)
分类器弹出错误。它说它想要一维数据,但找到了 10 维数据。
classifier = tf.estimator.DNNClassifier(feature_columns=featurecols,hidden_units=[30,10],n_classes=5)
classifier.train(input_fn=lambda: input_func(train,train_y,trainning=True),steps = 5000)
eval_result = classifier.evaluate(input_fn=lambda: input_func(test,test_y, trainning=False))
print('\nTest Set Acc: {accuracy:0.3f}\n'.format(**eval_result))
我的数据集与谷歌的鸢尾花数据集完全匹配。
解决方案
推荐阅读
- python - 错误:gevent 1.4.0 要求 greenlet>=0.4.14,但您将拥有不兼容的 greenlet 0.4.13
- ssh - 非 22 端口的 VS 代码 ssh 问题
- selenium - 无法使用 selenium java 上传文件
- javascript - 如何在 D3js 中折叠圆圈组?
- android - Kotlin Room 错误:实体和 pojo 应该有一个构造函数
- python - 两列相同时如何合并两个 Pandas DataFrame
- aws-lambda - 对 lambda 函数的调用者进行身份验证
- android - android - 如何从状态栏中删除黑暗
- https - Traefik 反向代理后的 Plone 的混合内容问题
- python - 按出现顺序获取 numpy 数组索引