python - InvalidArgumentError:预期 'tf.Tensor(False, shape=(), dtype=bool)' 为真
问题描述
在使用结构相似性指数进行比较之前,我正在使用 PCA 来减小图像的尺寸。使用 PCA 后,tf.image.ssim 会抛出错误。
我在这里比较图像而不使用 PCA。这完美地工作 -
import numpy as np
import tensorflow as tf
import time
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)
start = time.time()
for i in range(1,6000):
x_train_zero = np.expand_dims(x_train[0], axis=2)
x_train_expanded = np.expand_dims(x_train[i], axis=2)
print(tf.image.ssim(x_train_zero, x_train_expanded, 255))
print(time.time()-start)
我在这里应用了 PCA 来减小图像的尺寸,这样 SSIM 比较图像所需的时间更少——
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
x_train = x_train.reshape(60000,-1)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(x_train)
pca = PCA()
pca = PCA(n_components = 11)
X_pca = pca.fit_transform(X_scaled).reshape(60000,11,1)
start = time.time()
for i in range(1,6000):
X_pca_zero = np.expand_dims(X_pca[0], axis=2)
X_pca_expanded = np.expand_dims(X_pca[i], axis=2)
print(tf.image.ssim(X_pca_zero, X_pca_expanded, 255))
print(time.time()-start)
这段代码会引发错误 - InvalidArgumentError: Expected 'tf.Tensor(False, shape=(), dtype=bool)' to be true。汇总数据:11、1、1 11
解决方案
因此,简而言之,发生该错误是因为 in tf.image.ssim
,输入X_pca_zero
和X_pca_expanded
大小不匹配filter_size
,如果您有,filter_size=11
则X_pca_zero
andX_pca_expanded
必须至少为11x11,如何更改代码的示例:
import tensorflow as tf
import time
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)
x_train = x_train.reshape(60000,-1)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(x_train)
pca = PCA()
pca = PCA(n_components = 16) # or 12 -> 3, 4 filter_size=3
X_pca = pca.fit_transform(X_scaled).reshape(60000, 4, 4, 1)
start = time.time()
X_pca_zero = X_pca[0]
for i in range(1,6000):
X_pca_expanded = X_pca[i]
print(tf.image.ssim(X_pca_zero, X_pca_expanded, 255, filter_size=4))
print(time.time()-start)
推荐阅读
- excel - vba 数据透视表 - 应用程序定义或对象定义错误
- ios - [__NSCFString objectForKey:]:文本字段搜索崩溃
- reactive-programming - Spring WebFlux/ 反应堆核心
- jquery - Angular 7上的引导程序出现错误,因为'JQuery类型上不存在属性折叠
` - cluster-analysis - 单个“多输入 Tx”与多个“单输入 Tx”
- git - 在本地撤消 git rebase 跳过
- dart - 如何处理可扩展表单上的控制器 TextField
- c# - 试图从 SQL Server Profiler 获取过程名称,而是显示“sp_reset_connection”
- php - AWS SDK PHP - 如何通过一个链接加载多个文件?如何创建一个 zip 文件?
- java - 如何创建从 Activity 底部绘制的 listView?