python - 运行张量流随机森林但出现值错误?
问题描述
我正在尝试 tensorflow 随机森林,但我收到以下错误参数 =
tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(num_trees=100, max_nodes=1000,num_classes = len(le.classes_),num_features = 119)
classifier =tf.contrib.tensor_forest.client.random_forest.TensorForestEstimator(params)
classifier.fit(x=X_train, y=y_train)
ValueError: Tensor conversion requested dtype float32 for Tensor with dtype float64: 'Tensor("concat:0", shape=(?, 119), dtype=float64)'
但是当我跑步时它会起作用scikit-learn
clf = RandomForestClassifier(n_estimators=n_estms, n_jobs=n_jobs)
clf = clf.fit(X_train, y_train)
更新:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42)
更新:尝试了这种风格,但仍然无法正常工作
input_fn = numpy_io.numpy_input_fn(
x=X_train.astype(np.float32),
y=y_train.astype(np.float32),
num_epochs=None,
shuffle=True)
classifier.fit(input_fn = input_fn,steps=None)
ValueError: Features are incompatible with given information. Given features: Tensor("fifo_queue_DequeueMany:1",
shape=(128, 119), dtype=float32), required signatures: TensorSignature(dtype=tf.float64, shape=TensorShape([Dimension(128), Dimension(119)]), is_sparse=False).
数据集:
X_train,y_train,len(X_train),len(y_train)
(array([[ 3.3042e-01, 2.4995e-01, -6.0874e-01, ..., 3.0400e+02,
5.0000e+00, 1.0000e+00],
[ 4.2466e-01, 8.5174e-01, 8.6044e-01, ..., 1.0000e+00,
7.8000e+01, 1.0000e+00],
[ 6.1890e-01, -1.1185e+00, 5.8483e-01, ..., 1.4000e+01,
7.0000e+00, 1.0000e+00],
...,
[ 9.0512e-01, 1.3008e-01, 1.0917e+00, ..., 1.7000e+01,
2.0000e+00, 1.0000e+00],
[-1.4751e-01, 5.5556e-01, 1.0764e+00, ..., 1.8000e+01,
1.3000e+01, 1.0000e+00],
[-5.0246e-01, 1.2178e+00, -8.0065e-01, ..., 1.0000e+00,
3.0000e+00, 0.0000e+00]]),
array([1, 0, 5, ..., 8, 5, 9]),
510281,
510281)
解决方案
错误消息中非常清楚地描述了问题。尝试
classifier.fit(x=tf.cast(X_train, tf.float32), y=y_train)
此外,看起来您使用过时的输入格式和tensors
. 例如,如何将您的输入转换为input_fn
格式请看这里。
推荐阅读
- runtime-error - 结账时 Realex 全球支付问题
- c# - BufferBlock 缺失值
- python - Python 3.6 Pandas 从某些列中选择所有行
- uilocalnotification - 如何在更改日期时间后显示本地通知?
- c# - 如何在不拖放序列化引用的情况下激活脚本中的子项?
- java - 防止 Hazelcast 客户端关闭
- powerbi - 如何发布包含多个事实表中的数百万条记录的 PowerBI 报告
- aws-cli - 本地堆栈的 AWS CLI SQS - 无法设置队列属性
- python-3.x - Graphviz 和 YAML 文件的问题
- python-3.x - 具有大结果的 Pandas read_gbq