首页 > 解决方案 > 生成器不会遍历所有文件

问题描述

我正在用来自多个.csv文件的数据训练一个模型,我发现我的代码读取了这些文件,但模型仍然在一个文件上训练。我的代码的相关部分是:

def get_data(datasets_path):
    ''' 
    Returns the dataframes.
    '''
    full_path = datasets_path + "*.csv"
    for data_fname in glob.glob(full_path):
            df = pd.read_csv(data_fname)
            processed_df = __preprocessor(df)
            scaler = MinMaxScaler()
            transformed_df = scaler.fit_transform(processed_df)
            return transformed_df


def batch_generator(X, batch_size=16, shuffle=False):
    '''
    Return a random sample from X.
    '''
    count = 0
    while True:
        if shuffle:
            idx = np.random.randint(0, X.shape[0], batch_size)
            data = X[idx]
        else:
            indices = list(n for n in range(X.shape[0]))
            data = X[indices[count*batch_size : (count+1)*batch_size]]
            count +=1
        yield (data, data)

data = get_data(path_to_datasets)
x_train, x_test = train_test_split(data, test_size=0.2, random_state=42, shuffle=False)

x_train = np.expand_dims(x_train, axis=1)
x_test = np.expand_dims(x_test, axis=1)

train_gen = batch_generator(x_train, batch_size=32)
valid_gen = batch_generator(x_test, batch_size=32)

然后我定义一个简单的模型并用

model.fit_generator(
    generator=train_gen,
    epochs=1,
    steps_per_epoch=x_train.shape[0] // 32,
    validation_data=valid_gen,
    validation_steps=x_test.shape[0] // 32)

问题是这似乎是从一个.csv文件中训练出来的,而不是通过所有文件,我不明白为什么。

标签: pythondataframekerasgenerator

解决方案


问题是你在 for 循环中的 return 语句。处理单个文件后,该get_data方法将中断循环。尝试使用 yield 来获取迭代器。

def get_data(datasets_path):
    ''' 
    Returns the dataframes.
    '''
    full_path = datasets_path + "*.csv"
    for data_fname in glob.glob(full_path):
            df = pd.read_csv(data_fname)
            processed_df = __preprocessor(df)
            scaler = MinMaxScaler()
            transformed_df = scaler.fit_transform(processed_df)
            yield transformed_df

推荐阅读