首页 > 解决方案 > 训练 TensorFlow 模型时在不同的数据集上多次调用 model.fit

问题描述

我一直在使用 TensorFlow 和 Keras 开发 AlphaGo Zero 克隆。它似乎工作正常,但我不确定我用来训练我的模型的代码是否有意义:

    def train_from_game_log(self, game_log):
        half_point = len(game_log['x'])//2
        x = np.array(game_log['x'][half_point:])
        y0 = np.array(game_log['y'][0][half_point:])
        y1 = np.array(game_log['y'][1][half_point:])
        self.model.fit(x, [y0, y1], shuffle=True, batch_size=64, epochs=4)

        x = tf.image.rot90(x, k=1)
        y0 = np.reshape(y0, (-1,self.input_board_size, self.input_board_size, 1))
        y0 = tf.image.rot90(y0, k=1)
        y0 = np.reshape(y0, (-1,self.input_board_size*self.input_board_size))
        self.model.fit(x, [y0, y1], shuffle=True, batch_size=64, epochs=4)

        x = tf.image.rot90(x, k=1)
        y0 = np.reshape(y0, (-1,self.input_board_size, self.input_board_size, 1))
        y0 = tf.image.rot90(y0, k=1)
        y0 = np.reshape(y0, (-1,self.input_board_size*self.input_board_size))
        self.model.fit(x, [y0, y1], shuffle=True, batch_size=64, epochs=4)
        
        x = tf.image.rot90(x, k=1)
        y0 = np.reshape(y0, (-1,self.input_board_size, self.input_board_size, 1))
        y0 = tf.image.rot90(y0, k=1)
        y0 = np.reshape(y0, (-1,self.input_board_size*self.input_board_size))
        self.model.fit(x, [y0, y1], shuffle=True, batch_size=64, epochs=4)

我正在旋转我的数据集并fit()在模型上再次运行。

像这样训练它有意义吗?这会使模型偏向最后一次旋转吗?我应该做一个发电机吗?

标签: pythontensorflowmachine-learningkerasreinforcement-learning

解决方案


推荐阅读