lua - 从 python 到 lua 的翻译错误。我的错误在哪里?
问题描述
在询问并删除了几个问题后,我不知道代码中的错误在哪里。预期的输出是:
[0.6759289682539686, 0.6759289682539686, 0.6759289682539686, 0.6759289682539686, 0.6759289682539686, 0.31500873015873027, 0.12156230158730162, 0.5246873015873018, 0.5989928571428574, 0.060103968253968264]
但我的lua代码只生成:
3.2328
3.2328
3.2328
3.2328
3.2328
3.2328
3.2328
3.2328
3.2328
3.2328
我真的不知道我在哪里犯了错误。或许你们能看出来。我尝试了一些更改,但始终没有什么不同。我也不明白为什么每一行都有相同的数字。
有效的python代码:
from math import prod
from fractions import Fraction
def bitstrings(n) :
"""Return all possible bitstrings of length n"""
if n == 0 :
yield []
return
else :
for b in [0,1] :
for x in bitstrings(n-1) :
yield [b] + x
def prob_selected(weights, num_selected = 5) :
# P(n generated, including e)*P(e of n selected | n generated including e)
# i.e. Sum_n (n generated, including e) * #num_selections / #generated
# num_selected = how many will be drawn out of the hat (at most)
n = len(weights)
final_probability = [0] * n
for bits in bitstrings(n) :
num_generated = sum(bits)
prob_generated = prod([w if b else (1-w) for (w,b) in zip(weights, bits)])
for i in range(n) :
if bits[i] :
final_probability[i] += prob_generated * min(num_selected, num_generated) / num_generated
return final_probability
print(prob_selected([1, 1, 1, 1, 1,
0.5, 0.2, 0.8, 0.9, 0.1]))
我的lua代码:
-- python len()
function tablelength(T)
local count = 0
for _ in pairs(T) do count = count + 1 end
return count
end
-- python sum()
table.reduce = function (list, fn)
local acc
for k, v in ipairs(list) do
if 1 == k then
acc = v
else
acc = fn(acc, v)
end
end
return acc
end
globalArr = {}
function generateBitstrings (n, arr, i)
if i == n then
table.insert(globalArr, {table.unpack(arr)})
return
end
arr[i] = 0
generateBitstrings(n, arr, i + 1)
arr[i] = 1
generateBitstrings(n, arr, i + 1)
end
function prob_selected (weights, num_selected)
local n = tablelength(weights)
final_probability = {}
for i=1, n do
final_probability[i] = 0
end
globalArr = {}
generateBitstrings(n + 1, {}, 1)
for ibots, bits in ipairs(globalArr) do
num_generated = table.reduce(
bits,
function(a, b)
return a + b
end
)
prob_generated = 1
bitsLength = tablelength(bits)
for i=1,bitsLength do
if bits[i] then
prob_generated = prob_generated * weights[i]
else
prob_generated = prob_generated * 1 - weights[i]
end
end
for i=1,n do
if bits[i] == 1 then
final_probability[i] = final_probability[i] + (prob_generated * math.min(num_selected, num_generated) / num_generated)
end
end
end
return final_probability
end
for i, value in ipairs(prob_selected({1, 1, 1, 1, 1,0.5, 0.2, 0.8, 0.9, 0.1}, 5)) do
print(value)
end
解决方案
- 正如@Egor Skriptunoff 所说,
if bits[i]
在Lua 中不会像你期望的那样工作:0
不是假的;只有false
和nil
是。你需要if bits[i] == 1
. - 你忘记了 中的括号
prob_generated = prob_generated * (1 - weights[i])
, - 我添加了几个
local
,虽然这并不重要, - 我在
globalArr
本地制作;你可能想重命名它, - 我使测试输出更紧凑。
-- python len()
local function tablelength(T)
local count = 0
for _ in pairs(T) do count = count + 1 end
return count
end
-- python sum()
table.reduce = function (list, fn)
local acc
for k, v in ipairs(list) do
if 1 == k then
acc = v
else
acc = fn(acc, v)
end
end
return acc
end
local function generateBitstrings (global_arr, n, arr, i)
if i == n then
table.insert(global_arr, {table.unpack(arr)})
return
end
arr[i] = 0
generateBitstrings(global_arr, n, arr, i + 1)
arr[i] = 1
generateBitstrings(global_arr, n, arr, i + 1)
end
local function prob_selected (weights, num_selected)
local n = tablelength(weights)
local final_probability = {}
for i=1, n do
final_probability[i] = 0
end
local globalArr = {}
generateBitstrings(globalArr, n + 1, {}, 1)
for ibots, bits in ipairs(globalArr) do
local num_generated = table.reduce(
bits,
function(a, b)
return a + b
end
)
local prob_generated = 1
local bitsLength = tablelength(bits)
for i=1,bitsLength do
if bits[i] == 1 then
prob_generated = prob_generated * weights[i]
else
prob_generated = prob_generated * (1 - weights[i])
end
end
for i=1,n do
if bits[i] == 1 then
final_probability[i] = final_probability[i] + (prob_generated * math.min(num_selected, num_generated) / num_generated)
end
end
end
return final_probability
end
print (table.concat (prob_selected({1, 1, 1, 1, 1,0.5, 0.2, 0.8, 0.9, 0.1}, 5), ', '))
推荐阅读
- html - 我可以重复表格标题并将页面分成列吗?
- python - 涉及 `format()` 点符号时的 Python 格式化
- reactjs - 边框颜色在 Tailwind CSS 中不起作用
- javascript - 使用 axios 使用客户端凭据获取 Keycloak access_token 在 NextJS 应用程序中引发 400 错误
- html - 如何从IE页面复制文本并将其粘贴到excel中
- project-reactor - 为助焊剂使用者注册不同的线程
- excel - Excel Sumifs 多个标准
- ios - 从模态呈现的视图控制器导航到根视图控制器
- sql - 选择子查询 - 基于值
- postgresql - VSCode 我忘记了我的 postgresql 扩展用户名和密码