首页 > 解决方案 > 火炬:最小填充张量,使得 num 元素可被 x 整除

问题描述

假设我有一个t任意 ndim 的张量,我想填充(用零)它,以便 a)我在填充后引入尽可能少的 # 元素 b),(t.numel() % x) == 0

是否有比找到最大维度并将其增加 1 直到满足条件 (b) 更好的算法?

也许工作代码:

def pad_minimally(t, x):
    largest_dim = np.argmax(t.shape)
    buffer_shape = list(t.shape)
    new_t = t.clone()
    print(t.shape)
    for n_to_add in range(x):
        if new_t.numel() % x == 0:
            break
        buffer_shape[largest_dim] = n_to_add
        new_buffer = torch.zeros(*buffer_shape)
        new_t = torch.cat([t, new_buffer], axis=largest_dim)
    assert new_t.numel() % x == 0
    return new_t
assert pad_minimally(torch.rand(3,1), 7).shape == (7,1)
assert pad_minimally(torch.rand(3,2), 7).shape == (7,2)
assert pad_minimally(torch.rand(3,2, 6), 7).shape == (3,2,7)

标签: pythonpytorch

解决方案


首先,简单地在最大维度上加一个直到numel可被整除x并不适用于所有情况。例如,如果 is 的形状,t那么(3, 2)我们x = 9希望将其填充t为 be (3, 3), not (9, 2)

更令人担忧的是,不能保证只需要填充一个维度。例如,如果t具有形状(13, 17, 25)x = 8则最佳填充t将是(14, 18, 26)(13, 18, 28)

将其提炼到数学中,问题就变成了

给定正整数,s[1], ..., s[D]找到在可被所有整除的约束条件下q[1], ..., q[D]最小化的正整数。prod(q[i], i=1 to D)prod(q[i], i=1 to D)xq[i] >= s[i]i=1 to D

尽管我不是特别精通非线性整数编程,但我无法开发出有效的解决方案(请参阅更新以获得更有效的解决方案)。也许存在解决这个问题的有效方法。如果是这样,我想它将涉及和/或更好的记忆的主要x因素q。也就是说,可以使用穷举搜索来解决问题,前提是xD(ie len(t.shape)) 足够小(否则算法可能会运行很长时间)。

我提出的蛮力搜索算法迭代x大于或等于的每个倍数,t.numel()并使用深度优先搜索来查看该倍数是否存在填充。一旦找到有效的填充,算法就会结束。这个算法的python代码是:

import numpy as np

def search(shape, target_numel, memory):
    numel = np.prod(shape)
    if numel == target_numel:
        return True
    elif numel < target_numel:
        for idx in range(len(shape)):
            shape[idx] += 1
            if tuple(shape) not in memory:
                if search(shape, target_numel, memory):
                    return True
                memory.add(tuple(s for s in shape))
            shape[idx] -= 1
    return False

def minimal_shape(shape, target_multiple):
    shape = [s for s in shape]
    target_numel = target_multiple * int(np.ceil(max(1, np.prod(shape)) / target_multiple))
    while not search(shape, target_numel, set()):
        target_numel += target_multiple
    return shape

一旦你有了最小的形状,这个pad_minimal函数就可以非常简洁地实现为

def pad_minimally(t, x):
    new_shape = minimal_shape(t.shape, x)
    new_t = t.new_zeros(new_shape)
    new_t[[slice(0, s) for s in t.shape]] = t
    return new_t

我不确定这是否足以满足您的需求。希望其他人可以提供更有效的版本。


一些测试用例minimal_shape

assert minimal_shape([2, 2], 9) == [3, 3]
assert minimal_shape([2, 8], 6) == [2, 9]
assert minimal_shape([13, 17, 25], 8) in [[14, 18, 26], [13, 18, 28]]
assert minimal_shape([5, 13, 19], 6) == [5, 14, 21]

更新

我在 CS.SE 上询问了这个算法。根据我在那里收到的答案以及对该问题的后续更新,以下是minimal_shape.

from functools import reduce
from operator import mul
from copy import deepcopy

def prod(x):
    return reduce(mul, x, 1)

def argsort(x, reverse=False):
    return sorted(range(len(x)), key=lambda idx: x[idx], reverse=reverse)

def divisors(v):
    """ does not include 1 """
    d = {v} if v > 1 else set()
    for n in range(2, int(v**0.5) + 1):
        if v % n == 0:
            d.add(n)
            d.add(v // n)
    return d

def update_memory(b, c_rem, memory):
    tuple_m = tuple(b + [c_rem])
    if tuple_m in memory:
        return False
    memory.add(tuple_m)
    return True

def dfs(a, b, c, c_rem, memory, p_best=float('inf'), b_best=None):
    ab = [ai + bi for ai, bi in zip(a, b)]
    p = prod(ab)
    if p >= p_best:
        return p_best, b_best
    elif p % c == 0:
        return p, deepcopy(b)

    dc = divisors(c_rem)
    for i in argsort(ab):
        for d in dc:
            db = (d - ab[i]) % d
            b[i] += db
            if update_memory(b, c_rem // d, memory):
                p_best, b_best = dfs(a, b, c, c_rem // d, memory, p_best, b_best)
            b[i] -= db

    return p_best, b_best

def minimal_shape(shape, target_multiple):
    a = list(shape)
    b = [0 for _ in range(len(a))]
    c = target_multiple
    _, b = dfs(a, b, c, c, set())
    return [ai + bi for a, b in zip(a, b)]

推荐阅读