首页 > 解决方案 > Python - 在变量中存储大数组

问题描述

我有一个数据框,它正在存储(或尝试)5 个数组字段 1489*2048 图像文件和一个分类。我尝试过 json 和其他格式,目前正在使用 Pandy Data Frames。当我处理文件以提取图像数组并将它们放入数据帧时,我得到截断的数据.... 初始数据是 FITS 文件,我正在将 5 个波段的天文图像的 HUD0 数据放入一个数组中我可以对它进行一些机器学习。

import numpy as np
import pandas as pd
#from array import *
from astropy.io import fits
import matplotlib.pyplot as plt
from astropy.visualization import astropy_mpl_style
plt.style.use(astropy_mpl_style)
import os
os.chdir('d:\\project\\masters')


# table of definitations for classification
# 1 is a star
# 2 is a galaxy
# 3 is a quazar

classification = 1

with open('input.txt') as f:
    i=0
    while True:
        line = f.readline()
        if not line:
            break
        line = line.strip()
        band_u = fits.open(line, memmap=True)

        line = f.readline()
        if not line:
            break
        line = line.strip()
        band_g = fits.open(line, memmap=True)

        line = f.readline()
        if not line:
            break
        line = line.strip()
        band_r = fits.open(line, memmap=True)

        line = f.readline()
        if not line:
            break
        line = line.strip()
        band_i = fits.open(line, memmap=True)

        line = f.readline()
        if not line:
            break
        line = line.strip() 
        band_z = fits.open(line, memmap=True)     

        data1 = band_u[0].data
        data2 = band_g[0].data
        data3 = band_r[0].data
        data4 = band_i[0].data 
        data5 = band_z[0].data
        #
        #
        #
        my_array = np.array([data1, data2, data3, data4, data5, classification])
        df = pd.DataFrame(my_array)
        df.to_csv(r'pandas.txt', header=None, index=None, sep='\t', mode='a')
        # np.save('data.npy' , my_array)
        print(i)
        i += 1 # to track progress
    f.close()

输出示例

**"[[-0.0244751   0.01791382 -0.00328064 ... -0.01081848  0.06750488
   0.01052856]
 [-0.01739502  0.01791382 -0.01739502 ... -0.02505493  0.01763916
  -0.00370789]
 [-0.03155518 -0.0244751  -0.01739502 ...  0.07458496 -0.01081848
   0.01052856]**

需要摆脱截断....有什么想法吗?

标签: pythondataframevariablesfits

解决方案


我想我可能会看到你的问题。您的每个波段数据已经是一个 Numpy 数组,但classification只是一个整数。但在这里你做:

my_array = np.array([data1, data2, data3, data4, data5, classification])

这将创建一个混合数据类型的 Numpy“数组”:数组和整数。这会产生一个“对象”类型的数组(Numpy 数组应该在项目的数据类型中是同质的,但是如果您尝试从异构列表中创建一个,您只会得到一个“对象”数组。为了说明不同之处在于,如果您将相同大小的数组列表传递给np.array()它,它将把它们堆叠成一个二维数组:

>>> np.array([np.arange(10), np.arange(10), np.arange(10)])                                    
array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

但是,如果您在此列表的末尾添加一个int,它将生成一个 dtype 数组object(因为输入是混合数据类型,它们之间没有“明显”的转换):

>>> a2 = np.array([np.arange(10), np.arange(10), np.arange(10), 1])                            
>>> a2                                                                                         
array([array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
       array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
       array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), 1], dtype=object)
>>> a2.dtype                                                                                   
dtype('O')

然后,当您尝试从中实例化 pandasDataFrame时,它并不真正知道如何处理它。或者至少,它与你想要的没有关系——它只是创建一个DataFrame类型为“object”的单列(“object”在这里表示通用的 Python 对象,它们的类型不一定是同质的):

>>> df = pd.DataFrame(a2)                                                                      
>>> df                                                                                         
                                0
0  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
2  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
3                               1
>>> df.dtypes                                                                                  
0    object
dtype: object

我认为您想要做的是构建一个DataFrame每个波段一列和一个分类列。为此,没有理由将现有数组传递给np.array()我突出显示的行。您可以将列列表传递给pd.DataFrame(). 在这种情况下,您还应该将初始分类填充到与其他数组长度相同的数组中。例如:

df = pd.DataFrame(np.column_stack((band_r, band_g, np.full(band_r.shape, classification))))

您还可以DataFrame使用命名列创建一个,例如:

df = pd.DataFrame({'band_r': band_r, 'band_g': band_g,
                   'classification': np.full(band_r.shape, classification)})

对于它的价值,这是我编写代码的一种方式。我对其进行了概括,以便它可以与任意数量的波段一起使用——而不是将它们作为简单的文件名列表读取,您可以在 JSON 文件中读取,将波段名称映射到从中读取的文件名。当然,对于您的应用程序,这可能有点过头了。这是未经测试的:

from contextlib import ExitStack

def bands_to_dataframe(bands, initial_classification=1):
    """bands is a dict mapping band names to filenames"""

    columns = {}

    with ExitStack() as stack:
        for band_name, filename in bands.items():
            hdulist = stack.enter_context(fits.open(filename))
            columns[band_name] = hdulist[0].data

        # use the shape of the first band data array to determine the
        # correct shape of the classification column; this assumes all
        # the arrays are the same size
        first_band = next(iter(columns.values())
        columns['classification'] = np.full(first_band.shape, initial_classification)

        return pd.DataFrame(columns)

将它们捆绑在一起,给定一个 JSON 文件,例如:

{
    "band_r": "path/to/band_r.fits",
    "band_g": "path/to/band_g.fits",
    "band_i": "path/to/band_i.fits",
    ... and so on ...
}

你可以做:

def band_files_to_csv(bands_filename, output_filename=None):
    """Read band filenames from a JSON file structured as above and
    write their data to a TSV file.
    """

    if output_filename is None:
        base_filename, _ = os.path.splitext(bands_filename)
        output_filename = base_filename + '.tsv'

    with open(bands_filename) as fobj:
        bands = json.load(fobj)

    df = bands_to_dataframe(bands)
    df.to_csv(output_filename, header=None, index=None, sep='\t', mode='a')

类似的东西。我不确定您为什么要将此数据保存到 CSV/TSV;取决于您的应用程序是什么,它可能不是最有效的格式(例如,您可能需要考虑二进制格式)。但是任何适合您的应用程序的方法。

更新:如果您总是使用相同的波段序列,(u, g, r, i, z)并且只想读取文件列表而不使用 JSON,您也可以执行以下操作(大部分相同):

DEFAULT_BANDS = ('u', 'g', 'r', 'i', 'z')


def bands_to_dataframe(bands, band_names=DEFAULT_BANDS,
                       initial_classification=1):
    """bands is a list of filenames"""

    if len(bands) != len(band_names):
        raise ValueError(
           f'number of filenames ({len(bands)}) does not match the '
           f'number of band names: {band_names}')

    columns = {}

    with ExitStack() as stack:
        for band_name, filename in zip(band_names, bands):
            hdulist = stack.enter_context(fits.open(filename))
            columns['band_' + band_name] = hdulist[0].data

        # ... the rest is the same as previous version ...

然后

from contextlib import nullcontext


def bands_file_to_csv(filename_or_obj, band_names=DEFAULT_BANDS,
                      output_filename_or_obj=None):
    """Here filename_or_obj can be a filename, or an already open
    file-like object."""

    if isinstance(filename_or_obj, (str, pathlib.Path)):
        input_file = open(filename_or_obj)
    else:
        input_file = nullcontext(filename_or_obj)

    if output_filename_or_obj is None:
        if isinstance(filename_or_obj, (str, pathlib.Path)):
            base_filename, _ = os.path.splitext(filename_or_obj)
            output_filename = base_filename + '.tsv'
            output_ctx = open(output_filename, 'a')
        else:
            raise ValueError(
                'output_filename_or_obj is required if input filename '
                'is not a string or path')
    else:
        if isinstance(output_filename_or_obj, (str, pathlib.Path)):
            output_ctx = open(output_filename_or_obj, 'a')
        else:
            output_ctx = nullcontext(output_filename_or_obj)

    # read all non-empty lines from the file
    # this uses the Python 3.8 walrus operator to avoid
    # calling l.strip() twice, but you could do this other
    # ways for older Python versions
    bands = [ll for l in input_file if (ll := l.strip())]

    df = bands_to_dataframe(bands, band_names=band_names)

    with output_ctx as output_file:
        df.to_csv(output_file, header=None, index=None, sep='\t')

将其作为一个函数实现而不对输入文件名进行硬编码是非常有用的。它更可重用,并且有助于编写更好的脚本(例如,将输入文件名作为参数)。您还可以编写一个脚本,从标准输入中读取文件名列表,例如:

# myscript.py
import sys
from contextlib import nullcontext

# output to stdout by default
output_file = sys.stdout

if len(sys.argv) > 1:
    bands_file = sys.argv[1]
    ctx = open(bands_file)
    # If a second argument is provided it can be the output filename,
    # otherwise by default it outputs to stdout
    if len(sys.argv) == 3:
        output_file = open(sys.argv[2])
else:
    # If no filename is given as an argument, read from stdin
    # and output to stdout
    ctx = nullcontext(sys.stdin)

with ctx as input_file:
    bands_file_to_csv(input_file, output_filename_or_obj=output_file)

您可以通过几种不同的方式调用此脚本。给定一个输入文件名:

$ ./myscript.py inputs.txt > bands.tsv

它会将其输出写入bands.tsv(没有> bands.tsv它只会写入您可能不希望用于大文件的屏幕)。

或者您可以传递一个输出文件名,例如:

$ ./myscript.py inputs.txt > bands.tsv

最后,您可以将输入传递给标准输入并输出到文件:

$ cat inputs.txt | ./myscript.py > bands.tsv

通过这种方式,您可以构建脚本管道。


推荐阅读