python - 在 Heteroskedastic gpflow 示例(多潜在 GP)中使用 VGP 而不是 SVGP 时出错
问题描述
我试图理解为什么在 GPflow 中的异方差回归示例 ( https://gpflow.readthedocs.io/en/develop/notebooks/advanced/heteroskedastic.html ) 中尝试用 VGP 替换 SVGP 时出现 ValueError。
以下是我所做的更改:
- 模型 = gpf.models.VGP(...)
- loss_fn = model.training_loss_closure() 而不是 loss_fn = model.training_loss_closure(data)
核和似然与示例相同。
data = (X, Y)
model = gpf.models.VGP(
data = data,
kernel=kernel,
likelihood=likelihood,
#inducing_variable=inducing_variable,
num_latent_gps=likelihood.latent_dim,
)
loss_fn = model.training_loss_closure()
gpf.utilities.set_trainable(model.q_mu, False)
gpf.utilities.set_trainable(model.q_sqrt, False)
variational_vars = [(model.q_mu, model.q_sqrt)]
natgrad_opt = gpf.optimizers.NaturalGradient(gamma=0.1)
adam_vars = model.trainable_variables
adam_opt = tf.optimizers.Adam(0.01)
@tf.function
def optimisation_step():
natgrad_opt.minimize(loss_fn, variational_vars)
adam_opt.minimize(loss_fn, adam_vars)
epochs = 100
for epoch in range(0, epochs):
optimisation_step()
优化步骤给了我这个错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_406484/3662007586.py in <module>
3
4 for epoch in range(1, epochs + 1):
----> 5 optimisation_step()
6
7
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
887
888 with OptionalXlaContext(self._jit_compile):
--> 889 result = self._call(*args, **kwds)
890
891 new_tracing_count = self.experimental_get_tracing_count()
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
931 # This is the first call of __call__, so we have to initialize.
932 initializers = []
--> 933 self._initialize(args, kwds, add_initializers_to=initializers)
934 finally:
935 # At this point we know that the initialization is complete (or less
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
761 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
762 self._concrete_stateful_fn = (
--> 763 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
764 *args, **kwds))
765
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
3048 args, kwargs = None, None
3049 with self._lock:
-> 3050 graph_function, _ = self._maybe_define_function(args, kwargs)
3051 return graph_function
3052
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442
3443 self._function_cache.missed.add(call_context_key)
-> 3444 graph_function = self._create_graph_function(args, kwargs)
3445 self._function_cache.primary[cache_key] = graph_function
3446
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3277 arg_names = base_arg_names + missing_arg_names
3278 graph_function = ConcreteFunction(
-> 3279 func_graph_module.func_graph_from_py_func(
3280 self._name,
3281 self._python_function,
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997 _, original_func = tf_decorator.unwrap(python_func)
998
--> 999 func_outputs = python_func(*func_args, **func_kwargs)
1000
1001 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670 # the function a weak reference to itself to avoid a reference cycle.
671 with OptionalXlaContext(compile_with_xla):
--> 672 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
673 return out
674
~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
984 except Exception as e: # pylint:disable=broad-except
985 if hasattr(e, "ag_error_metadata"):
--> 986 raise e.ag_error_metadata.to_exception(e)
987 else:
988 raise
ValueError: in user code:
ValueError: Dimensions must be equal, but are 2 and 1001 for '{{node add_2}} = AddV2[T=DT_DOUBLE](diag, mul_1)' with input shapes: [1001,2,2], [1001,1001].
这是一个错误还是可能性和模型不兼容或者我错过了什么?一种解决方法是使诱导变量 = 训练数据并使用 SVGP,但这会使训练变得异常缓慢......
解决方案
推荐阅读
- javascript - 哪种方式最适合 React Native 中的条件渲染?
- python - 尝试在 Windows10 上使用 Python 创建 Virtualenv 时出现 AssertionError
- c++ - C++ 对泛型类成员的未定义引用
- algorithm - 计算给定代码的时间复杂度的问题
- android - Searchview 过滤器不适用于包含已解析 JSON 的列表,但适用于非 JSON 列表
- php - PDO:使用子字符串作为标准来获取行数
- python - 在 EMR 中运行 Jupyter notebook 时没有名为“pyspark”的模块
- python - 使用 Python 将 Excel DATE(不是日期时间)数据插入 SQL Server
- java - Maven发布插件配置及示例
- java - 如何在 swagger 模型中创建嵌套对象