python - Tensorflow:如何在约束下最小化
问题描述
我正面临一个受约束、等式和不等式约束的数值优化问题。看起来这项任务的一切都在 tensorflow 中,阅读诸如https://www.tensorflow.org/api_docs/python/tf/contrib/constrained_optimization之类的文档。
虽然我错过了一个最小的工作示例。我进行了广泛的谷歌搜索,但没有任何结果。谁能和我分享一些有用的资源?最好在急切模式下运行。
编辑:
我现在找到了https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/constrained_optimization
我仍然欢迎任何其他资源。
解决方案
您可以使用适用于 TF > 1.4 的TFCO 。
这是我们要最小化的具体示例:
(x - 2) ^ 2 + y
英石
- x + y = 1
- x > 0
- y > 0
import tensorflow as tf
# Use the GitHub version of TFCO
# !pip install git+https://github.com/google-research/tensorflow_constrained_optimization
import tensorflow_constrained_optimization as tfco
class SampleProblem(tfco.ConstrainedMinimizationProblem):
def __init__(self, loss_fn, weights):
self._loss_fn = loss_fn
self._weights = weights
@property
def num_constraints(self):
return 4
def objective(self):
return loss_fn()
def constraints(self):
x, y = self._weights
sum_weights = x + y
lt_or_eq_one = sum_weights - 1
gt_or_eq_one = 1 - sum_weights
constraints = tf.stack([lt_or_eq_one, gt_or_eq_one, -x, -y])
return constraints
x = tf.Variable(0.0, dtype=tf.float32, name='x')
y = tf.Variable(0.0, dtype=tf.float32, name='y')
def loss_fn():
return (x - 2) ** 2 + y
problem = SampleProblem(loss_fn, [x, y])
optimizer = tfco.LagrangianOptimizer(
optimizer=tf.optimizers.Adagrad(learning_rate=0.1),
num_constraints=problem.num_constraints
)
var_list = [x, y] + problem.trainable_variables + optimizer.trainable_variables()
for i in range(10000):
optimizer.minimize(problem, var_list=var_list)
if i % 1000 == 0:
print(f'step = {i}')
print(f'loss = {loss_fn()}')
print(f'constraint = {(x + y).numpy()}')
print(f'x = {x.numpy()}, y = {y.numpy()}')
推荐阅读
- makefile - 如何在 makefile 中使用 g++ 编译位于不同目录中的源文件和头文件?
- visual-studio - 如何使代码抑制消息始终以英文显示?
- mathematical-optimization - N个产品与M个特征的最优组合
- angular - 为什么我在使用 Microsoft Graph REST API 使用 Angular 调用此端点时会收到 403 禁止响应?
- r - R,data.table - 为多个表创建新列
- mpesa - Safaricom 仪表板未显示所有测试凭据
- java - 执行 SonarScanner.MSBuild.exe end 命令后执行 SonarScanner 期间出错
- javascript - Javascript如何生成时隙的动态列表
- javascript - 有没有办法在 for 中定义一个钩子?
- grafana - InfluxQL 数据源连接错误“xxxxxx....” 找不到数据库