首页 > 技术文章 > GDAL常用代码

skypanxh 2021-09-01 14:39 原文

GDAL常用代码

1.导入数据

from osgeo import gdal
import numpy as np

def LoadData(filename):
    file = gdal.Open(filename)
    if file == None:
        print(filename + " can't be opened!")
        return
    nb = file.RasterCount

    L = []
    for i in range(1, nb + 1):
        band = file.GetRasterBand(i)
        background = band.GetNoDataValue()
        data = band.ReadAsArray()
        data = data.astype(np.float32)
        index = np.where(data == background)
        data[index] = 0
        L.append(data)
    data = np.stack(L,0)
    if nb == 1:
        data = data[0,:,:]
    return data

或者

import xarray as xr
arr = xr.open_rasterio("路径").data[0,:,:]

2.写出数据

def WriteTiff(im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape
        # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset

raster = gdal.Open(path)
im_width = raster.RasterXSize #栅格矩阵的列数
im_height = raster.RasterYSize #栅格矩阵的行数
im_bands = raster.RasterCount #波段数
im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
im_proj = raster.GetProjection()#获取投影信息
ResultPath = "路径"
WriteTiff(arr, im_width, im_height, im_bands, im_geotrans, im_proj, ResultPath) 

或者

def WriteTiff(im_data,inputdir, path):
    raster = gdal.Open(inputdir)
    im_width = raster.RasterXSize #栅格矩阵的列数
    im_height = raster.RasterYSize #栅格矩阵的行数
    im_bands = raster.RasterCount #波段数
    im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
    im_proj = raster.GetProjection()#获取投影信息

    
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape
        # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset
WriteTiff(im_data,inputdir, path)

或者

def WriteTiff(im_data,inputdir, path):
    raster = gdal.Open(inputdir)
    im_width = raster.RasterXSize #栅格矩阵的列数
    im_height = raster.RasterYSize #栅格矩阵的行数
    im_bands = raster.RasterCount #波段数
    im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
    im_proj = raster.GetProjection()#获取投影信息

    
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape
        # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        # 获取地理坐标系统信息,用于选取需要的地理坐标系统
        if im_proj == "":
            # 如果没有坐标系就用WGS-84
            sr = osr.SpatialReference()
            sr.SetWellKnownGeogCS('WGS84')
            dataset.SetProjection(sr.ExportToWkt()) 
        else:
            dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset

3.影像拼接

把要合并的多个tif放在path路径下的文件夹file1、 file2...filen,每个file下文件数量名字都相同

from osgeo import gdal
import math

def GetExtent(in_fn):
    ds=gdal.Open(in_fn)
    geotrans=list(ds.GetGeoTransform())
    xsize=ds.RasterXSize 
    ysize=ds.RasterYSize
    min_x=geotrans[0]
    max_y=geotrans[3]
    max_x=geotrans[0]+xsize*geotrans[1]
    min_y=geotrans[3]+ysize*geotrans[5]
    ds=None
    return min_x,max_y,max_x,min_y

def mosaic(in_files,output_name,arr_files):
    os.chdir(in_files)
    in_files = os.listdir(in_files)
    in_fn=in_files[0]
    #获取待镶嵌栅格的最大最小的坐标值
    min_x,max_y,max_x,min_y=GetExtent(in_fn)
    for in_fn in in_files[1:]:
        minx,maxy,maxx,miny=GetExtent(in_fn)
        min_x=min(min_x,minx)
        min_y=min(min_y,miny)
        max_x=max(max_x,maxx)
        max_y=max(max_y,maxy)
    #计算镶嵌后影像的行列号
    in_ds=gdal.Open(in_files[0])
    geotrans=list(in_ds.GetGeoTransform())
    width=geotrans[1]
    height=geotrans[5]
    
    columns=math.ceil((max_x-min_x)/width)
    rows=math.ceil((max_y-min_y)/(-height))
    in_band=in_ds.GetRasterBand(1)
    
    driver=gdal.GetDriverByName('GTiff')
    
    out_ds=driver.Create(output_name,columns,rows,1,in_band.DataType)
    out_ds.SetProjection(in_ds.GetProjection())
    geotrans[0]=min_x
    geotrans[3]=max_y
    out_ds.SetGeoTransform(geotrans)
    out_band=out_ds.GetRasterBand(1)
    #定义仿射逆变换
    inv_geotrans=gdal.InvGeoTransform(geotrans)
    #开始逐渐写入
    for in_fn in in_files:
        in_ds=gdal.Open(in_fn)
        in_gt=in_ds.GetGeoTransform()
        #仿射逆变换
        offset=gdal.ApplyGeoTransform(inv_geotrans,in_gt[0],in_gt[3])
        x,y=map(int,offset)
        # print(x,y)
        trans=gdal.Transformer(in_ds,out_ds,[])#in_ds是源栅格,out_ds是目标栅格
        success,xyz=trans.TransformPoint(False,0,0)#计算in_ds中左上角像元对应out_ds中的行列号
        x,y,z=map(int,xyz)
        # print(x,y,z)
        data=in_ds.GetRasterBand(1).ReadAsArray()
        out_band.WriteArray(data,x,y)#x,y是开始写入时左上角像元行列号
    del in_ds,out_band,out_ds
    
in_files = 要合并的tif存放的路径
output_name = 输出的tif名称

4.开闭运算去除小斑块

# -*- coding: utf-8 -*-
"""
Created on Tue Oct  5 12:53:34 2021

@author: Xhpan
"""
from osgeo import gdal 
import xarray as xr

from skimage import morphology as sm
import numpy as np
import os

def WriteTiff(im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape
        # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset
    
def getBoundary(filename,urbanID,kernel,ResultPath):
    raster = xr.open_rasterio(filename).data[0,:,:]
    index1 = np.where(raster != urbanID)
    index2 = np.where(raster == urbanID)
    raster[index1] = False
    raster[index2] = True
    img_close = sm.closing(raster, kernel)
    img_open = sm.opening(img_close, kernel)
    raster = gdal.Open(filename)
    im_width = raster.RasterXSize #栅格矩阵的列数
    im_height = raster.RasterYSize #栅格矩阵的行数
    im_bands = raster.RasterCount #波段数
    im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
    im_proj = raster.GetProjection()#获取投影信息
    WriteTiff(img_open, im_width, im_height, im_bands, im_geotrans, im_proj, ResultPath)
    
# 获取某目录下所有tif文件
def getTiffFileName(filepath, suffix):
    L1 = []
    L2 = []
    for root, dirs, files in os.walk(filepath):  # 遍历该文件夹
        for file in files:  # 遍历刚获得的文件名files
            (filename, extension) = os.path.splitext(file)  # 将文件名拆分为文件名与后缀
            if (extension == suffix):  # 判断该后缀是否为.c文件
                L1.append(filepath + "/" + file)
                L2.append(filename)
    return L1, L2

urbanID = 1
filepath =  r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017"
kernel = sm.octagon(2, 1)
inputPathFiles, inputNames = getTiffFileName(filepath, ".tif")

for name in inputNames:
    filename = filepath + "/" + name + ".tif"
    ResultPath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017_1" + "/" + name + ".tif"
    getBoundary(filename,urbanID,kernel,ResultPath)
    print(filename)

5.时间序列订正

# -*- coding: utf-8 -*-
"""
Created on Tue Sep 28 13:17:24 2021

@author: Xhpan
"""

import numpy as np
from osgeo import gdal

import xarray as xr
import os

def WriteTiff(im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape
        # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
    if (dataset != None):
        dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
        dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset
    
# 获取某目录下所有tif文件
def getTiffFileName(filepath, suffix):
    L1 = []
    L2 = []
    for root, dirs, files in os.walk(filepath):  # 遍历该文件夹
        for file in files:  # 遍历刚获得的文件名files
            (filename, extension) = os.path.splitext(file)  # 将文件名拆分为文件名与后缀
            if (extension == suffix):  # 判断该后缀是否为.c文件
                L1.append(filepath + "/" + file)
                L2.append(filename)
    return L1, L2

def timeSeriesCorrection(filepath,outputpath):
    if not os.path.exists(outputpath):
        os.makedirs(outputpath)
    inputPathFiles, inputNames = getTiffFileName(filepath, ".tif")
    raster = gdal.Open(filepath + "/" + str(inputNames[0]) + ".tif")
    im_width = raster.RasterXSize #栅格矩阵的列数
    im_height = raster.RasterYSize #栅格矩阵的行数
    im_bands = raster.RasterCount #波段数
    im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
    im_proj = raster.GetProjection()#获取投影信息
    for i in range(len(inputNames)-1):
        if i == 0:
            arr1 = xr.open_rasterio(filepath + "/" + str(inputNames[i]) + ".tif").data[0,:,:]
            arr2 = xr.open_rasterio(filepath + "/" + str(inputNames[i + 1]) + ".tif").data[0,:,:]
            arr3 = arr1 + arr2
            arr3[np.where(arr3 == 2)] = 1
        else:
            arr1 = arr3
            arr2 = xr.open_rasterio(filepath + "/" + str(inputNames[i + 1]) + ".tif").data[0,:,:]
            arr3 = arr1 + arr2
            arr3[np.where(arr3 == 2)] = 1
        
        ResultPath = outputpath + "/" + str(inputNames[i + 1]) + ".tif"
        WriteTiff(arr3, im_width, im_height, im_bands, im_geotrans, im_proj, ResultPath)
        print(ResultPath)

filepath =  r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017_1"
outputpath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017_2"
timeSeriesCorrection(filepath,outputpath)

6.统一图像的行列号

# -*- coding: utf-8 -*-
"""
Created on Wed Oct  6 14:26:52 2021

@author: Xhpan
"""

from osgeo import gdal
import math

def GetExtent(in_fn):
    ds=gdal.Open(in_fn)
    geotrans=list(ds.GetGeoTransform())
    xsize=ds.RasterXSize 
    ysize=ds.RasterYSize
    min_x=geotrans[0]
    max_y=geotrans[3]
    max_x=geotrans[0]+xsize*geotrans[1]
    min_y=geotrans[3]+ysize*geotrans[5]
    ds=None
    return min_x,max_y,max_x,min_y

def UnifiedLineNumber(in_fn,criterion_fn,output_name):
    in_ds=gdal.Open(criterion_fn)
    geotrans=list(in_ds.GetGeoTransform())
    width=geotrans[1]
    height=geotrans[5]
    
    # 计算输出图像的行列号
    min_x,max_y,max_x,min_y = GetExtent(criterion_fn)
    columns=math.ceil((max_x-min_x)/width)
    rows=math.ceil((max_y-min_y)/(-height))
    in_band=in_ds.GetRasterBand(1)
    
    driver=gdal.GetDriverByName('GTiff')
    out_ds=driver.Create(output_name,columns,rows,1,in_band.DataType)
    out_ds.SetProjection(in_ds.GetProjection())
    
    # 计算原图像在新图像位置
    min_x1,max_y1,max_x1,min_y1 = GetExtent(in_fn)
    geotrans[0]=min_x1
    geotrans[3]=max_y1
    out_ds.SetGeoTransform(geotrans)
    out_band=out_ds.GetRasterBand(1)
    #定义仿射逆变换
    inv_geotrans=gdal.InvGeoTransform(geotrans)
    
    in_ds=gdal.Open(in_fn)
    in_gt=in_ds.GetGeoTransform()
    #仿射逆变换
    offset=gdal.ApplyGeoTransform(inv_geotrans,in_gt[0],in_gt[3])
    x,y=map(int,offset)
    # print(x,y)
    trans=gdal.Transformer(in_ds,out_ds,[])#in_ds是源栅格,out_ds是目标栅格
    success,xyz=trans.TransformPoint(False,0,0)#计算in_ds中左上角像元对应out_ds中的行列号
    x,y,z=map(int,xyz)
    # print(x,y,z)
    
    data=in_ds.GetRasterBand(1).ReadAsArray()
    out_band.WriteArray(data,x,y)#x,y是开始写入时左上角像元行列号
    del in_ds,out_band,out_ds
    
in_fn = "需要处理的tif路径"
criterion_fn = "标准的tif路径"
output_name = "输出tif路径"
UnifiedLineNumber(in_fn,criterion_fn,output_name)

7.矢量按位置(相交和相邻)/属性选择

import os
from osgeo import ogr
from tqdm import trange

def create_shp_by_layer(shp, layer): # 保存结果shp:文件名。layer:原输入shp
    outputfile = shp
    if os.access(outputfile, os.F_OK):
        driver.DeleteDataSource(outputfile)
    newds = driver.CreateDataSource(outputfile)
    pt_layer = newds.CopyLayer(layer,'')
    # print(shp)
    
def totxt(resultname,L):
    f=open(resultname,"w")
        
    for name in L:
        f.write(str(name)+'\n')  
    f.close()


filename = 'Export_Output.shp'
resultpath = "result"
if not os.path.exists(resultpath):
	os.makedirs(resultpath)
    
resultpath1 = "resultshp"
if not os.path.exists(resultpath1):
	os.makedirs(resultpath1)

driver = ogr.GetDriverByName("ESRI Shapefile")
 # 读入被选择数据(本身)
target_shp =  filename
target_ds = ogr.Open(target_shp)
target_layer = target_ds.GetLayer(0) # 得到第一个layer

source_shp = filename 
source_ds = ogr.Open(source_shp)
source_layer = source_ds.GetLayer(0)  # 得到第一个layer
# 遍历每个数据
for i in trange(source_layer.GetFeatureCount()):
    source_feats = source_layer.GetFeature(i)
    source_id = source_feats.GetField('cyid') # 获取每个面cyid字段值
    poly = source_feats.GetGeometryRef() # 获取该面的范围  
       
    target_layer.SetSpatialFilter(poly) # 选择该面和它相邻的全部要素
    shp = resultpath1 + "/" + str(source_id) + ".shp"
    create_shp_by_layer(shp, target_layer)
    
    
    # 读取输出shp的cyid
    filter_names = []
    result_ds = ogr.Open(shp)
    result_layer = result_ds.GetLayer(0) # 得到第一个layer
    for j,fea in enumerate(result_layer):
        result_feats = result_layer.GetFeature(j)
        result_id = result_feats.GetField('cyid')
        if result_id != source_id:
            filter_names.append(result_id)
            
            
    resultname = resultpath + "/"+ str(source_id)  +".txt"
    totxt(resultname,filter_names)
            
    source_layer.SetSpatialFilter(None)
    source_layer.ResetReading()
    target_layer.SetSpatialFilter(None)
    target_layer.ResetReading()
    result_layer.SetSpatialFilter(None)
    result_layer.ResetReading()

8.根据矢量裁剪栅格

from osgeo import gdal
def clip_raster(in_raster, out_raster, mask_shp):
    """
    :param in_raster: 输入栅格
    :param out_raster: 输出栅格
    :param mask_shp: 裁剪矢量
    :param wkid: wkid
    :return:
    """
    gdal.Warp(out_raster,
              in_raster,
              format='GTiff',
              dstSRS='EPSG:4326',
              cutlineDSName=mask_shp,
              cropToCutline=True,  # 按掩膜图层范围裁剪
              dstNodata=-9999,
              outputType=gdal.GDT_Float64)
    
in_raster = "栅格路径"
out_raster = r"test.tif"
mask_shp =  "矢量路径"
clip_raster(in_raster, out_raster, mask_shp)

9. ArcGIS将shp按照属性字段进行分割为多个polygon矢量

from osgeo import ogr
import os
shpfile = r"输入shp"  
resultpath = r"输出文件"
if not os.path.exists(resultpath):
	os.makedirs(resultpath)

driver = ogr.GetDriverByName("ESRI Shapefile")
ds = ogr.Open(shpfile)
layer = ds.GetLayer(0)
for i in range(layer.GetFeatureCount()):
    source_feats = layer.GetFeature(i)
    source_id = source_feats.GetField('ID')  # 以ID字段命名结果
    layer.SetAttributeFilter("ID = {}".format(source_id))
    
    extfile = resultpath + "/" + str(source_id).zfill(2) + ".shp"
    newds = driver.CreateDataSource(extfile)
    lyrn = newds.CreateLayer('rect', None, ogr.wkbPolygon)
    
    feat = layer.GetNextFeature()
    while feat is not None:
        lyrn.CreateFeature(feat)
        feat = layer.GetNextFeature()
    newds.Destroy()
    print(i)

10.分区统计(多tif批量)

import time
import geopandas as gpd
import rasterio
from rasterstats import zonal_stats
import pandas as pd

from osgeo import gdal
import numpy as np

def LoadData(filename):
    file = gdal.Open(filename)
    if file == None:
        print(filename + " can't be opened!")
        return
    nb = file.RasterCount

    L = []
    for i in range(1, nb + 1):
        band = file.GetRasterBand(i)
        background = band.GetNoDataValue()
        data = band.ReadAsArray()
        data = data.astype(np.float32)
        index = np.where(data == background)
        data[index] = 0
        L.append(data)
    data = np.stack(L,0)
    if nb == 1:
        data = data[0,:,:]
    return data

start = time.time()
shp_path = '../data/urbandata/shp/basin.shp'
stats = ['mean'] # ['min', 'max', 'mean', 'median', 'majority']
shp_driver = gpd.read_file(shp_path)
df = shp_driver['ID'].to_frame()

names = ["Chen","He","Zhou","Hyde","LUH","Gao"]
for name in names:
    ras_path = "../data/urbandata/" + name + ".tif"
    ras_driver = rasterio.open(ras_path)
    array = LoadData(ras_path)
    array[np.where(array == array[0][0])] = 0
    affine = ras_driver.transform
    zs = zonal_stats(shp_path, array, affine=affine, stats = stats)
    values = []
    for i in range(0,len(zs)):
        values.append(zs[i][stats[0]])
    df['{}'.format(name)] = values
    print(name)
df.to_excel("../urba_mean.xlsx")

推荐阅读