首页 > 解决方案 > symengine 中的 Argmax 或等效替代方案

问题描述

我正在对非线性系统网络进行简单的模拟。特别是我有 N 个节点,每个节点由 m 个单元组成。每个单元的输出函数既取决于它的活动,也取决于同一节点中其他单元的活动。

我实现的模拟是在 scipy + jitcode 中。

我实现的第一个版本是根据 softmax 分布,因此我实现了这个简单的函数来计算每个单元的输出。

def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
    sum_hc = 0
    for unit in node:
        sum_hc += symengine.exp(unit * G)
    for unit in node:
        act.append(symengine.exp(unit * G)/sum_hc)
return act

现在,我想用一个简单的函数替换上面的函数,对于每个节点,为活动度最高的单元输出 1,在其他单元中输出 0。长话短说,对于每个节点,只有一个单元输出 1。

我现在面临的主要问题是如何使用 symengine 执行此操作,以便 jitcode 可以使用它。我在下面实现的功能由于明显的原因不起作用。我猜 if 条件不是很有象征意义。

def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
    max_act = symengine.Max(*node)
    for unit in node:
        if unit >= max_act:
            act.append(1)
        else:
            act.append(0)           
return act

我没有找到任何 symengine.argmax() 函数或任何智能替代解决方案。你有什么建议吗?

更新

def max_activation(activities):
    act = []

for hc in activities:
    sum_hc = 0
    max_act = symengine.Max(*hc)
    for mc in hc:
        act.append(symengine.GreaterThan(mc, max_act))
    print(act)
return act

测试这个功能:

    max_activation([[y(1), y(2)], [y(3), y(4)]])

我得到以下有希望的输出。一旦我有一些测试,我会更新。

[max(y(2), y(1)) <= y(1), max(y(2), y(1)) <= y(2)]

[max(y(4), y(3)) <= y(3), max(y(4), y(3)) <= y(4)]

标签: maxsimulationjitargmaxsymengine

解决方案


推荐阅读