python - 稍微编辑的 tensorflow 官方示例无法正常运行
问题描述
我正在通过分布式张量流研究 word2vec。出于兼容的原因,只需将官方 word2vec 稍微编辑为 Model 有点编码拱门。
代码片段如下:
def build():
self.global_step = tf.train.get_or_create_global_step()
with tf.variable_scope("weights", partitioner=partitioner):
self.embeddings = tf.get_variable(name="embeddings", shape=(self.vocab_size, self.embedding_size), initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0))
self.nce_weights = tf.get_variable(name="nce_weights", shape=(self.vocab_size, self.embedding_size), initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(self.embedding_size)))
self.bias = tf.get_variable(name="bias", shape=(self.vocab_size), initializer=tf.zeros_initializer())
self.embeded = tf.nn.embedding_lookup(self.embeddings, inputs, partition_strategy='div')
print("lables: ", self.labels)
self.loss = tf.reduce_mean(
tf.nn.nce_loss(
weights = self.nce_weights,
biases = self.bias,
labels = self.labels,
inputs = self.embeded,
num_sampled = self.num_sampled,
num_classes = self.vocab_size,
partition_strategy="div"
)
)
self.optimizer = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step)
# evaluate
normized = tf.sqrt(tf.reduce_sum(tf.square(self.embeddings), 1, keepdims=True))
normallized_embeddings = self.embeddings / normized
valid_data = np.r_[1:5]
self.valid_size = len(valid_data)
evaluate_examples = tf.constant(valid_data)
valid_embeddings = tf.nn.embedding_lookup(normallized_embeddings, evaluate_examples)
self.similarity = tf.matmul(valid_embeddings, normallized_embeddings, transpose_b=True)
火车方式:
def train(args):
loss, _, global_step, embs = session.run([self.loss, self.optimizer, self.global_step, self.embeddings])
print(embs)
训练:
def main():
model = Word2vec(args)
model.build() # call the method above to build the graph
tf.global_variables_initializer()
with tf.Session() as sess:
while num_step < upperboud:
model.train(sess)
我在训练的时候打印了评估结果,发现一直没有变化,但是nce_weights在变化。并且 global_step 和 local_step 正在增加。不知道哪里错了,谁能帮忙指出?谢谢
解决方案
推荐阅读
- python - sagemath:如何找到正确的导入库名称来编写要从 sage 加载的 python prog
- powershell - “Get-Help”不显示内容
- python-3.x - 查找重复行并将相应数据移动到与原始行相邻的位置
- number-formatting - 如何在 Apache Superset 中以千万和十万为单位显示数字
- google-cloud-platform - API 网关请求 API 限制 Google Cloud
- java - 如何将相机预览放在扩展 InputMethodService 的类中?
- r - 如何在 r 中使用“arfima”包来拟合 ARIMA(1,1,1) 和 ARMA(1,1)
- go - 使用 Go Get 安装后无法导入包?
- python - Python 随机库:从帕累托分布模拟(使用形状和比例参数)
- python - 如何为轮廓图扩展 matplotlib 颜色条?