首页 > 解决方案 > sklearn 的 2D lat/lon KernelDensity Estimator

问题描述

使用 sklearn.neighbors 的 KernelDensity,我得到的密度值比我预期的要小得多。密度估计值大约是我预期的 1/200。

我已经查看了 sklearn 的Kernel Density Estimate of Species Distribution,我沿着将输入的纬度/经度数据转换为弧度并使用半正弦距离度量的路径走下去,但我得到了奇怪的结果。

我已经考虑了很多,这里是对我来说最有意义的参数。

问题

  1. 这似乎是解决这个问题的合理方法吗?
  2. 为什么密度值比我预期的要小得多?

这是我的函数和我传递给它的参数。

import rasterio
from rasterio.crs import CRS
from sklearn.neighbors import KernelDensity
import numpy as np

def kernel_density_lat_lon(positions, bandwidth, metric, kernel,
                            cell_size, extent, output_raster, multiplier=None):

# Set the bounds of the output raster based on the extent
x_min = extent[0]
x_max = extent[1]
y_min = extent[2]
y_max = extent[3]

# Create arrays, based on cell_size and bounds
# These arrays hold x locations and y locations for each pixel in the output raster
x = np.arange(x_min, x_max, cell_size)
y = np.arange(y_min, y_max, cell_size)

# Create a meshgrid, which has cells whose values are the (x,y) location at each cell
xx, yy = np.meshgrid(x, y)

# Pair the x locations with y locations
xys = np.vstack((xx.ravel(), yy.ravel())).T

# Create a density map
x_shape = xx.shape

# Get the kernel density estimator
kde = KernelDensity(bandwidth=bandwidth,metric=metric,
                    kernel=kernel, algorithm='ball_tree')

# Fit it to the coordinate pairs
_ = kde.fit(positions)

# Evaluate
z = np.exp(kde.score_samples(xys))
print(np.max(z))

zi = np.arange(xys.shape[0])

# Plug densities into grid
zg = -9999 + np.zeros(xys.shape[0])
zg[zi] = z
xyz = np.hstack((xys[:, :2], zg[:, None]))

# Get the density values arranged on the grid
z = xyz[:, 2].reshape(x_shape)
temp = z[::-1, :]

output_arr = temp.reshape(-1, temp.shape[0], temp.shape[1])

# Write the densities to a raster
with rasterio.open(
        output_raster,
        'w',
        driver='GTiff',
        height=output_arr.shape[1],
        width=output_arr.shape[2],
        dtype=output_arr.dtype,
        crs=CRS.from_epsg(4326),
        count=1,
        transform=rasterio.transform.from_bounds(x_min, y_min, x_max, y_max, output_arr.shape[2], output_arr.shape[1])
) as dst:
    dst.write(output_arr)

if __name__ == "__main__":
    positions = [[126.82800884821953, 8.021550450814345],
 [123.0835913004416, 15.887493017360754],
 [122.87172138544588, 15.155979776107289],
 [122.48465193221716, 15.233649683534475],
 [122.26320643954872, 16.71625103407011],
 [122.13275884500477, 15.941644592949958],
 [120.63772441542471, 7.078277119741588],
 [120.57180822188472, 7.537689414917545],
 [119.53047809084589, 1.396741864447578],
 [119.51652407635684, 1.7028166423529711],
 [119.35538543402562, 7.795232293743844],
 [119.35371605376332, 1.7139590065581176],
 [118.21983976700818, 0.2725608428591114],
 [116.32507063966972, -2.0478066628388163],
 [115.9455871941716, -2.2758686356158915],
 [110.54879990595637, 4.849182291868757],
 [109.00373897612512, 12.330559666134512],
 [108.56317006080423, 23.10356852435795],
 [107.95374212609899, -3.878293744564539],
 [107.6618148392204, -4.215545933851648],
 [107.39598092145678, -3.3557991558597426],
 [107.38347877309276, -4.243848824653475],
 [107.3802332907293, -4.724984303635246],
 [106.92298020128571, 3.3377440975999058],
 [106.8467663232349, -3.427384435159751],
 [106.6198566766759, 3.327030211530555],
 [106.59035576911651, 3.409433089119516],
 [106.48649132403538, 3.5936300679047966],
 [106.2879431146126, 3.039670857739856],
 [105.96323043582797, 2.5103916023650656],
 [105.9540323861389, 2.596746532847891],
 [105.80111748849575, 3.388380151516756],
 [105.62119198430719, 3.2169296961449554],
 [105.43276377101233, 2.6840109661437204],
 [105.29236334314527, 2.420170430982717],
 [104.94141265184744, 3.091707354213681],
 [103.08902291491331, 3.1932135322924133],
 [102.59488296531873, 14.93503092216549],
 [100.7213889691745, 5.834246665586201],
 [100.70491932538964, 5.2594820067014245],
 [100.51665775078591, 6.0369426594605855],
 [100.51156199546038, 5.491942119998682],
 [100.45311457726862, 5.281343969279209],
 [99.984116880471, 5.658350660638604],
 [93.51170627287425, 24.024373245961645],
 [93.34991893283902, 23.04050533807432],
 [84.93884193888668, 19.384547030288207],
 [84.30999142795147, 18.825326243832105],
 [84.1630944193751, 19.06013889689632],
 [83.80094785724114, 18.57306909774846],
 [74.16321921976069, 23.579347585345776],
 [72.4113965790803, 21.875517403679595],
 [49.40472412468231, 32.2487630729451],
 [42.90510332039255, -12.821849510976579],
 [42.408207428324495, -12.31050970009727],
 [42.36825610793828, -13.083052941231413],
 [42.30285486383656, -12.234780003717532],
 [15.328057669295298, -7.460883355600632],
 [14.631592099379093, -7.440778982157976],
 [14.563929300312948, -7.140268202440664],
 [14.446656807020666, -4.699494598106393],
 [14.188788859460905, -6.430418645148537],
 [13.44490187975298, -2.8654279482460323],
 [13.301089335672936, -2.593387816196834],
 [13.131727857324034, -3.412434046655619],
 [11.637624067618695, 5.306602656962694],
 [11.537324701566494, 1.5773310360579327],
 [11.056051828014489, 5.372994263069668],
 [10.981944105212998, 6.05789466930291],
 [10.978615683124655, 5.7586879077143225],
 [10.384229532923067, 2.6509917300959476],
 [10.293978958054748, 5.6087142487617045],
 [9.724503564938162, 5.965801337392755],
 [9.228154036572047, 6.4564328707855605],
 [8.847083818460739, 4.696640992862242],
 [8.724622829999017, 5.5476494764785516],
 [8.483278678008926, 6.612624047942372],
 [8.44366045716664, 6.2122982089038725],
 [8.4255624128847, 4.755664077859387],
 [8.11860899795907, 5.659724263701104],
 [7.912362077517271, 4.87480562915889],
 [7.563449250527216, 2.842579773546474],
 [7.2608575851074, 5.16577516485171],
 [7.004069229900638, 3.5416918941072804],
 [6.9915716303567494, 3.7362296571866294],
 [6.468876406999725, 5.010859767233725],
 [6.203147917904825, 4.992482439632923],
 [5.4017709770599325, 7.676092696459705],
 [5.350100368207385, 7.762605113995827],
 [5.279221956366327, 4.915935839020336],
 [5.213104554080347, 8.281676925077297],
 [5.1108484406102805, 7.9040681892696485],
 [5.059337403465768, 8.140534352024792],
 [4.861618772269268, 8.322655646328752],
 [4.80376638793241, 8.062341031849334],
 [4.665446704573248, 7.477404025788393],
 [4.6477402888853145, 7.797020093234158],
 [4.609044098910636, 38.765860093618905],
 [4.555126307535386, 7.873929016757312],
 [4.4195324599539845, 7.394848626095032],
 [4.400283930670644, 8.038284539940614],
 [4.347819621721147, 8.443859742876246],
 [4.240704264765369, 6.955830447603886],
 [4.227870824209585, 5.751072313355475],
 [4.033821062618696, 7.0740805209122595],
 [3.665972118522844, 6.545536856751896],
 [3.4165849005141005, 7.191717476638518],
 [3.121450235674562, 8.103710628355616],
 [1.8057346437941182, 1.3314371195302515],
 [0.21998421850813876, 6.744306925430884],
 [-12.310298533627448, 11.362835062050264],
 [-49.352317054841336, 2.010101652464972],
 [-49.56587070660965, 1.366869361066606],
 [-49.5821860267535, 1.824258170311353],
 [-70.58665807820438, 20.03257364630837],
 [-70.6803277335339, 19.902301232265422],
 [-70.78620439744233, 20.024999949922996],
 [-70.86459827149523, 20.273742251629713],
 [-71.02033226779315, 19.891866165854587],
 [-73.57317798569044, 12.265930473198331],
 [-75.32300214385347, -10.734649751468147],
 [-75.36631826293349, -10.206201123969526],
 [-75.37463804230384, -10.724232696199014],
 [-75.40829227919468, -10.817431611704407],
 [-75.46984739081694, -10.195876463554633],
 [-75.56266706716431, -10.202240256127965],
 [-75.74233061116121, -10.647556252995775],
 [-75.90503122834087, -10.297561312609464],
 [-75.94114020328095, -10.530481915516726],
 [-78.13302896559648, -1.2629721839381856],
 [-78.42506520505198, -0.6805387090496724],
 [-78.68351568134375, -1.1006283268898114],
 [-79.09221180056895, -1.5423219306900116],
 [-90.05839881111541, 21.022199691388156],
 [-91.3208074507767, 20.58263399988673],
 [-91.86906142999138, 20.169783366358622],
 [-91.89838954465436, 20.49386425203851]]
bandwidth = 1.0
cell_size = 0.1
extent = [-180, 180, -90, 90]
metric = "euclidean"
kernel = "gaussian"
output_raster = metric + "_" + kernel + "_" + str(bandwidth).split(".")[0] + ".tif"

# The parameters that I think should do the trick
kernel_density_lat_lon(positions, bandwidth, metric, kernel, cell_size,
                       extent, output_raster)

# The parameters that get me closest to the desired output
# This requires multiplying all of the density probabilities by 205...
bandwidth = 1.0
cell_size = 0.1
extent = [-180, 180, -90, 90]
metric = "euclidean"
kernel = "epanechnikov"
multiplier = 205
output_raster = metric + "_" + kernel + "_" + str(bandwidth).split(".")[0] + ".tif"
kernel_density_lat_lon(positions, bandwidth, metric, kernel, cell_size,
                       extent, output_raster, multiplier)

我已经考虑了很多,我很难理解为什么密度估计值低于我的预期。谢谢你的帮助。

标签: pythonscikit-learngiskernel-density

解决方案


推荐阅读