首页 > 解决方案 > Numpy:重塑的数组表现异常

问题描述

我正在尝试重塑一个 numpy 数组 [ link ],然后再次重塑该数组,但无法达到我想要的结果。我的数据从 shape 开始(n_vertices, n_time, n_dimensions)。然后我将其转换为形状(n_time, n_vertices * n_dimensions)

import numpy as np

X = np.load('dance.npy')

n_vertices, n_time, n_dims = X.shape    

X = X.reshape(n_time, n_vertices * n_dims)

通过可视化数据,我可以看到上面的转换并没有扭曲内部值:

import mpl_toolkits.mplot3d.axes3d as p3
from mpl_toolkits.mplot3d.art3d import juggle_axes
import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib import animation
import matplotlib

matplotlib.rcParams['animation.embed_limit'] = 2**128

def update_points(time, points, df):
  points._offsets3d = juggle_axes(df[:,time,0], df[:,time,1], df[:,time,2], 'z')

def get_plot(df, lim=1, frames=200, duration=45, time_axis=1, reshape=False):
  if reshape: df = df.reshape(n_vertices, df.shape[time_axis], n_dims)
  fig = plt.figure()
  ax = p3.Axes3D(fig)
  ax.set_xlim(-lim, lim)
  ax.set_ylim(-lim, lim)
  ax.set_zlim(-lim, lim)
  points = ax.scatter(df[:,0,0], df[:,0,1], df[:,0,2], depthshade=False) # x,y,z vals
  return animation.FuncAnimation(fig, update_points, frames, interval=duration, fargs=(points, df), blit=False ).to_jshtml()

HTML(get_plot(X, frames=200, time_axis=0, reshape=True))

这显示了运动中的数据(顶点是舞者的身体部位,可视化看起来像人体)。这一切都很好。但是,当我尝试仅可视化数据的前 10 个时间片时,生成的图并没有显示上面可视化的前几帧——该表格实际上不是人形的:

HTML(get_plot(X[:20], frames=10, time_axis=0, reshape=True))

谁能帮我理解为什么这个切片操作与 X 的前几个时间框架不匹配?任何建议或意见都会非常有帮助。

标签: pythonarraysnumpymatplotlibindexing

解决方案


事实证明,我的重塑操作并没有像我想象的那样操纵我的数组。以下函数将我的原始数组 X 重塑为展平形式(具有两个轴),然后正确地恢复为未展平形式(具有三个轴)。我添加了评论和测试以确保一切都符合预期:

from math import floor

def flatten(df, run_tests=True):
  '''
  df is a numpy array with the following three axes:
    df.shape[0] = the index of a vertex
    df.shape[1] = the index of a time stamp
    df.shape[2] = the index of a dimension (x, y, z)
  So df[1][0][2] is the value for the 1st vertex (0-based) at time 0 in dimension 2 (z).
  To flatten this dataframe will mean to push the data into shape:
    flattened.shape[0] = time index
    flattened.shape[1] = [vertex_index*3] + dimension_vertex
  So flattened[1][3] will be the 3rd dimension of the 1st index (0-based) at time 1. 
  '''
  if run_tests:
    assert df.shape == X.shape and np.all(df == X)

  # reshape X such that flattened.shape = time, [x0, y0, z0, x1, y1, z1, ... xn-1, yn-1, zn-1]
  flattened = X.swapaxes(0, 1).reshape( (df.shape[1], df.shape[0] * df.shape[2]), order='C' )

  if run_tests: # switch to false to skip tests
    for idx, i in enumerate(df):
      for jdx, j in enumerate(df[idx]):
        for kdx, k in enumerate(df[idx][jdx]):
          assert flattened[jdx][ (idx*df.shape[2]) + kdx ] == df[idx][jdx][kdx]

  return flattened

并取消展平展平的数据:

def unflatten(df, run_tests=True):
  '''
  df is a numpy array with the following two axes:
    df.shape[0] = time index
    df.shape[1] = [vertex_index*3] + dimension_vertex

  To unflatten this dataframe will mean to push the data into shape:
    unflattened.shape[0] = the index of a vertex
    unflattened.shape[1] = the index of a time stamp
    unflattened.shape[2] = the index of a dimension (x, y, z)

  So df[2][4] == unflattened[1][2][0]
  '''
  if run_tests:
    assert (len(df.shape) == 2) and (df.shape[1] == X.shape[0] * X.shape[2])

  unflattened = np.zeros(( X.shape[0], df.shape[0], X.shape[2] ))

  for idx, i in enumerate(df):
    for jdx, j in enumerate(df[idx]):
      kdx = floor(jdx / 3)
      ldx = jdx % 3
      unflattened[kdx][idx][ldx] = df[idx][jdx]

  if run_tests: # set to false to skip tests
    for idx, i in enumerate(unflattened):
      for jdx, j in enumerate(unflattened[idx]):
        for kdx, k in enumerate(unflattened[idx][jdx]):
          assert( unflattened[idx][jdx][kdx] == X[idx][jdx][kdx] )

  return unflattened

然后可视化:

import mpl_toolkits.mplot3d.axes3d as p3
from mpl_toolkits.mplot3d.art3d import juggle_axes
import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib import animation
import matplotlib

# ask matplotlib to plot up to 2^128 frames in animations
matplotlib.rcParams['animation.embed_limit'] = 2**128

def update_points(time, points, df):
  points._offsets3d = juggle_axes(df[:,time,0], df[:,time,1], df[:,time,2], 'z')

def get_plot(df, lim=1, frames=200, duration=45):
  if len(df.shape) == 2: df = unflatten(df)
  fig = plt.figure()
  ax = p3.Axes3D(fig)
  ax.set_xlim(-lim, lim)
  ax.set_ylim(-lim, lim)
  ax.set_zlim(-lim, lim)
  points = ax.scatter(df[:,0,0], df[:,0,1], df[:,0,2], depthshade=False) # x,y,z vals
  return animation.FuncAnimation(fig,
    update_points,
    frames,
    interval=duration,
    fargs=(points, df),
    blit=False  
  ).to_jshtml()

HTML(get_plot(unflat, frames=200))

这让我可以毫无问题地分割我的时间轴:

flat = flatten(X)
unflat = unflatten(flat)

HTML(get_plot(unflat, frames=200))
HTML(get_plot(flat[:20], frames=20))
HTML(get_plot(unflat[:,:20,:], frames=20))

推荐阅读