tensorflow - `get_variable()` 中的 `tf.zeros_initializer` 有什么问题?
问题描述
我想使用 CW 算法训练一些对抗性示例,我使用了这里的示例和这里的CW 实现。但是我遇到了一个错误:tf.zeros_initializer
ValueError: The initializer passed is not valid. It should be a callable with no arguments and the shape should not be provided or an instance of
'tf.keras.initializers.*' and `shape` should be fully defined.
编辑:似乎未完全定义的形状与使用初始化程序冲突。我该如何解决?
这是一段代码:
# ... omitted
with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
# CW
_, env.adv_cw, _ = cw.cw(model, env.x)
这是env.x
:
env.x = tf.placeholder(tf.float32, (None, width, height, channels), name='x')
当我运行代码时,我收到错误消息:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-39-712c8b007d37> in <module>()
8 with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
9 # CW
---> 10 _, env.adv_cw, _ = cw.cw(model, env.x)
5 frames
/content/cw.py in cw(model, x, y, eps, ord_, T, optimizer, alpha, min_prob, clip)
50 """
51 xshape = x.get_shape().as_list()
---> 52 noise = tf.get_variable('noise', shape=xshape, dtype=tf.float32,
53 initializer=tf.zeros_initializer)
54
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py in get_variable(name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter, constraint, synchronization, aggregation)
1494 constraint=constraint,
1495 synchronization=synchronization,
-> 1496 aggregation=aggregation)
1497
1498
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py in get_variable(self, var_store, name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter, constraint, synchronization, aggregation)
1237 constraint=constraint,
1238 synchronization=synchronization,
-> 1239 aggregation=aggregation)
1240
1241 def _get_partitioned_variable(self,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py in get_variable(self, name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter, constraint, synchronization, aggregation)
560 constraint=constraint,
561 synchronization=synchronization,
--> 562 aggregation=aggregation)
563
564 def _get_partitioned_variable(self,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py in _true_getter(name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource, constraint, synchronization, aggregation)
512 constraint=constraint,
513 synchronization=synchronization,
--> 514 aggregation=aggregation)
515
516 synchronization, aggregation, trainable = (
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py in _get_single_variable(self, name, shape, dtype, initializer, regularizer, partition_info, reuse, trainable, collections, caching_device, validate_shape, use_resource, constraint, synchronization, aggregation)
906 variable_dtype = None
907 else:
--> 908 raise ValueError("The initializer passed is not valid. It should "
909 "be a callable with no arguments and the "
910 "shape should not be provided or an instance of "
ValueError: The initializer passed is not valid. It should be a callable with no arguments and the shape should not be provided or an instance of `tf.keras.initializers.*' and `shape` should be fully defined.
但是谷歌的 TensorFlow Guide给出了一个使用的例子get_variable
:
my_int_variable = tf.get_variable("my_int_variable", [1, 2, 3], dtype=tf.int32,
initializer=tf.zeros_initializer)
环境: Google Colab、TensorFlow 1.14.0-rc1、Python 3.6
解决方案
只需根据您的占位符尺寸进行更改,让我以您的占位符变量为例。
** x = placeholder(t f. float 32, (None, width, height, channels), name='x')**.
它有 4 个维度:[无,宽度,高度,通道],但是没有定义宽度,高度,通道,意味着对于图像宽度 = 6,高度 = 6,通道 = 3 已定义,因此张量维度为 [6 x 6 x 3]。
您可以做的是,您读取的图像,获取不同变量中的所有三个维度值并将其传递给您的占位符变量。前任。
Image A = 32 x 32 x 3
width = A.shape[0]
height = A.shape[1]
channels =A.shape[2]
或者您可以通过这种方式定义占位符,直接将值提供给宽度、高度、通道(如果您知道输入数据的形状)。
推荐阅读
- python - Istio 策略未对 JWT 进行身份验证
- react-native - 带有浮动内联后缀的 React-native TexField
- r - 完善每日销售预测
- sql-server - 将 sql server AlphaToDate4 转换为具有 5 个数字的 oracle 日期
- python - 在熊猫数据框中使用多列作为索引
- vba - 如何使用资本和利率计算贷款成本?Vba
- android - 将动态标头添加到带有正文的 POST 请求(Retrofit v2.4.0)
- elixir - 埃克托 | 如何获取应用程序中的所有模式模块
- android - 适用于某些智能手机的 Google Play 隐形应用
- android - 如何修复getResourceID在android中找不到问题?