python-3.x - 将操作应用于 numpy 数组中的特定列
问题描述
我想将特征归一化应用于 numpy 数组。通常这对于 python 广播来说是微不足道的,例如一个人会做这样的事情:
train_mean = train.mean(axis=0)
train_std = train.std(axis=0)
train = (train - train_mean) / train_std
val = (val - train_mean) / train_std
test = (test - train_mean) / train_std
但是,我的 numpy 数组有 9 列,因此 和 的形状train_mean
是train_std
,(9,)
我只想对数组中的特定列应用规范化,为此我在字典中有索引:
column_indices
{'blind angle': 0,
'fully open': 1,
'ibn': 2,
'idh': 3,
'altitude': 4,
'azimuth_sin': 5,
'azimuth_cos': 6,
'dgp': 7,
'ill': 8}
我已经列出了我想要规范化的列:
FEATURE_NORM_COLS = ['blind angle', 'ibn', 'idh', 'altitude']
我只想根据它们的索引以及我的 train_mean 和 train_std 列表中的相应索引(与我的数据的索引相同)来规范化这些列。
实现此操作的最佳方法是什么?
我做了以下,似乎得到了想要的结果,但是看起来很麻烦。有没有更好的方法来做到这一点?
for name in FEATURE_NORM_COLS:
train[:, column_indices[name]] = (train[:, column_indices[name]] - train_mean[column_indices[name]]) / train_std[column_indices[name]]
更新
我遵循了类似于我认为更优雅的注释的方法,并且避免了循环数据集中的每一列。
def normalise(dataset, col_indices=COLUMN_INDICES, norm_cols=NORM_COLS,
train_mean=TRAIN_MEAN, train_std=TRAIN_STD):
"""
Returns normalised features with mean of zero and std of 1.
formula is (train - train_mean) / train_std, but we index by indices
since we dont want to normalise all columns.
Args:
dataset: numpy array to normalise
col_indices -> dict: the indices of cols in dataset
norm_cols -> list: columns to be normalised
train_mean -> list: means of train set columns
train_std -> list: std's of train set columns
"""
indices = [col_indices[col] for col in norm_cols]
dataset[:,indices] = (dataset[:,indices] - train_mean[indices]) / train_std[indices]
return dataset
解决方案
那会是更好的方法吗?
from sklearn import preprocessing
np.set_printoptions(suppress=True, linewidth=1000, precision=3)
np.random.seed(5)
train = np.array([np.random.uniform(low=0, high=100, size=10),
np.random.uniform(low=0, high=30, size=10),
np.random.uniform(low=0, high=70, size=10),
np.random.uniform(low=0, high=20, size=10),
np.random.uniform(low=0, high=90, size=10),
np.random.uniform(low=0, high=50, size=10),
np.random.uniform(low=0, high=30, size=10),
np.random.uniform(low=0, high=80, size=10),
np.random.uniform(low=0, high=90, size=10)]).T
column_indices = {'blind angle': 0,
'fully open': 1,
'ibn': 2,
'idh': 3,
'altitude': 4,
'azimuth_sin': 5,
'azimuth_cos': 6,
'dgp': 7,
'ill': 8}
FEATURE_NORM_COLS = ['blind angle', 'ibn', 'idh', 'altitude']
indices = [column_indices[c] for c in FEATURE_NORM_COLS]
print('TRAIN\n', train, '\n')
scaler = preprocessing.StandardScaler().fit(train)
train[:,indices] = scaler.transform(train)[:, indices]
print('PARTIALLY SCALED\n', train)
如果需要,您可以将scaler
用于您的验证集和测试集(请参阅文档)。
输出:
TRAIN
[[22.199 2.422 41.995 0.486 23.319 38.543 19.061 4.091 84.919]
[87.073 22.153 18.607 4.091 72.225 24.247 24.357 15.093 10.052]
[20.672 13.239 19.928 13.997 78.343 1.456 27.8 29.238 75.92 ]
[91.861 4.749 17.751 15.59 83.047 4.326 27.379 19.543 31.143]
[48.841 26.398 22.929 0.459 0.199 5.573 24.744 63.607 9.074]
[61.174 8.223 10.092 11.553 42.254 12.562 2.826 28.168 34.507]
[76.591 12.427 11.593 0.033 88.332 48.246 10.831 51.11 45.932]
[51.842 8.882 67.475 10.309 35.905 31.588 1.065 39.473 86.499]
[29.68 18.864 67.216 12.796 73.236 40.833 16.391 46.68 33.436]
[18.772 17.395 13.189 19.712 49.181 28.304 23.884 75.144 1.113]]
PARTIALLY SCALED
[[-1.085 2.422 0.618 -1.247 -1.132 38.543 19.061 4.091 84.919]
[ 1.371 22.153 -0.501 -0.713 0.638 24.247 24.357 15.093 10.052]
[-1.143 13.239 -0.437 0.755 0.859 1.456 27.8 29.238 75.92 ]
[ 1.552 4.749 -0.542 0.991 1.029 4.326 27.379 19.543 31.143]
[-0.077 26.398 -0.294 -1.251 -1.969 5.573 24.744 63.607 9.074]
[ 0.39 8.223 -0.908 0.393 -0.447 12.562 2.826 28.168 34.507]
[ 0.974 12.427 -0.836 -1.314 1.22 48.246 10.831 51.11 45.932]
[ 0.037 8.882 1.836 0.208 -0.677 31.588 1.065 39.473 86.499]
[-0.802 18.864 1.824 0.577 0.674 40.833 16.391 46.68 33.436]
[-1.215 17.395 -0.76 1.601 -0.196 28.304 23.884 75.144 1.113]]
推荐阅读
- vue.js - 使用 nuxt-components true 选项在 WebStorm 中定义组件
- python - 如何访问从 ListView 传递到模板的上下文变量中的字典键?
- image - 在优化校准中计算图像的正确位置,并正确绘制背景(在 MFC 对话框中)
- git - 创建签名的 Git 标签
- r - 使用多组日期对 R 数据框进行子集化
- r - 如何确保使用 case_when 的所有选项?
- laravel - Laravel rest api,获取与存储在上传/帖子中的数据库中的数据相关的图像的完整链接
- python - 弄乱 Python PATH 后如何重新下载 matplotlib?
- database - 如何根据两个字段值导出唯一文档?
- java - 包含为两个相同的对象返回 false