z3 - Z3Py:如何以最有效的方式实现和约束位向量 LUT 的输入
问题描述
我尝试使用 Z3Py 作为 SMT Solver 为具有内部状态条件的 DES 类型算法生成密码/密钥/明文对。求解器运行了很长时间,还没有找到解决方案。如何改进代码以显着加快执行时间?
让我先解释一下算法、代码和问题。
应该用 SMT 求解器建模的算法是一个 16 轮的 DES 类型加密算法,它由排列、位移、异或和替换框组成。最后一个函数实际上是一小部分密码的查找表,它将 6 位输入映射到 4 位输出。每轮包括从 30 位输入向量中提取 5 个 6 位向量,应用 LUT,并将 5 个 4 位向量连接到 20 位。任务是在每一轮中生成仅使用这些 LUT 中的某些 LUT 的密码/密钥/明文对。
下面是一个完整的可执行示例,它删除了所有不重要的位操作以专注于问题陈述。
棘手的部分(我不确定的地方)是 LUT 部分,它应该是一个查找表。LUT 被实现为一个函数:
LUT = [Function('LUT_%s' %i, BitVecSort(6), BitVecSort(4)) for i in range(5)]
无效条目在我的python列表中用-1标记,所有不是-1的值都添加到函数中,如代码示例所示# defining the LUT mapping
对于每一轮 k,计算每个单独 LUT 的输入。添加了一个约束,因此 LUT 输入应该不等于所有无效条目。这在 下的代码示例中进行了描述#Adding Invalid LUT inputs to constraints
。求解器有大约 2560 个这样的约束。
M
问题:求解器在合理的时间内执行并找到 10-11 轮的配对(在下面的代码示例中)。对于 12 轮,求解器无法在 2 小时内找到结果。该示例的更详细分析见表。
米 | 时间(秒) |
---|---|
6 | 0.3 |
7 | 0.8 |
8 | 3 |
9 | 4.6 |
10 | 66 |
11 | 177 |
问题:是否有更有效的方式来描述问题陈述,例如 LUT 的实现方式或约束的描述方式?
谢谢你的支持。
import time
from z3 import *
# Crypto algo LUT and taps declaration
LUTn = [[14, -1, 13, 1, -1, -1, 11, 8, 3, 10, -1, 12, 5, -1, 0, 7, 0, 15, 7, 4, -1, 2, 13, 1, -1, 6, 12, 11, -1, -1, -1, 8, -1, -1, 14, -1, -1, 6, -1, -1, 15, -1, 9, -1, 3, -1, -1, -1, 15, 12, -1, -1, -1, -1, 1, 7, 5, 11, 3, 14, -1, 0, -1, -1],
[15, -1, 8, 14, -1, 11, 3, 4, 9, 7, 2, 13, -1, 0, 5, 10, 3, -1, -1, 7, 15, -1, -1, -1, 120, 1, 10, 6, 9, -1, 5, -1, 0, 14, -1, -1, -1, 4, -1, 1, 5, -1, -1, -1, -1, -1, -1, 15, 13, 8, 10, -1, 3, 15, 4, -1, -1, -1, -1, -1, -1, 5, -1, -1],
[10, -1, 9, 14, 6, 3, 15, 5, 1, 13, 12, -1, 11, 4, -1, 8, -1, -1, 0, 9, -1, 4, 6, -1, -1, -1, 5, -1, 12, 11, 15, 1, 13, -1, 4, 9, 8, -1, 3, 0, -1, -1, 2, -1, -1, 10, -1, 7, 1, -1, -1, -1, 6, 9, 8, 7, 4, -1, 14, 3, 11, 5, 2, 12],
[7, -1, -1, -1, -1, 6, 9, -1, 1, -1, -1, 5, -1, 12, 4, -1, 13, 8, -1, -1, 6, -1, -1, 3, 4, 7, 2, 12, 1, 10, -1, 9, -1, 6, 9, 0, 12, 11, 7, 13, -1, 1, 3, 14, 5, -1, 8, 4, 3, 15, -1, 6, -1, 1, -1, 8, 9, -1, -1, -1, -1, 7, 2, 14],
[-1, -1, 4, 1, -1, 10, 11, 6, 8, -1, 3, -1, 13, -1, -1, 9, -1, 11, 2, 12, -1, 7, 13, 1, -1, 0, 15, 10, 3, 9, -1, 6, 4, -1, 1, 11, -1, -1, 7, 8, -1, 9, -1, 5, 6, 3, -1, 14, -1, 8, 12, -1, 1, 14, 2, 13, -1, -1, 0, 9, 10, 4, 5, -1]]
m_rounds = 16
# Z3 Declarations
key = [BitVec("key_%s" %i ,30) for i in range(m_rounds+2)]
cipher_right = [BitVec("cipher_right_%s" % i,20) for i in range(m_rounds+3)]
cipher_left = [BitVec("cipher_left_%s" % i,20) for i in range(m_rounds+2)]
cipher_expanded = [BitVec("cipher_expanded_%s" % i,30) for i in range(m_rounds+1)]
lut_in = [BitVec("lut_in_%s" % i,30) for i in range(m_rounds+1)]
lut_out = [BitVec("lut_out_%s" % i,20) for i in range(m_rounds+1)]
lutX_in = [[BitVec("lutX_in_%s_%s" % (i, j),6) for i in range(m_rounds+1)] for j in range (5)]
# Uninterpreted function to model the LUT
LUT = [Function('LUT_%s' %i, BitVecSort(6), BitVecSort(4)) for i in range(5)]
# Adding constraints to the solver
slv = Solver()
# defining the LUT mapping
for k in range(5):
for lut_in_s, lut_val_s in enumerate(LUTn[k]):
if(lut_val_s != -1):
slv.add(LUT[k](lut_in_s) == lut_val_s)
# Constraining the M rounds of the algorithm
M = 11
for rnd in range(m_rounds, m_rounds-M, -1):
# Adding partial steps of cryptographic algorithm necessary for minimal example
slv.add(key[rnd] == Concat(Extract(0,0,key[rnd+1]),Extract(28,0,LShR(key[rnd+1],1))))
slv.add(cipher_expanded[rnd] == Concat(Extract(9,0,cipher_right[rnd]),cipher_right[rnd]))
slv.add(lut_in[rnd] == key[rnd]^cipher_expanded[rnd])
for k in range(5):
slv.add(lutX_in[k][rnd] == Concat(Extract(21+2*k, 20+2*k ,lut_in[rnd]),Extract(3+4*k, 4*k, lut_in[rnd])))
slv.add(cipher_right[rnd-1] == cipher_left[rnd]^lut_out[rnd])
slv.add(cipher_left[rnd-1] == cipher_right[rnd])
# Adding LUT evaluation
slv.add(lut_out[rnd] == Concat([LUT[k](lutX_in[k][rnd]) for k in range(5)]))
# Adding Invalid LUT inputs to constraints
for k in range(5):
lut_inval_in = [lut_in_s for lut_in_s, lut_val_s in enumerate(LUTn[k]) if lut_val_s == -1]
for lut_inval_in_s in lut_inval_in:
slv.add(lutX_in[k][rnd] != lut_inval_in_s)
t0 = time.process_time()
if(slv.check() == sat):
print('Solution found')
else:
print('No Solution found')
elapsed_time = time.process_time() - t0
print("Elapsed Time: " + str(elapsed_time) + " s")
解决方案
推荐阅读
- javascript - 移动到其他页面时停止 setTimeout
- nginx - 使用托管在 Google Kubernetes Engine 上的 Nginx RTMP 服务器的 HTTP 实时流式传输 (HLS)
- python - 优化 json.load() 以减少 Python 中的内存使用和时间
- linux - Alpine Linux 映像上的 maven 和 jdk8
- sql - 怎么按case和order呢?
- sql - 如何构造复杂的 django 查询语句?
- multithreading - 既然 U32 已经实现了同步,为什么还要在 Rust 中使用 AtomicU32?
- python - 具有离散数据的高斯混合模型
- c# - 使用 OneDrive API 的正确方法是什么?
- sql - 使用 CTE 计算累积和