python - scikit-learn:`gen_batches()` 的功能是什么?这个怎么运作?
问题描述
我正在阅读 scikit-learn 的源代码。
这条线正在使用 function gen_batches()
,然后我试图了解它是如何工作的。
我在doc上搜索了此功能,但没有得到任何结果。
我也试过这个小片段
from sklearn.utils import gen_batches
slices = gen_batches(3,5)
for sl in slices:
print(sl.start==0)
print(sl)
也一无所获。
功能是做什么gen_batches()
用的?这个怎么运作?
解决方案
来自 sklearn 源
def gen_batches(n, batch_size, min_batch_size=0):
"""Generator to create slices containing batch_size elements, from 0 to n.
The last slice may contain less than batch_size elements, when batch_size
does not divide n.
Parameters
----------
n : int
batch_size : int
Number of element in each batch
min_batch_size : int, default=0
Minimum batch size to produce.
Yields
------
slice of batch_size elements
Examples
--------
>>> from sklearn.utils import gen_batches
>>> list(gen_batches(7, 3))
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
>>> list(gen_batches(6, 3))
[slice(0, 3, None), slice(3, 6, None)]
>>> list(gen_batches(2, 3))
[slice(0, 2, None)]
>>> list(gen_batches(7, 3, min_batch_size=0))
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
>>> list(gen_batches(7, 3, min_batch_size=2))
[slice(0, 3, None), slice(3, 7, None)]
"""
所以本质上,它是一个生成批次的工具。例如,
>>> X = np.random.random((10,3))
>>> X
array([[0.66955147, 0.10954688, 0.41856203],
[0.23409169, 0.20154919, 0.49110055],
[0.40495903, 0.66112904, 0.32610395],
[0.22084787, 0.47966598, 0.10281514],
[0.75948756, 0.11656251, 0.56470397],
[0.90018421, 0.13771094, 0.22860183],
[0.12720045, 0.58558546, 0.32475034],
[0.21623059, 0.04038225, 0.03538428],
[0.11403724, 0.8097086 , 0.9633516 ],
[0.85671638, 0.44873045, 0.39033928]])
>>>
>>>
>>> from sklearn.utils import gen_batches
>>>
>>>
>>> slices = gen_batches(10,2)
>>> for s in slices:
... print(X[s])
...
[[0.66955147 0.10954688 0.41856203]
[0.23409169 0.20154919 0.49110055]]
[[0.40495903 0.66112904 0.32610395]
[0.22084787 0.47966598 0.10281514]]
[[0.75948756 0.11656251 0.56470397]
[0.90018421 0.13771094 0.22860183]]
[[0.12720045 0.58558546 0.32475034]
[0.21623059 0.04038225 0.03538428]]
[[0.11403724 0.8097086 0.9633516 ]
[0.85671638 0.44873045 0.39033928]]
>>>
当批量大小不划分n
时,最后一批的元素较少。
>>> slices = gen_batches(10,8)
>>> for s in slices:
... print(X[s])
...
[[0.66955147 0.10954688 0.41856203]
[0.23409169 0.20154919 0.49110055]
[0.40495903 0.66112904 0.32610395]
[0.22084787 0.47966598 0.10281514]
[0.75948756 0.11656251 0.56470397]
[0.90018421 0.13771094 0.22860183]
[0.12720045 0.58558546 0.32475034]
[0.21623059 0.04038225 0.03538428]]
[[0.11403724 0.8097086 0.9633516 ]
[0.85671638 0.44873045 0.39033928]]
推荐阅读
- c# - 如何在 .net 核心 API 中通过 API-Key 或用户凭据处理身份验证?
- azure - 在 Azure Functions 中,我们将要放在 Azure 上的设置放在哪里?
- chart.js - Chart.js 2 个带有单独数据集标签的折线图
- c++ - 结构的大小和对齐方式
- java - 在 Java 中比较对象是否相等时出错
- javascript - 当每个数组项都是函数调用时如何调用每个数组项
- javascript - React - 道具未定义
- r - 如何增加地块周围的填充?
- php - 在php中完成该过程后如何在SQLite表中插入?
- java - 如何在java中编写偶数或奇数程序?