python - 类型错误:call() 缺少 1 个必需的位置参数:'state_c'
问题描述
我正在尝试注意实现解码器功能,但我收到错误消息:
TypeError: call() missing 1 required positional argument: 'state_c'"
class Decoder(tf.keras.Model):
def __init__(self,out_vocab_size, embedding_dim, output_length, dec_units ,score_fun ,att_units):
#Intialize necessary variables and create an object from the class onestepdecoder
super(Decoder, self).__init__()
self.vocab_size = out_vocab_size
self.embedding_dim = embedding_dim
self.out_length = output_length
self.dec_units = dec_units
self.score_fun = score_fun
self.att_units = att_units
self.stepdec = OneStepDecoder(self.vocab_size, self.embedding_dim,self.out_length, self.dec_units,self.score_fun, self.att_units)
def call(self, input_to_decoder,encoder_output,decoder_hidden_state,state_c ):
all_outputs= tf.TensorArray(tf.float32, size = input_to_decoder.shape[1], name="output_arrays")
for timestep in range(input_to_decoder.shape[1]):
output, input_state = self.stepdec(input_to_decoder[:,timestep:timestep+1],state_c, encoder_output)
all_outputs = all_outputs.write(timestep, output)
all_outputs = tf.transpose(all_outputs.stack(), [1, 0, 2])
return all_outputs
def grader_decoder(score_fun):
out_vocab_size = 13
embedding_dim = 12
input_length = 10
output_length = 11
dec_units = 16
att_units = 16
batch_size = 32
target_sentences = tf.random.uniform(shape=(batch_size,output_length),maxval=10,minval=0,dtype=tf.int32)
encoder_output = tf.random.uniform(shape=[batch_size,input_length,dec_units])
state_h = tf.random.uniform(shape=[batch_size,dec_units])
state_c = tf.random.uniform(shape=[batch_size,dec_units])
decoder = Decoder(out_vocab_size, embedding_dim, output_length, dec_units ,score_fun ,att_units)
output = decoder(target_sentences,encoder_output, state_h, state_c)
assert(output.shape==(batch_size,output_length,out_vocab_size))
return True
print(grader_decoder('dot'))
print(grader_decoder('general'))
print(grader_decoder('concat'))
任何建议或解决方案将不胜感激。
解决方案
推荐阅读
- angular - 如何在 Angular HTTP post 方法标头中插入用户输入?
- android - 按三星bixby按钮发送tcp包
- java - 将字符串值存储到 JSON 类型的字段中
- c++ - 在 DirectX 11 中从 GPU 读回顶点缓冲区(并获取顶点)
- node.js - readXlsxFile在节点js中乱序读取excel表
- python - 如何将 Telegram API 中的 contacts.getLocated() 与 Telethon 一起使用?
- javascript - 获取用户位置并显示数据
- java - 当小部件大小更改时,SWT GridLayout 列不会调整大小
- c# - gRPC 和域驱动设计 - 在哪里放置 proto 文件(域层与其他地方)?
- javascript - 我需要构建一个 onScroll 函数来检测屏幕上最大的 html 元素。基于数组中的 IDS