这是我的分解代码我正在尝试使用 TPU 运行我的迭代函数的某些部分,但我收到了回溯错误?如果我犯了一个明显的错误,我很抱歉。update_xu 和 update_yi 主要是矩阵乘法和矩阵求逆。因此我尝试在 strategy.scope() 中运行它们?最终某种内存泄漏什么的

def loss(R, X, Y, C, lmda_x, lmda_y):
  returns the MSE of the weighted least squares plus L2 regularisation error
  Error = R - tf.matmul(X, Y, transpose_a = True)
  Error = Error * Error
  Error = C * Error
  Reg = 0 + tf.math.reduce_sum(X * X * lmda_x) + tf.math.reduce_sum(Y * Y * lmda_y)
  return Reg + tf.math.reduce_sum(Error) * 1 / (Error.shape[0] * Error.shape[1])

def update_xu(Ru, Y, Cuser, lmda):
  column vector Ru,
  k x m matrix Y,
  m x m matrix Cuser and
  the updated user row vector (xu) by making Y matrix constant
  Ru = tf.reshape(Ru, shape = [Y.shape[1], 1])
  C = Cuser @ Ru
  inverse = tf.linalg.inv(Y @ Cuser @ tf.transpose(Y) + lmda * tf.eye(Y.shape[0]))
  ans = (inverse @ Y @ C)
  return tf.reshape(ans, [Y.shape[0]])

def update_yi(Ri, X, Citem, lmda):
  column vector Ri,
  k x n matrix X,
  n x n matrix Citem and
  the updated user row vector (yi) by making X matrix constant
  Ri = tf.reshape(Ri, shape = [X.shape[1], 1])
  C = Citem @ Ri
  inverse = tf.linalg.inv(X  @ Citem @ tf.transpose(X) + lmda * tf.eye(X.shape[0]))
  ans = inverse @ X @ C
  return  tf.reshape(ans, [X.shape[0]])

def iterate(R, X, Y, C, lmda_x, lmda_y, epochs):
  returns approximately updated X and Y such R = X(Y.T) with WALS algorithm
  with strategy.scope():
    for _ in range(epochs):
      Xtt = tf.vectorized_map(lambda x: update_xu(x[0], Y, tf.linalg.diag(x[1]), lmda_x), (R, C))
      #Xtt = tf.map_fn(lambda x: update_xu(x[0], Y, tf.linalg.diag(x[1]), lmda_x), (R, C), dtype = tf.TensorSpec([Y.shape[0]], dtype = tf.float32), parallel_iterations=6)
      X = tf.transpose(Xtt)
      R, C = tf.transpose(R), tf.transpose(C)
      Ytt = tf.vectorized_map(lambda x: update_yi(x[0], X, tf.linalg.diag(x[1]), lmda_y), (R, C))
      #Ytt = tf.map_fn(lambda x: update_yi(x[0], X, tf.linalg.diag(x[1]), lmda_y), (R, C), dtype = tf.TensorSpec([X.shape[0]], dtype = tf.float32),  parallel_iterations=6)
      R , C = tf.transpose(R), tf.transpose(C)
      Y = tf.transpose(Ytt)
  return X, Y

def init_weights(R):  
  chk = tf.where(
      tf.math.abs(R) > eps, 1., 0.
  cnt = tf.constant(
      tf.math.reduce_sum(chk, axis = 0), shape = [1, R.shape[1]]
  cnt = tf.repeat(
      cnt, repeats = R.shape[0], axis = 0
  cnt = tf.reshape(cnt, [R.shape[0], R.shape[1]])
  cnt = cnt * chk
  cnt = tf.random.normal( R.shape, mean=0.0, stddev=1.0) * cnt
  cnt = cnt + tf.random.normal([1], mean = 0.0, stddev = 1.0)
  return cnt

def train( R, lmda_x, lmda_y, epochs, embd):
  flag = False
  loss_train , loss_test, total = 0., 0., 0
  loss_train_list, loss_test_list,total_epochs = [], [], []
  X, Y = tf.zeros([embd, R.shape[0]]), tf.zeros([embd, R.shape[1]]) 
  C = init_weights(train_data)
  for epoch in epochs:
    if flag == False:
      X, Y = tf.random.normal( [embd, R.shape[0]], mean=0.0, stddev=1.0), tf.random.normal( [embd, R.shape[1]], mean=0.0, stddev=1.0)
      X, Y = iterate(R, X, Y, C, lmda_x, lmda_y, epoch)
      flag = True
      X, Y = iterate(train_data, X, Y, C, lmda_x, lmda_y, epoch)
    total += epoch
    loss_train = loss(train_data, X, Y, C, lmda_x, lmda_y)
    loss_test = loss(test_data, X, Y, C, lmda_x, lmda_y)
  print(loss_train, loss_test)
  plt.plot(total_epochs, loss_train_list, label = "Training", linewidth = 5)
  plt.plot(total_epochs, loss_test_list, label = "Test", linewidth = 1)
  plt.xticks(fontsize = 10)
  plt.title(str(embd)+ ', ' +str(lmda_x)  + ',' + str(lmda_y))
  plt.yticks(fontsize = 10)
  plt.xlim(0, 200)
  plt.ylim(0, 1000000000)
  plt.xlabel('iterations', fontsize=30);
  plt.ylabel('MSE', fontsize=30);
  plt.legend(loc='best', fontsize=20);
  return X, Y

def grid_search(epochs, embds, lmdas_x, lmdas_y, train_data, test_data):
    for lmda_x in lmdas_x:
      for lmda_y in lmdas_y:
        for embd in embds:  
          lmda_x , lmda_y , epochs, embd = tf.convert_to_tensor(lmda_x, dtype = tf.float32), tf.convert_to_tensor(lmda_y, dtype =tf.float32), tf.convert_to_tensor(epochs, dtype = tf.int64), tf.convert_to_tensor(embd, dtype =tf.int64)
          X, Y = train(train_data, lmda_x, lmda_y, epochs, embd)


WARNING:tensorflow:11 out of the last 11 calls to <function pfor.<locals>.f at 0x7f4dd38f7b90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:11 out of the last 11 calls to <function pfor.<locals>.f at 0x7f4dd38f7b90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
(943, 1)
WARNING:tensorflow:11 out of the last 11 calls to <function pfor.<locals>.f at 0x7f4dd38f7560> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:11 out of the last 11 calls to <function pfor.<locals>.f at 0x7f4dd38f7560> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
(1682, 1)
(943, 1)
Object was never used (type <class 'tensorflow.python.ops.tensor_array_ops.TensorArray'>):
<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x7f4e018c73d0>
If you want to mark it as used call its "mark_used()" method.
It was originally created here:
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2778, in while_loop
    return result  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2726, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/map_fn.py", line 507, in compute
    return (i + 1, tas)  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/map_fn.py", line 505, in <listcomp>
    ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable)  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/tf_should_use.py", line 249, in wrapped
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-58-a2922620d8dd> in <module>()
      7 epochs = [5, 100] #2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]#, 588, 299, 300, 200]
      8 #X, Y, C = train(train_data.shape[0], train_data.shape[1], train_data, 0, 0, 2, 2)
----> 9 grid_search(epochs, embds, lmdas, lmdas, train_data, test_data)

3 frames
<ipython-input-57-8cc0dfdb85de> in grid_search(epochs, embds, lmdas_x, lmdas_y, train_data, test_data)
      5           plt.figure(figsize = (10,10))
      6           lmda_x , lmda_y , epochs, embd = tf.convert_to_tensor(lmda_x, dtype = tf.float32), tf.convert_to_tensor(lmda_y, dtype =tf.float32), tf.convert_to_tensor(epochs, dtype = tf.int64), tf.convert_to_tensor(embd, dtype =tf.int64)
----> 7           X, Y = train(train_data, lmda_x, lmda_y, epochs, embd)

<ipython-input-56-d70632663530> in train(R, lmda_x, lmda_y, epochs, embd)
     88       flag = True
     89     else:
---> 90       X, Y = iterate(train_data, X, Y, C, lmda_x, lmda_y, epoch)
     91     total += epoch
     92     loss_train = loss(train_data, X, Y, C, lmda_x, lmda_y)

<ipython-input-56-d70632663530> in iterate(R, X, Y, C, lmda_x, lmda_y, epochs)
     50       Xtt = tf.vectorized_map(lambda x: update_xu(x[0], Y, tf.linalg.diag(x[1]), lmda_x), (R, C))
     51       #Xtt = tf.map_fn(lambda x: update_xu(x[0], Y, tf.linalg.diag(x[1]), lmda_x), (R, C), dtype = tf.TensorSpec([Y.shape[0]], dtype = tf.float32), parallel_iterations=6)
---> 52       print(Xtt.shape)
     53       X = tf.transpose(Xtt)
     54       R, C = tf.transpose(R), tf.transpose(C)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in shape(self)
   1173         # `_tensor_shape` is declared and defined in the definition of
   1174         # `EagerTensor`, in C.
-> 1175         self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
   1176       except core._NotOkStatusException as e:
   1177         six.raise_from(core._status_to_exception(e.code, e.message), None)

InvalidArgumentError: {{function_node __inference_f_5094764}} Input is not invertible.
     [[{{node loop_body/MatrixInverse/pfor/MatrixInverse}}]]

