首页 > 解决方案 > 将整个数据帧(不仅是数据和标签)传递给 Tensorflow 的 model.fit 的工作原理 - 即如何使用未显式调用的类函数

问题描述

我正在看这个教程:如何使用 TF-Hub 解决 Kaggle 上的问题

我有一个问题理解MyModel

class MyModel(tf.keras.Model):
  def __init__(self, hub_url):
    super().__init__()
    self.hub_url = hub_url
    self.embed = hub.load(self.hub_url).signatures['default']
    self.sequential = tf.keras.Sequential([
      tf.keras.layers.Dense(500),
      tf.keras.layers.Dense(100),
      tf.keras.layers.Dense(5),
    ])

  def call(self, inputs):
    phrases = inputs['Phrase'][:,0]
    embedding = 5*self.embed(phrases)['default']
    return self.sequential(embedding)

  def get_config(self):
    return {"hub_url":self.hub_url}

我关心callget_config功能。模型定义和编译:

model = MyModel("https://tfhub.dev/google/nnlm-en-dim128/1")
model.compile(
    loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.optimizers.Adam(), 
    metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])

当我们将数据传递给 时model.fit,我们传递的是一个完整的字典dict(train_df)

history = model.fit(x=dict(train_df), y=train_df['Sentiment'],
          validation_data=(dict(validation_df), validation_df['Sentiment']),
          epochs = 25)

为了处理这个,似乎call是需要的。这一步什么时候完成?

我习惯于将数据传递给train,例如这里这里所做的。我很好奇为什么在上面的例子中传递整个数据框是有效的,我认为这与理解函数的调用方式有关。labelmodel.fitcall

标签: pythonpython-3.xclasstensorflowkeras

解决方案


推荐阅读