首页 > 解决方案 > 是否可以在 Google 的 Jax 机器学习库中使用对象

问题描述

我正在尝试使用 Google 的 Jax 机器学习库编写 DC Gan 网络。为此,我创建了对象作为鉴别器和生成器,但是,当我测试鉴别器时,我得到了错误:

    TypeError: Argument '<__main__.Discriminator object at 0x7fdfa5c6ffd0>' of type <class '__main__.Discriminator'> is not a valid JAX type

我查看了 Jax github 页面上的示例,据我所见,那里的示例都没有使用对象,这使我假设可能无法在 Jax 中使用对象。但如果是这样的话,我真的不明白为什么不能使用对象,这会是将来实现的东西吗?我只是天真地忽略了一些东西吗?

这是我的鉴别器对象:

class Discriminator():
    def __init__(self):
        self.step_size = 0.0001
        self.image_shape = (256,256,3)
        self.params = []
        num_layers = 6
        num_filters = 64
        filter_size = 4
        self.params.append(create_conv_layer(3, 
                                             num_filters, 
                                             filter_size, 
                                             filter_size, 
                                             random.PRNGKey(0)))
        for l in range(1, num_layers):
            self.params.append(create_conv_layer(64*2**(l-1), 
                                                 64*2**l, 
                                                 filter_size,   
                                                 filter_size, 
                                                 random.PRNGKey(0)))
        self.params.append(create_conv_layer(64*2**num_filters, 
                                             1, 
                                             filter_size, 
                                             filter_size, 
                                             random.PRNGKey(0)))

    def predict(self):
        activations = image
        for w, b in params[:-1]:
            outputs = conv_forward(activations,w,b,stride=2)
            outputs = batch_normalization(outputs)
            activations = leaky_relu(outputs)
        final_w, final_b = params[-1]
        return sigmoid(conv_forward(activations,final_w,final_b,))

    def batched_predict(self, images):
        shape = [None] + list(self.image_shape)
        return vmap(self.predict, in_axes=shape)(self.params, images)

    def loss(self, params, images, targets):
        preds = self.batched_predict(params, images)
        return -np.sum(preds * targets)

    def accuracy(self, images, targets):
        predicted_class = np.round(np.ravel(batched_predict(images)))
        return np.mean(predicted_class == target_class)

    @jit
    def update(self, params, x, y):
        grads = grad(self.loss)(params, x, y)
        return [(w - self.step_size * dw, b - self.step_size * db)
                for (w, b), (dw, db) in zip(params, grads)]

我在这里更新参数:

num_epochs = 5
batch_size = 64
steps_per_epoch = train_images.shape[0] // batch_size
discrim = Discriminator()
params = discrim.params

print("lets-a-go!")
for epoch in range(num_epochs):
    start_time = time.time()
    for step in range(steps_per_epoch):
        x, y = simple_data_generator(batch_size)
        params = discrim.update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = discrim.accuracy(train_images, train_labels)
    test_acc = discrim.accuracy(test_images, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

标签: pythonmachine-learninggoogle-jax

解决方案


推荐阅读