首页 > 解决方案 > `model.fit()` 输入而不是集合的生成器?(训练数据太大而无法存储)

问题描述

tf-learn 模型的拟合函数可以像这样传递训练和测试数据:

model = tflearn.DNN(nn)

model.fit({'input': X_train},
          {'targets': Y_train},
          n_epoch=10,
          validation_set=(
              {'input': X_test},
              {'targets': Y_test}
          ))

哪里nn是模型的定义。但是,如果诸如此类的集合X_train太大而无法存储怎么办?

在我的例子中,我以索引列表和向量维数的整数编码的形式压缩了稀疏二进制向量(单元格为 1 或 0),这使我能够重建原始向量。

压缩向量的集合确实适合内存,而不是包含完整向量的集合。因此,我尝试将生成器传递给X_train其他集合(现在包含压缩向量),它们可以即时重建完整向量,但model.fit需要一个len()函数。所以我定义了一个自定义的 feeder 类,如下所示:

class Feeder:
    def __init__(self, data, convert):
        self.data = data
        self.convert = convert

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return self

    def __next__(self):
        for item in self.data:
            yield self.convert(item)

我这样称呼:

def reconstruct_vector(non_zero_indices, dimensionality):
    """
    returns a vector of zeros and ones reconstructed from a sparse vector
    and a dimensions value
    """
    vec = np.zeros(dimensionality)
    for i in non_zero_intraining data too large for memorydices:
        vec[i] = 1
    return vec

item_to_input_vector = lambda item : reconstruct_vector(item[0], item[1])
item_to_target_vector = lambda item : np.array([1,0]) if item else np.array([0,1])

model.fit({'input': Feeder(X_train, item_to_input_vector)},
          {'targets': Feeder(Y_train, item_to_target_vector)},
          n_epoch=10,
          validation_set=(
              {'input': Feeder(X_test, item_to_input_vector)},
              {'targets': Feeder(Y_test, item_to_target_vector)}
          ))

但这也不起作用,因为我得到了一些神秘的错误:

Exception in thread Thread-3:
Traceback (most recent call last):
  File "/usr/lib64/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/usr/lib64/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.6/site-packages/tflearn/data_flow.py", line 187, in fill_feed_dict_queue
    data = self.retrieve_data(batch_ids)
  File "/usr/lib/python3.6/site-packages/tflearn/data_flow.py", line 222, in retrieve_data
    utils.slice_array(self.feed_dict[key], batch_ids)
  File "/usr/lib/python3.6/site-packages/tflearn/utils.py", line 187, in slice_array
    return X[start]
TypeError: only integer scalar arrays can be converted to a scalar index

那么,解决这个问题的正确方法是什么?

标签: numpytensorflowmemorytflearn

解决方案


推荐阅读