python - 生成器不会遍历所有文件
问题描述
我正在用来自多个.csv
文件的数据训练一个模型,我发现我的代码读取了这些文件,但模型仍然在一个文件上训练。我的代码的相关部分是:
def get_data(datasets_path):
'''
Returns the dataframes.
'''
full_path = datasets_path + "*.csv"
for data_fname in glob.glob(full_path):
df = pd.read_csv(data_fname)
processed_df = __preprocessor(df)
scaler = MinMaxScaler()
transformed_df = scaler.fit_transform(processed_df)
return transformed_df
def batch_generator(X, batch_size=16, shuffle=False):
'''
Return a random sample from X.
'''
count = 0
while True:
if shuffle:
idx = np.random.randint(0, X.shape[0], batch_size)
data = X[idx]
else:
indices = list(n for n in range(X.shape[0]))
data = X[indices[count*batch_size : (count+1)*batch_size]]
count +=1
yield (data, data)
和
data = get_data(path_to_datasets)
x_train, x_test = train_test_split(data, test_size=0.2, random_state=42, shuffle=False)
x_train = np.expand_dims(x_train, axis=1)
x_test = np.expand_dims(x_test, axis=1)
train_gen = batch_generator(x_train, batch_size=32)
valid_gen = batch_generator(x_test, batch_size=32)
然后我定义一个简单的模型并用
model.fit_generator(
generator=train_gen,
epochs=1,
steps_per_epoch=x_train.shape[0] // 32,
validation_data=valid_gen,
validation_steps=x_test.shape[0] // 32)
问题是这似乎是从一个.csv
文件中训练出来的,而不是通过所有文件,我不明白为什么。
解决方案
问题是你在 for 循环中的 return 语句。处理单个文件后,该get_data
方法将中断循环。尝试使用 yield 来获取迭代器。
def get_data(datasets_path):
'''
Returns the dataframes.
'''
full_path = datasets_path + "*.csv"
for data_fname in glob.glob(full_path):
df = pd.read_csv(data_fname)
processed_df = __preprocessor(df)
scaler = MinMaxScaler()
transformed_df = scaler.fit_transform(processed_df)
yield transformed_df
推荐阅读
- javascript - “未定义”将出现在我的条形图中 react-chartjs
- r - 如何在 R 的每次迭代中打印 GA 算法中使用的函数的值?
- r - 有没有一种简单的方法可以在 R 中标记异常值?
- python-3.x - 如何使用其中一行中可用的 TIMESTAMP 值在数据框中为所有行设置日期
- boost - 如何全局指定 boost::asio::streambuf bufferstrm 大小
- mapbox - Mapbox GL聚类:缩写标签文本中的累积值
- angular - Angular Search Pipe,西里尔文支持
- sql - 用作查询变量以返回结果的函数
- c++ - 如何读取文件并比较用户输入以查看它是否匹配?C++
- java - java中带有“卡片”的可滚动JPanel