python - Tensorflow 形状推断静态 RNN 编译器错误
问题描述
我正在开发针对手机摄像头图像优化的 OCR 软件。
目前,每个 300 x 1000 x 3 (RGB) 图像被重新格式化为 900 x 1000 numpy 数组。我有一个更复杂的模型架构的计划,但现在我只想让基线工作。我想从我生成的数据上训练一个静态 RNN 开始。
形式上,我在每个时间步 t 为 T 个时间步输入 n_t,其中 n_t 是一个 900 向量且 T = 1000(类似于从左到右读取整个图像)。这是我创建训练批次的 Tensorflow 代码:
sequence_dataset = tf.data.Dataset.from_generator(example_generator, (tf.int32,
tf.int32))
sequence_dataset = sequence_dataset.batch(experiment_params['batch_size'])
iterator = sequence_dataset.make_initializable_iterator()
x_batch, y_batch = iterator.get_next()
tf.nn.static_bidirectional_rnn 文档声称输入必须是“长度为 T 的输入列表,每个输入的形状为 [batch_size, input_size] 的张量,或此类元素的嵌套元组”。因此,我通过以下步骤将数据转换为正确的格式。
# Dimensions go from [batch, n , t] -> [t, batch, n]
x_batch = tf.transpose(x_batch, [2, 0, 1])
# Unpack such that x_batch is a length T list with element dims [batch_size, n]
x_batch = tf.unstack(x_batch, experiment_params['example_t'], 0)
在不进一步更改批次的情况下,我进行以下调用:
output, _, _ = tf.nn.static_rnn(lstm_fw_cell, x_batch, dtype=tf.int32)
请注意,我没有明确告诉 Tensorflow 矩阵的维度(这可能是问题所在)。它们都具有相同的维度,但我收到以下错误:
ValueError: Input size (dimension 0 of inputs) must be accessible via shape
inference, but saw value None.
我应该在堆栈中的哪个位置声明输入的维度?因为我正在使用数据集并希望将其批次直接发送到 RNN,所以我不确定“占位符 - > feed_dict”路线是否有意义。如果这实际上是最有意义的方法,请告诉我它是什么样的(我绝对不知道)。否则,如果您对该问题有任何其他见解,请告诉我。谢谢!
解决方案
缺少静态形状信息的原因是 TensorFlow 对函数的了解不够,example_generator
无法确定它产生的数组的形状,因此它假设形状可以从一个元素到下一个元素完全不同。限制这一点的最佳方法是指定可选output_shapes
参数,它接受与生成的元素(和参数)tf.data.Dataset.from_generator()
的结构匹配的形状的嵌套结构。output_types
在这种情况下,您将传递两个形状的元组,可以部分指定。例如,如果x
元素是900 x 1000
数组并且y
元素是标量:
sequence_dataset = tf.data.Dataset.from_generator(
example_generator, (tf.int32, tf.int32),
output_shapes=([900, 1000], []))
推荐阅读
- android - Android 应用程序未成功构建,出现“同步失败”
- html - Outlook 2016 CSS 问题
- odoo - 如何重新激活 odoo 11 或 12 中的会计引擎?
- c - 从语言 L 中的给定正则表达式创建字符串集
- java - 如何使用 Jackson 将蛇案例 yaml 映射到驼峰 Java 字段
- ruby-on-rails - 设计 Saml Authenticatable Gem 错误 - 不支持带有 AuthnContexts 的 AuthnRequest
- time-series - 我使用专家建模器在 SPSS 中创建了一个时间序列预测模型。我有 300 种产品,并希望将模型应用于所有这些产品
- android - 来自 main.xml 的嵌套布局的数据绑定
- shopify - 如何使用 Shopify 中的 API 在订单页面上创建批量操作链接?
- c# - 如何生成 IDbSet 而不是 DbSet