python - Astropy:将 FITS 表拆分为训练和测试集
问题描述
我有一个 FITS 表,我正在用 astropy 操作。我想将表随机拆分为训练和测试数据以创建两个新的 FITS 表。
我首先想到使用该scikit-learn
函数test_train_split
,但后来我不得不将我的数据来回转换为numpy.array
.
到目前为止,我已经data
从 FITS 文件中读取了 astropy.table.Table 并尝试了以下操作
training_fraction = 0.5
n = len(data)
indexes = random.sample(range(n), k=int(n*training_fraction))
testing_sample = data[indexes]
training_sample = ?
但是,我不知道如何获取索引不在的所有行indexes
。也许有更好的方法来做到这一点?如何获得我的表的随机分区?
我表中的样本碰巧每个都有一个唯一的 ID,它是一个介于 1 和 len(data) 之间的整数。所以我想,我可以做到
indexes = random.sample(range(1, n+1), k=int(n*training_fraction))
testing_sample = data[data['ID'] in indexes]
training_sample = data[data['ID'] not in indexes]
但是第一行提出了ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
解决方案
您提到使用 scikit-learn 中的现有train_test_split
路由。如果这是您使用 scikit-learn的唯一目的,那将是矫枉过正。但是,如果您已经将它用于任务的其他部分,您也可以这样做。Astropy Tables 一开始就已经由 Numpy 数组支持,所以你不需要“来回转换你的数据”。
由于表的'ID'
列对表中的行进行索引,因此将其正式设置为表的索引会很有用,以便 ID 值可用于索引表中的行(独立于它们的实际位置索引)。例如:
>>> from astropy.table import Table
>>> import numpy as np
>>> t = Table({
... 'ID': [1, 3, 5, 6, 7, 9],
... 'a': np.random.random(6),
... 'b': np.random.random(6)
... })
>>> t
<Table length=6>
ID a b
int64 float64 float64
----- ------------------- -------------------
1 0.7285295918917892 0.6180944983953155
3 0.9273855839237182 0.28085439237508925
5 0.8677312765220222 0.5996267567496841
6 0.06182255608446752 0.6604620336092745
7 0.21450048405835265 0.5351066893214822
9 0.928930682667869 0.8178640424254757
然后设置'ID'
为表的索引:
>>> t.add_index('ID')
用于train_test_split
根据需要对 ID 进行分区:
>>> train_ids, test_ids = train_test_split(t['ID'], test_size=0.2)
>>> train_ids
<Column name='ID' dtype='int64' length=4>
7
9
5
1
>>> test_ids
<Column name='ID' dtype='int64' length=2>
6
3
>>> train_set = t.loc[train_ids]
>>> test_set = t.loc[test_ids]
>>> train_set
<Table length=4>
ID a b
int64 float64 float64
----- ------------------- ------------------
7 0.21450048405835265 0.5351066893214822
9 0.928930682667869 0.8178640424254757
5 0.8677312765220222 0.5996267567496841
1 0.7285295918917892 0.6180944983953155
>>> test_set
<Table length=2>
ID a b
int64 float64 float64
----- ------------------- -------------------
6 0.06182255608446752 0.6604620336092745
3 0.9273855839237182 0.28085439237508925
(笔记:
>>> isinstance(t['ID'], np.ndarray)
True
>>> type(t['ID']).__mro__
(astropy.table.column.Column,
astropy.table.column.BaseColumn,
astropy.table._column_mixins._ColumnGetitemShim,
numpy.ndarray,
object)
)
对于它的价值,因为它可能会帮助您在将来更轻松地找到此类问题的答案,这将有助于更抽象地考虑您正在尝试做的事情(似乎您已经在这样做,但是您的问题的措辞否则建议):表中的列只是 Numpy 数组——一旦它采用这种形式,它们是从 FITS 文件中读取的就无关紧要了。你所做的也与 Astropy 没有直接关系。问题就变成了如何随机划分一个 Numpy 数组。
您可以找到这个问题的通用答案,例如,在这个问题中。但是使用现有的专用实用程序也很好,train_test_split
如果你有它的话。
推荐阅读
- java - 如何仅使用 Java (android studio) 创建具有特定布局的卡片视图
- ios - 如何在应用启动时修复 Unity IOS Facebook 错误
- c# - 为什么 MongoDB 在 Unity 中给我错误?
- debugging - Qt 5.10 我可以像在 Octave 中那样在调试期间(手动)更改变量吗?
- performance - 在 UML 用例图的系统边界内移动参与者
- scroll - AMP:是否可以在用户滚动时突出显示页面菜单?
- c++ - 如何使用 Bazel 构建这个简单的示例?
- arduino - Beacon NRF52832 从组装好的 SHT30 传感器读取温度/湿度值
- php - PHP Table - 如何在从数据库中提取的每一行表上添加不同的超链接
- c# - 事件重入问题。同一事件同时运行