I wrote the function which cutting geotiff image by bounding box.
First image is original. At second you can see result off my code. I don't use gdalwarp or any console utilities. But I have no idea how to cut by geojson file. Also I can use only GDAL and numpy modules.
Here is my code:
import os, sys
from osgeo import gdal,gdalconst,osr
def cut_by_bounding_box(min_x, max_x, min_y, max_y):
xValues = [min_x, max_x]
yValues = [min_y, max_y]
#
# Register Imagine driver and open file
#
driver = gdal.GetDriverByName('GTiff')
driver.Register()
dataset = gdal.Open(filename)
if dataset is None:
print 'Could not open ' + filename
sys.exit(1)
#
# Getting image dimensions
#
cols = dataset.RasterXSize
rows = dataset.RasterYSize
bands = dataset.RasterCount
#
# Getting georeference info
#
transform = dataset.GetGeoTransform()
xOrigin = transform[0]
yOrigin = transform[3]
pixelWidth = transform[1]
pixelHeight = -transform[5]
#
# Computing Point1(i1,j1), Point2(i2,j2)
#
i1 = int((xValues[0] - xOrigin) / pixelWidth)
j1 = int((yOrigin - yValues[0]) / pixelHeight)
i2 = int((xValues[1] - xOrigin) / pixelWidth)
j2 = int((yOrigin - yValues[1]) / pixelHeight)
new_cols = i2 - i1 + 1
new_rows = j2 - j1 + 1
#
# Create list to store band data in
#
band_list = []
#
# Read in bands and store all the data in bandList
#
for i in range(bands):
band = dataset.GetRasterBand(i+1) # 1-based index
data = band.ReadAsArray(i1, j1, new_cols, new_rows)
band_list.append(data)
new_x = xOrigin + i1 * pixelWidth
new_y = yOrigin - j1 * pixelHeight
new_transform = (new_x, transform[1], transform[2], new_y, transform[4], transform[5])
#
# Create gtif file
#
driver = gdal.GetDriverByName("GTiff")
dst_ds = driver.Create(output_file, new_cols, new_rows, 3, gdal.GDT_Byte)
#
# Writting output raster
#
for j in range(bands):
data = band_list[j]
dst_ds.GetRasterBand(j+1).WriteArray(data)
#
# Setting extension of output raster
#
dst_ds.SetGeoTransform(new_transform)
wkt = dataset.GetProjection()
#
# Setting spatial reference of output raster
#
srs = osr.SpatialReference()
srs.ImportFromWkt(wkt)
dst_ds.SetProjection( srs.ExportToWkt() )
#
# Close output raster dataset
#
dataset = None
dst_ds = None
if __name__ == '__main__':
# Imput/output file name and set directory
os.chdir('/home/sant/test/satellite_images')
filename = '20160501.tif'
output_file = '/home/sant/test/20160501_cutted_by_bounding_box.tif'
cut_by_bounding_box(531961.73, 535987.34, 4894164.57, 4888631.61)
print 'cutter.py script done!'
Answer
Here is my own solution. It works for any number of bands, any types of geometry(e.g. multipolygon) and works with images any zones!
import geojson as gj
from osgeo import ogr, osr, gdal
# Enable GDAL/OGR exceptions
gdal.UseExceptions()
# GDAL & OGR memory drivers
GDAL_MEMORY_DRIVER = gdal.GetDriverByName('MEM')
OGR_MEMORY_DRIVER = ogr.GetDriverByName('Memory')
def cut_by_geojson(input_file, output_file, shape_geojson):
# Get coords for bounding box
x, y = zip(*gj.utils.coords(gj.loads(shape_geojson)))
min_x, max_x, min_y, max_y = min(x), max(x), min(y), max(y)
# Open original data as read only
dataset = gdal.Open(input_file, gdal.GA_ReadOnly)
bands = dataset.RasterCount
# Getting georeference info
transform = dataset.GetGeoTransform()
projection = dataset.GetProjection()
xOrigin = transform[0]
yOrigin = transform[3]
pixelWidth = transform[1]
pixelHeight = -transform[5]
# Getting spatial reference of input raster
srs = osr.SpatialReference()
srs.ImportFromWkt(projection)
# WGS84 projection reference
OSR_WGS84_REF = osr.SpatialReference()
OSR_WGS84_REF.ImportFromEPSG(4326)
# OSR transformation
wgs84_to_image_trasformation = osr.CoordinateTransformation(OSR_WGS84_REF,
srs)
XYmin = wgs84_to_image_trasformation.TransformPoint(min_x, max_y)
XYmax = wgs84_to_image_trasformation.TransformPoint(max_x, min_y)
# Computing Point1(i1,j1), Point2(i2,j2)
i1 = int((XYmin[0] - xOrigin) / pixelWidth)
j1 = int((yOrigin - XYmin[1]) / pixelHeight)
i2 = int((XYmax[0] - xOrigin) / pixelWidth)
j2 = int((yOrigin - XYmax[1]) / pixelHeight)
new_cols = i2 - i1 + 1
new_rows = j2 - j1 + 1
# New upper-left X,Y values
new_x = xOrigin + i1 * pixelWidth
new_y = yOrigin - j1 * pixelHeight
new_transform = (new_x, transform[1], transform[2], new_y, transform[4],
transform[5])
wkt_geom = ogr.CreateGeometryFromJson(str(shape_geojson))
wkt_geom.Transform(wgs84_to_image_trasformation)
target_ds = GDAL_MEMORY_DRIVER.Create('', new_cols, new_rows, 1,
gdal.GDT_Byte)
target_ds.SetGeoTransform(new_transform)
target_ds.SetProjection(projection)
# Create a memory layer to rasterize from.
ogr_dataset = OGR_MEMORY_DRIVER.CreateDataSource('shapemask')
ogr_layer = ogr_dataset.CreateLayer('shapemask', srs=srs)
ogr_feature = ogr.Feature(ogr_layer.GetLayerDefn())
ogr_feature.SetGeometryDirectly(ogr.Geometry(wkt=wkt_geom.ExportToWkt()))
ogr_layer.CreateFeature(ogr_feature)
gdal.RasterizeLayer(target_ds, [1], ogr_layer, burn_values=[1],
options=["ALL_TOUCHED=TRUE"])
# Create output file
driver = gdal.GetDriverByName('GTiff')
outds = driver.Create(output_file, new_cols, new_rows, bands,
gdal.GDT_Float32)
# Read in bands and store all the data in bandList
mask_array = target_ds.GetRasterBand(1).ReadAsArray()
band_list = []
for i in range(bands):
band_list.append(dataset.GetRasterBand(i + 1).ReadAsArray(i1, j1,
new_cols, new_rows))
for j in range(bands):
data = np.where(mask_array == 1, band_list[j], mask_array)
outds.GetRasterBand(j + 1).SetNoDataValue(0)
outds.GetRasterBand(j + 1).WriteArray(data)
outds.SetProjection(projection)
outds.SetGeoTransform(new_transform)
target_ds = None
dataset = None
outds = None
ogr_dataset = None
No comments:
Post a Comment