首页 > 解决方案 > scikit-learn:`gen_batches()` 的功能是什么?这个怎么运作?

问题描述

我正在阅读 scikit-learn 的源代码。

条线正在使用 function gen_batches(),然后我试图了解它是如何工作的。

我在doc上搜索了此功能,但没有得到任何结果。

我也试过这个小片段

from sklearn.utils import gen_batches
slices = gen_batches(3,5)
for sl in slices:
    print(sl.start==0)
    print(sl)

也一无所获。

功能是做什么gen_batches()用的?这个怎么运作?

标签: pythonscikit-learn

解决方案


来自 sklearn 源

def gen_batches(n, batch_size, min_batch_size=0):
    """Generator to create slices containing batch_size elements, from 0 to n.
    The last slice may contain less than batch_size elements, when batch_size
    does not divide n.
    Parameters
    ----------
    n : int
    batch_size : int
        Number of element in each batch
    min_batch_size : int, default=0
        Minimum batch size to produce.
    Yields
    ------
    slice of batch_size elements
    Examples
    --------
    >>> from sklearn.utils import gen_batches
    >>> list(gen_batches(7, 3))
    [slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
    >>> list(gen_batches(6, 3))
    [slice(0, 3, None), slice(3, 6, None)]
    >>> list(gen_batches(2, 3))
    [slice(0, 2, None)]
    >>> list(gen_batches(7, 3, min_batch_size=0))
    [slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
    >>> list(gen_batches(7, 3, min_batch_size=2))
    [slice(0, 3, None), slice(3, 7, None)]
    """

所以本质上,它是一个生成批次的工具。例如,

>>> X = np.random.random((10,3))
>>> X
array([[0.66955147, 0.10954688, 0.41856203],
       [0.23409169, 0.20154919, 0.49110055],
       [0.40495903, 0.66112904, 0.32610395],
       [0.22084787, 0.47966598, 0.10281514],
       [0.75948756, 0.11656251, 0.56470397],
       [0.90018421, 0.13771094, 0.22860183],
       [0.12720045, 0.58558546, 0.32475034],
       [0.21623059, 0.04038225, 0.03538428],
       [0.11403724, 0.8097086 , 0.9633516 ],
       [0.85671638, 0.44873045, 0.39033928]])
>>> 
>>> 
>>> from sklearn.utils import gen_batches
>>> 
>>> 
>>> slices = gen_batches(10,2)
>>> for s in slices:
...     print(X[s])
... 
[[0.66955147 0.10954688 0.41856203]
 [0.23409169 0.20154919 0.49110055]]
[[0.40495903 0.66112904 0.32610395]
 [0.22084787 0.47966598 0.10281514]]
[[0.75948756 0.11656251 0.56470397]
 [0.90018421 0.13771094 0.22860183]]
[[0.12720045 0.58558546 0.32475034]
 [0.21623059 0.04038225 0.03538428]]
[[0.11403724 0.8097086  0.9633516 ]
 [0.85671638 0.44873045 0.39033928]]
>>> 

当批量大小不划分n时,最后一批的元素较少。

>>> slices = gen_batches(10,8)
>>> for s in slices:
...     print(X[s])
... 
[[0.66955147 0.10954688 0.41856203]
 [0.23409169 0.20154919 0.49110055]
 [0.40495903 0.66112904 0.32610395]
 [0.22084787 0.47966598 0.10281514]
 [0.75948756 0.11656251 0.56470397]
 [0.90018421 0.13771094 0.22860183]
 [0.12720045 0.58558546 0.32475034]
 [0.21623059 0.04038225 0.03538428]]
[[0.11403724 0.8097086  0.9633516 ]
 [0.85671638 0.44873045 0.39033928]]

推荐阅读