首页 > 解决方案 > 在 Heteroskedastic gpflow 示例(多潜在 GP)中使用 VGP 而不是 SVGP 时出错

问题描述

我试图理解为什么在 GPflow 中的异方差回归示例 ( https://gpflow.readthedocs.io/en/develop/notebooks/advanced/heteroskedastic.html ) 中尝试用 VGP 替换 SVGP 时出现 ValueError。

以下是我所做的更改:

核和似然与示例相同。


data = (X, Y)
model = gpf.models.VGP(
    data = data,
    kernel=kernel,
    likelihood=likelihood,
    #inducing_variable=inducing_variable,
    num_latent_gps=likelihood.latent_dim,
)

loss_fn = model.training_loss_closure()

gpf.utilities.set_trainable(model.q_mu, False)
gpf.utilities.set_trainable(model.q_sqrt, False)

variational_vars = [(model.q_mu, model.q_sqrt)]
natgrad_opt = gpf.optimizers.NaturalGradient(gamma=0.1)

adam_vars = model.trainable_variables
adam_opt = tf.optimizers.Adam(0.01)

@tf.function
def optimisation_step():
    natgrad_opt.minimize(loss_fn, variational_vars)
    adam_opt.minimize(loss_fn, adam_vars)

epochs = 100
for epoch in range(0, epochs):
    optimisation_step()

优化步骤给了我这个错误:


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_406484/3662007586.py in <module>
      3 
      4 for epoch in range(1, epochs + 1):
----> 5     optimisation_step()
      6 
      7 

~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    887 
    888       with OptionalXlaContext(self._jit_compile):
--> 889         result = self._call(*args, **kwds)
    890 
    891       new_tracing_count = self.experimental_get_tracing_count()

~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    931       # This is the first call of __call__, so we have to initialize.
    932       initializers = []
--> 933       self._initialize(args, kwds, add_initializers_to=initializers)
    934     finally:
    935       # At this point we know that the initialization is complete (or less

~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    761     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    762     self._concrete_stateful_fn = (
--> 763         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    764             *args, **kwds))
    765 

~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3048       args, kwargs = None, None
   3049     with self._lock:
-> 3050       graph_function, _ = self._maybe_define_function(args, kwargs)
   3051     return graph_function
   3052 

~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3442 
   3443           self._function_cache.missed.add(call_context_key)
-> 3444           graph_function = self._create_graph_function(args, kwargs)
   3445           self._function_cache.primary[cache_key] = graph_function
   3446 

~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3277     arg_names = base_arg_names + missing_arg_names
   3278     graph_function = ConcreteFunction(
-> 3279         func_graph_module.func_graph_from_py_func(
   3280             self._name,
   3281             self._python_function,

~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    997         _, original_func = tf_decorator.unwrap(python_func)
    998 
--> 999       func_outputs = python_func(*func_args, **func_kwargs)
   1000 
   1001       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    670         # the function a weak reference to itself to avoid a reference cycle.
    671         with OptionalXlaContext(compile_with_xla):
--> 672           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    673         return out
    674 

~/miniconda3/envs/tensorflow/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    984           except Exception as e:  # pylint:disable=broad-except
    985             if hasattr(e, "ag_error_metadata"):
--> 986               raise e.ag_error_metadata.to_exception(e)
    987             else:
    988               raise

ValueError: in user code:


    ValueError: Dimensions must be equal, but are 2 and 1001 for '{{node add_2}} = AddV2[T=DT_DOUBLE](diag, mul_1)' with input shapes: [1001,2,2], [1001,1001].

这是一个错误还是可能性和模型不兼容或者我错过了什么?一种解决方法是使诱导变量 = 训练数据并使用 SVGP,但这会使训练变得异常缓慢......

标签: pythontensorflowgpflow

解决方案


推荐阅读