python - ValueError: Shape must be rank 1 but is rank 2 when doing tf.einsum('i,j->ij',u ,j)
问题描述
我已经用 tf.keras 对这个模型进行了编码,
import tensorflow as tf
from tensorflow import einsum
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Flatten , Dot
from tensorflow.keras.layers import Embedding, Multiply, Dense, Input
from tensorflow.keras import Model
from tensorflow.keras.layers import concatenate
from tensorflow.keras.models import load_model
num_items = 1250
num_users = 1453
emb_size = 32
input_userID = Input(shape=[1], name='user_ID')
input_itemID = Input(shape=[1], name='item_ID')
user_emb_GMF = Embedding(num_users, emb_size, name='user_emb_GMF')(input_userID)
item_emb_GMF = Embedding(num_items, emb_size, name='item_emb_GMF')(input_itemID)
flat_u_GMF = Flatten()(user_emb_GMF)
flat_i_GMF = Flatten()(item_emb_GMF)
interraction_map = einsum('i,j->ij',flat_u_GMF ,flat_i_GMF) # output[i,j] = u[i]*v[j]
layer = Dense(16, activation='relu', name='hidden_layer' )(interraction_map)
out = Dense(1,activation='sigmoid',name='output')(layer)
oncf_model = Model([input_userID, input_itemID], out)
tf.keras.utils.plot_model(oncf_model, show_shapes=True)
基本上我想得到 user_emb_GMF 和 item_emb_GMF 的外积(这是一个矩阵),我得到了错误:
InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs, op_def)
1811 try:
-> 1812 c_op = pywrap_tf_session.TF_FinishOperation(op_desc)
1813 except errors.InvalidArgumentError as e:
InvalidArgumentError: Shape must be rank 1 but is rank 2
for 0th input and equation: i,j->ij for '{{node Einsum_2}} = Einsum[N=2, T=DT_FLOAT, equation="i,j->ij"](flatten_10/Reshape, flatten_11/Reshape)' with input shapes: [?,32], [?,32].
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
9 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs, op_def)
1813 except errors.InvalidArgumentError as e:
1814 # Convert to ValueError for backwards compatibility.
-> 1815 raise ValueError(str(e))
1816
1817 return c_op
ValueError: Shape must be rank 1 but is rank 2
for 0th input and equation: i,j->ij for '{{node Einsum_2}} = Einsum[N=2, T=DT_FLOAT, equation="i,j->ij"](flatten_10/Reshape, flatten_11/Reshape)' with input shapes: [?,32], [?,32].
我想知道如何解决这个问题
解决方案
如果 interraction_map 的所需输出是(num_batch,emb_size,emb_size,1)
您可以简单地使用 keras Dot 层,然后添加维度
这样就不需要embedding的flattening
num_items = 1250
num_users = 1453
emb_size = 32
input_userID = Input(shape=[1], name='user_ID')
input_itemID = Input(shape=[1], name='item_ID')
user_emb_GMF = Embedding(num_users, emb_size, name='user_emb_GMF')(input_userID)
item_emb_GMF = Embedding(num_items, emb_size, name='item_emb_GMF')(input_itemID)
interraction_map = tf.expand_dims(Dot(axes=1)([user_emb_GMF,item_emb_GMF]), -1)
conv = Conv2D(32, 2, activation='relu', padding="SAME")(interraction_map)
pool = GlobalMaxPool2D()(conv)
layer = Dense(16, activation='relu', name='hidden_layer' )(pool)
out = Dense(1,activation='sigmoid',name='output')(layer)
oncf_model = Model([input_userID, input_itemID], out)
oncf_model.summary()
推荐阅读
- postgresql - 无法连接到从 stable/postgresql helm chart 安装的 postgres
- sql-server - 气流中的子进程命令无法找到可执行文件和相关文件
- ios - 如何从一系列 NavigationController 中取出 UIViewController?
- ruby-on-rails - 如何更新 Rails 中的 ActiveStorage blob?
- python-3.x - 在双轴图上旋转 matplotlib x 轴
- ios - Swift 中的 socket.io 问题
- javascript - 我需要用一个处理程序处理不同的输入 onChange 方法
- python - 消除元组python中的重复项
- python - 是什么导致了禁止的条件语句?
- java - Thread.sleep(1500) 在 driver.manage().timeouts().implicitlyWait(2,TimeUnit.SECONDS) 不起作用的地方工作