首页 > 解决方案 > 是否可以通过 keras 图更改 batch_size 中途?

问题描述

是否可以在图表的中途动态更改批量大小?

我的系统使用单词生成句子表示,然后使用句子生成文档表示。

如果一个文档包含 20 个句子,每个句子有 50 个单词(为简单起见,单词向量大小为 1)。我有 10 个文档的批量大小:

我尝试了一个重塑层和 keras.backend 重塑层,但是 keras 似乎坚持我的批量大小在整个图表中保持不变(200),即使操作本身感觉它们应该是合法的。

实际错误是: ValueError:无法将输入数组从形状(10、20、100)广播到形状(200、20、100)。即在让我重塑我的张量之后,它现在正试图将它改组为批量大小为 200 的张量

标签: pythontensorflowkerastensorflow2.0

解决方案


你可以tf.reshpae在你的模型中使用它,它可以让你改变张量的形状,甚至是批量维度,但你必须使一切保持一致,以便在训练期间数据流是正确的。

这是一个虚拟网络:

from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow as tf

init_batch_sz = 10 # let's assume initial batch size is 10
ip1 = layers.Input((20,10)) 
dense = layers.Dense(10)(ip1)
res = tf.reshape(dense, (init_batch_sz//2, -1, -1))

model = models.Model(ip1, res)
model.summary()
Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_8 (InputLayer)         [(None, 20, 10)]          0         
_________________________________________________________________
dense_1 (Dense)              (None, 20, 10)            110       
_________________________________________________________________
tf_op_layer_Reshape_1 (Tenso [(5, None, None)]         0         
=================================================================
Total params: 110
Trainable params: 110
Non-trainable params: 0

但是你不应该fit用来训练这样的网络,因为你会得到一个错误。

其他一些选项是:

  1. 使用虚拟 1 批次维度。
from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow as tf
import numpy as np


init_batch_sz = 10 # let's assume initial batch size is 10
ip1 = layers.Input((10, 20,10)) 
dense = layers.Dense(10)(ip1)
res = tf.reshape(dense, (-1, init_batch_sz//2, 40, 10)) # you need to make some calculations here to get the correct output_shape

model = models.Model(ip1, res)
model.summary()

x = np.random.randn(1, 10, 20, 10) # dummy 1 batch
y = np.random.randn(1, 5, 40, 10) # dummy 1 batch

model.compile('adam', 'mse')
model.fit(x, y)
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 10, 20, 10)]      0         
_________________________________________________________________
dense_1 (Dense)              (None, 10, 20, 10)        110       
_________________________________________________________________
tf_op_layer_Reshape_1 (Tenso [(None, 5, 40, 10)]       0         
=================================================================
Total params: 110
Trainable params: 110
Non-trainable params: 0
_________________________________________________________________
1/1 [==============================] - 0s 1ms/step - loss: 1.9959

<tensorflow.python.keras.callbacks.History at 0x7f600d0eb630>

推荐阅读