python - Python中用于非平方成本矩阵的匈牙利算法
问题描述
我想在非方形 numpy 数组上使用 python 中的匈牙利分配算法。
我的输入矩阵X
如下所示:
X = np.array([[0.26, 0.64, 0.16, 0.46, 0.5 , 0.63, 0.29],
[0.49, 0.12, 0.61, 0.28, 0.74, 0.54, 0.25],
[0.22, 0.44, 0.25, 0.76, 0.28, 0.49, 0.89],
[0.56, 0.13, 0.45, 0.6 , 0.53, 0.56, 0.05],
[0.66, 0.24, 0.61, 0.21, 0.47, 0.31, 0.35],
[0.4 , 0.85, 0.45, 0.14, 0.26, 0.29, 0.24]])
X
所需的结果是排序的矩阵,例如X_desired_output
:
X_desired_output = np.array([[0.63, 0.5 , 0.29, 0.46, 0.26, 0.64, 0.16],
[0.54, 0.74, 0.25, 0.28, 0.49, 0.12, 0.61],
[[0.49, 0.28, 0.89, 0.76, 0.22, 0.44, 0.25],
[[0.56, 0.53, 0.05, 0.6 , 0.56, 0.13, 0.45],
[[0.31, 0.47, 0.35, 0.21, 0.66, 0.24, 0.61],
[[0.29, 0.26, 0.24, 0.14, 0.4 , 0.85, 0.45]])
在这里,我想最大化成本而不是最小化,因此算法的输入在理论上要么是,要么是1-X
简单X
的。
我发现https://software.clapper.org/munkres/导致:
from munkres import Munkres
m = Munkres()
indices = m.compute(-X)
indices
[(0, 5), (1, 4), (2, 6), (3, 3), (4, 0), (5, 1)]
# getting the indices in list format
ii = [i for (i,j) in indices]
jj = [j for (i,j) in indices]
我怎样才能使用这些来排序X
?jj
仅包含 6 个元素,而不是原始的 7 列X
。
我正在寻找实际排序的矩阵。
解决方案
在花了几个小时研究它之后,我找到了一个解决方案。问题是由于X.shape[1] > X.shape[0]
某些列根本没有分配,这导致了问题。
该文件指出
“Munkres 算法假设成本矩阵是方形的。但是,如果您首先用 0 值填充它以使其成为方形,则可以使用矩形矩阵。该模块会自动填充矩形成本矩阵以使其成为方形。”
from munkres import Munkres
m = Munkres()
indices = m.compute(-X)
indices
[(0, 5), (1, 4), (2, 6), (3, 3), (4, 0), (5, 1)]
# getting the indices in list format
ii = [i for (i,j) in indices]
jj = [j for (i,j) in indices]
# re-order matrix
X_=X[:,jj] # re-order columns
X_=X_[ii,:] # re-order rows
# HERE IS THE TRICK: since the X is not diagonal, some columns are not assigned to the rows !
not_assigned_columns = X[:, [not_assigned for not_assigned in np.arange(X.shape[1]).tolist() if not_assigned not in jj]].reshape(-1,1)
X_desired = np.concatenate((X_, not_assigned_columns), axis=1)
print(X_desired)
array([[0.63, 0.5 , 0.29, 0.46, 0.26, 0.64, 0.16],
[0.54, 0.74, 0.25, 0.28, 0.49, 0.12, 0.61],
[0.49, 0.28, 0.89, 0.76, 0.22, 0.44, 0.25],
[0.56, 0.53, 0.05, 0.6 , 0.56, 0.13, 0.45],
[0.31, 0.47, 0.35, 0.21, 0.66, 0.24, 0.61],
[0.29, 0.26, 0.24, 0.14, 0.4 , 0.85, 0.45]])
推荐阅读
- spring-boot - 创建 Spring Boot 项目和 Maven
- php - 使用 XDebug 进行调试时,如何定义 PHP 对象在 VSCode 中的显示方式?
- python - 'numpy.float64' 类型的对象没有 len():我该如何解决这个问题?
- oculus - 如何在 Oculus Quest 浏览器中自动打开链接?
- r - 以向量化方式基于数据子集定义多行
- firebase - 从子集合 firestore + FLUTTER 中检索信息
- php - Laravel 关系不适用于 leftJoin
- regex - 带有可选下划线的下划线分隔表名的正则表达式
- hibernate - 如何在 hbm 文件中引用 java.time.LocalDate?
- javascript - 有没有办法通过节点运行 scrollreveal.js?