tensorly - 如何从 tensorly 设计对 partial_tucker 函数的测试?
问题描述
我尝试设计一个测试,以验证partial_tucker
tensorly 中的函数是否按我预期的那样工作。换句话说,我想为partial_tucker
函数设计一个输入及其相关的预期输出。
所以,我试图做的是取一个初始随机张量A
(4阶),手动计算它的“低秩”塔克分解,然后重建与初始张量相同形状的张量,比如说A_tilde
。我认为A_tilde
张量是初始张量的“低阶近似” A
。我对么?
然后我想partial_tucker
使用该张量上的函数,A_tilde
我希望结果与我手动计算的 tucker 分解相同。事实并非如此,所以我猜我手工制作的塔克分解是错误的。如果是这样,为什么?
import tensorly
import numpy as np
h, w, c, f = 3, 3, 64, 128
c_prim, f_prim = 16, 32
base_tensor = np.random.rand(h, w, c, f)
# compute tucker decomposition by hand using higher order svd describred here: https://www.alexejgossmann.com/tensor_decomposition_tucker/.
lst_fac = []
for k in [2, 3]:
mod_k_unfold = tensorly.base.unfold(base_tensor, k)
U, _, _ = np.linalg.svd(mod_k_unfold)
lst_fac.append(U)
real_in_fac, real_out_fac = lst_fac[0], lst_fac[1]
real_core = multi_mode_dot(base_tensor, [real_in_fac.T, real_out_fac.T], modes=(2,3))
del base_tensor # no need of it anymore
# what i call the "low rank tucker decomposition"
real_core = real_core[:,:,:c_prim,:f_prim]
real_in_fac = real_in_fac[:, :c_prim]
real_out_fac = real_out_fac[:, :f_prim]
# low rank approximation
base_tensor_low_rank = multi_mode_dot(real_core, [real_in_fac, real_out_fac], modes=(2,3))
in_rank, out_rank = c_prim, f_prim
core_tilde, (in_fac_tilde, out_fac_tilde) = partial_tucker(base_tensor_low_rank, modes=(2, 3), ranks=(in_rank, out_rank), init='svd')
base_tensor_tilde = multi_mode_dot(core_tilde, [in_fac_tilde, out_fac_tilde], modes=(2,3))
assert np.allclose(base_tensor_tilde, base_tensor_low_rank) # this is OK
assert np.allclose(in_fac_tilde, real_in_fac) # this fails
请注意,我试图计算in_fac_tilde.T @ real_in_fac
它是否是身份或类似的东西,我注意到只有第一列在两个矩阵中是共线的,并且与所有其他矩阵正交。
解决方案
您在这里隐含地做了很多假设:您假设,例如,您可以只修剪秩-R 分解以获得秩-(R-1) 分解。这通常是不正确的。另外,请注意,您使用的 Tucker 分解不仅仅是高阶 SVD (HO-SVD)。相反,HO-SVD 用于初始化,然后是高阶正交迭代 (HOOI)。
您还假设低等级分解对于任何给定等级都是唯一的,这将允许您直接比较分解的因素。情况也并非如此(即使在矩阵情况下,并且具有诸如正交性之类的强约束,您仍然会有符号不确定性)。
相反,您可以例如检查相对重建误差。我建议你看看 TensorLy 中的测试。如果您从张量开始,有很多很好的参考资料。例如,科尔达和巴德的开创性作品;特别是对于 Tucker,De Lathauwer 等人的工作(例如,关于张量的最佳低秩近似)等。
推荐阅读
- angular - 加载时未显示 Three.js 纹理
- sql-server - 如果员工没有完成每月工作时间,如何扣除工资?
- url-rewriting - 里程碑的 Azure URL 重写问题
- java - “实际参数列表和形式参数列表的长度不同”
- laravel - laravel 操作数组没有 for 循环以获得更好的性能和优化
- mysql - Mysql加入3个表没有行合并
- angular - 重绘谷歌图表而不刷新页面
- node.js - 如何将模块发送到控制器外部的导入 ES6 类
- python - 如何优化 Pandas DataFrame 速度?
- ruby-on-rails - 如何将响应格式从 fast_jsonapi 格式更改为 AMS gem 响应格式