python - 填充批次不处理最后一批小于批次大小的数据
问题描述
我有一个数据集,其中有图像,每张图像有 10 个问题,因此有 10 个答案。我已经成功地训练并检查了模型。该模型由两个输入组成,一个输入作为 CNN 的图像,另一个输入作为 LSTM 的问题。因此,对于每张图片,我都会提出 10 个问题。然后将两者的结果连接起来并输入到 FC 层。
考虑到我的批量大小为 64,我将输入 64 张图像和 640 个问题。在连接时,我需要使维度axis:0
相等,以避免由于不同维度导致的连接错误。因此,我将 CNN 网络输出展平并重复 10 次,然后将其连接到 LSTM 输出。
在model_rn.py中,我执行以下操作:
class Model:
def __init__(self):
self.img = tf.placeholder(
name='img',
dtype=tf.float32,
shape=[self.batch_size, self.img_size, self.img_size, 3]
)
self.q = tf.placeholder(
name='ques',
dtype=tf.float32,
shape=[self.batch_size * 10, self.ques_dim]
)
self.ans = tf.placeholder(
name='ans',
dtype=tf.float32,
shape=[self.batch_size * 10, self.ans_dim]
)
# and some more class variables
self.build()
def build(self):
def cnn(img, q, scope):
# some Conv2D and BatchNormalization
flat = Flatten(name='flatten')(bn_4) # layer where data is flattened before concatenate
flat = tf.keras.backend.repeat_elements(flat, 10, axis=0) # repeat 10 times
# some statements to feed data into LSTM and CNN
然后我加载我的模型并尝试在包含 20 个图像、200 个问题和 200 个答案的测试数据集上运行它。但后来我得到了错误:
ValueError:无法为具有形状“(640、128、128、3)”的张量“img_1:0”提供形状(20、128、128、3)的值
从我使用的测试数据集中提供批次padded_batch
。
dataset_img = Dataset.from_tensor_slices((images)).padded_batch(
64, padded_shapes=(128, 128, 3)
)
dataset_ques = Dataset.from_tensor_slices((questions)).padded_batch(
64 * 10, padded_shapes=(14)
)
dataset_ans = Dataset.from_tensor_slices((answers)).padded_batch(
64 * 10, padded_shapes=(22)
)
有人可以帮我弄这个吗?
谢谢!
解决方案
您已将占位符限制为恰好采用“batch_size”行数。要获取任何行,您可以像这样创建占位符
self.img = tf.placeholder(
name='img',
dtype=tf.float32,
shape=[None, self.img_size, self.img_size, 3]
)
self.q 和 self.ans 也是如此
推荐阅读
- android - 如何在 Kotlin 中获取 Retrofit 的原始 json 响应?
- docusignapi - 如何使用信封-java从docusign下载签名文档
- pyspark - 用点“。”计算一列数据框的 approxQuantile
- java-native-interface - Android JNI 开发时我调用 (*env)->CallVoidMethod 函数结果崩溃
- laravel - 如何设置 nginx laravel + vue cli 项目
- javascript - 如何在three.js中保存加载的模型?
- javascript - 如何使用 javascript 重置 sap ui5 输入字段?
- php - Codeception 未定义索引:ELEMENT 错误
- vba - 如何使用 VBA 在 Sharepoint Online 中删除 Excel 文件
- python - 我们如何从pickle文件中获取注释,这将告诉我们存储的对象数量和pickle文件中的详细信息?