python - 具有未知 batch_size 的 Keras repeat_elements
问题描述
(?,61,80)
我有一个函数,我需要用 2D 张量的大小来做 Keras batch_dot 的大小(40,61)
。维度?
用于自定义层中的批量大小。在使用 Kerasrepeat_elements
时,我们需要指定批量大小以使其成为(batch_size, 40,61)
. 但是,repeat_elements
不适用于?
批量大小。
代码是
M1 = K.expand_dims(M,axis=0)
BatchM = K.repeat_elements(x=M1,rep=batch_size,axis=0)
out1 = K.batch_dot(BatchM,Ash1,axes=[2,1])
这M
是 size 的二维张量(40,61)
。BatchM
应该给出(batch_size,40,61)
并且Ash1
是大小(?,61,80)
。
编辑1:
A= Input(shape=(61,80))
M= K.variable(np.random.rand(40,61))
n=1
import tensorflow as tf
M1 = K.expand_dims(M,axis=0)
BatchM = K.repeat_elements(x=M1,rep=tf.shape(A)[0],axis=0)
out1 = K.batch_dot(BatchM,Ash1,axes=[2,1])
此返回错误显示:
Traceback (most recent call last)
File "<ipython-input-7-edc5ef31181b>", line 3, in <module>
BatchM = K.repeat_elements(x=M1,rep=tf.shape(A)[0],axis=0)
File "/home/hanumant/.conda/envs/kerasenv/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2092, in repeat_elements
x_rep = [s for s in splits for _ in range(rep)]
File "/home/hanumant/.conda/envs/kerasenv/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2092, in <listcomp>
x_rep = [s for s in splits for _ in range(rep)]
TypeError: 'Tensor' object cannot be interpreted as an integer
解决方案
事实上,你不需要repeat_elements
用未知的batch_size。您可以将K.dot()
andK.permute_dimensions
直接用于相同的目的。
def customer_dot(a,b):
a = K.permute_dimensions(a, (0, 2, 1)) # x = (?,80,61)
b = K.permute_dimensions(b, (1, 0)) # kernel = (61,40)
ab_dot = K.permute_dimensions(K.dot(a, b), (0, 2, 1)) # ab_dot = (?,40,80)
return ab_dot
A = Input(shape=(61,80))
M = K.variable(np.random.rand(40,61))
result = customer_dot(A,M)
print(result.shape)
# print
(?, 40, 80)
并且你可以通过下面的例子看到结果和你的代码操作的结果是一样的。
# print
A = K.constant(np.random.rand(3,2,4))
M = K.constant(np.random.rand(5,2))
M1 = K.expand_dims(M,axis=0)
BatchM = K.repeat_elements(x=M1,rep=K.int_shape(A)[0],axis=0)
out1 = K.batch_dot(BatchM,A,axes=[2,1])
print(K.eval(out1))
result = customer_dot(A,M)
print(K.eval(result))
[[[0.07588554 0.19896106 0.4122516 0.16694324]
[0.02837059 0.07994501 0.15250334 0.05631477]
[0.02922964 0.03180532 0.17185953 0.11346529]
[0.24399586 0.64474815 1.3240533 0.53126353]
[0.06582426 0.0952256 0.38014278 0.22963922]]
[[0.05856805 0.31629622 0.37190455 0.15167782]
[0.02006819 0.12145159 0.1384899 0.0497717 ]
[0.03729554 0.09602766 0.14768752 0.11432388]
[0.18666261 1.0198846 1.1952925 0.481425 ]
[0.07623056 0.2298356 0.33025196 0.22802524]]
[[0.29545793 0.27023914 0.14775626 0.22487558]
[0.10839225 0.10083499 0.05140937 0.07595014]
[0.13047284 0.10567644 0.08779343 0.15208915]
[0.9481214 0.868726 0.47162086 0.7157058 ]
[0.28504598 0.23714545 0.18145116 0.30803293]]]
[[[0.07588554 0.19896106 0.4122516 0.16694324]
[0.02837059 0.07994501 0.15250334 0.05631477]
[0.02922964 0.03180532 0.17185953 0.11346529]
[0.24399586 0.64474815 1.3240533 0.53126353]
[0.06582426 0.0952256 0.38014278 0.22963922]]
[[0.05856805 0.31629622 0.37190455 0.15167782]
[0.02006819 0.12145159 0.1384899 0.0497717 ]
[0.03729554 0.09602766 0.14768752 0.11432388]
[0.18666261 1.0198846 1.1952925 0.481425 ]
[0.07623056 0.2298356 0.33025196 0.22802524]]
[[0.29545793 0.27023914 0.14775626 0.22487558]
[0.10839225 0.10083499 0.05140937 0.07595014]
[0.13047284 0.10567644 0.08779343 0.15208915]
[0.9481214 0.868726 0.47162086 0.7157058 ]
[0.28504598 0.23714545 0.18145116 0.30803293]]]
推荐阅读
- r - 在 R 中处理 XML 数据库 - 如何处理丢失的节点
- excel - 闪烁屏幕 VBA 宏
- discord.py - 如何让我的不和谐机器人说出我说的话然后删除我的消息
- javascript - babel编译后“未捕获的语法错误:无法在模块外使用导入语句”
- javascript - 访问和打印嵌套的 JSON 对象
- c# - 一个事件多个消费者(只有一个消费者工作) Rabbitmq
- python - FileNotFoundError: [WinError 2] 系统找不到指定的文件:'demo.html'
- visual-studio-code - Visual Studio 代码 - Elixir 格式化程序不工作,尝试使用更漂亮的代替
- python - 从 Excel 单元格中删除特定数字字符向后计数
- lua - 如何调试scrapy-splash冻结的位置?