python-3.x - 如何解决 JAX/Python 中的 ValueError `vector::reserve`?
问题描述
编辑:这里的 GitHub 问题:https ://github.com/google/jax/issues/5190
我正在尝试使用 jit 优化以下功能:
@partial(jit, static_argnums=(0, 1,))
def coocurrence_helper(pairs: np.array, label_map: Dict) -> lil_matrix:
uniques = lil_matrix(np.zeros((len(label_map), len(label_map))).astype("int32"))
for item in pairs:
if item[0]!=item[1]:
uniques[label_map[item[0]], label_map[item[1]]] += 1
return uniques
上面的例程在这里使用:
def _get_pairwise_frequencies(
data: pd.DataFrame, crosstab=False
) -> pd.DataFrame:
values = data.stack()
values.index = values.index.droplevel(1)
values.name = "vals"
values = optimize(values.to_frame())
pair = optimize(values.join(values, rsuffix="_2"))
label_map = dict()
for lbl, each in enumerate(values.vals.unique()):
label_map[each] = lbl
if not crosstab:
freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
return ((freq / freq.sum(1).ravel()).astype(np.float32))
else:
freq = pd.crosstab(pair["vals"], pair["vals_2"])
self.index = freq.index
return csr_matrix((freq / freq.sum(1)).astype(np.float32))
但我收到以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-42-f8e638fc2bb6> in <module>
----> 1 _get_pairwise_frequencies(data)
<ipython-input-30-43adeb39c76c> in _get_pairwise_frequencies(data, crosstab)
25 label_map[each] = lbl
26 if not crosstab:
---> 27 freq = coocurrence_helper(pairs = pair.values, label_map=label_map)
28 return csr_matrix((freq / freq.sum(1).ravel()).astype(np.float32))
29 else:
~/anaconda3/envs/tensorflow2_p36/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
369 return cache_miss(*args, **kwargs)[0] # probably won't return
370 else:
--> 371 return cpp_jitted_f(*args, **kwargs)
372 f_jitted._cpp_jitted_f = cpp_jitted_f
373
ValueError: vector::reserve
这里问题的根源是什么?不使用static_argnums
错误消息是
RuntimeError: Invalid argument: Unknown NumPy type O size 8
具有相同的回溯。
解决方案
问题是您返回的scipy.sparse.lil_matrix
不是有效的 JAX 类型。JAXjit
装饰器不能用作任意 Python 代码的编译器;它旨在优化 JAX 数组上的操作序列。
在这种情况下,最好的方法可能是@partial(jit, ...)
从你的函数中删除装饰器;如果你想在这里使用 JAX jit 编译,你首先必须重写你的代码以避免scipy.sparse
矩阵并使用 JAX 数组。
推荐阅读
- android - Proto DataStore - 嵌套类型
- wso2 - 布尔类型属性在 WSO2 EI 6.1.1 中不起作用
- python - 如何使用 python pyodbc 检查表中是否存在列?
- php - 如何在不符合 $callback 参数顺序的类方法上使用 array_walk
- c# - 使用 lambda c# 将数据分组返回列表
- ms-access - 在 MS Access、VBA 中验证未绑定文本控件中的输入
- python - Django - 将多行序列化为列表
- multithreading - 处理 std::thread c++ 中的泄漏
- post - 捕获带有乱码正文的 HTTP Post 请求
- javascript - 如何映射对象数组列表以显示在日历中?