from osgeo import gdal, osr
import math
from shapely.geometry import Polygon
from pyproj import Proj, transform
import warnings
warnings.filterwarnings(action='ignore')
from rasterio.transform import Affine
import rasterio
from shapely.geometry import Polygon
from typing import Tuple
def get_boundary(ds: gdal):
p0 = [0,0]
p1 = [ds.RasterXSize,ds.RasterYSize]
geo_pts_0 = [gdal.ApplyGeoTransform(ds.GetGeoTransform(), float(x), float(y)) for x, y in [(p0[0],p0[1])]][0]
geo_pts_1 = [gdal.ApplyGeoTransform(ds.GetGeoTransform(), float(x), float(y)) for x, y in [(p1[0],p1[1])]][0]
polygon = Polygon(
[
(geo_pts_0[0], geo_pts_0[1]),
(geo_pts_1[0], geo_pts_0[1]),
(geo_pts_1[0], geo_pts_1[1]),
(geo_pts_0[0], geo_pts_1[1])
]
)
return polygon
def get_intersection(poly1: Polygon, poly2: Polygon) -> Polygon:
isIntersection = poly1.intersection(poly2)
return isIntersection
def get_min_max_lat_lng(poly: Polygon) -> Tuple[str, str, str, str]:
x, y = poly.exterior.coords.xy
max_x, max_y, min_x, min_y = max(x), max(y), min(x), min(y)
return max_x, max_y, min_x, min_y
def get_crs(ds: gdal):
prj=ds.GetProjection()
srs = osr.SpatialReference(wkt=prj)
crs = 'EPSG:'+str(srs.GetAttrValue('AUTHORITY',1))
return crs
def change_polygon(poly_intersect, src_crs, dst_crs):
changed_poly_intersect = list()
poly_intersect_list = list(zip(*poly_intersect.exterior.coords.xy))
inProj = Proj(init=src_crs)
outProj = Proj(init=dst_crs)
for x1, y1 in poly_intersect_list:
x2, y2 = transform(inProj,outProj, x1, y1)
changed_poly_intersect.append((x2, y2))
return Polygon(changed_poly_intersect)
def get_geo_transform(ds, x, y):
inv_geo_transform = gdal.InvGeoTransform(ds.GetGeoTransform())
p = gdal.ApplyGeoTransform(inv_geo_transform, x, y)
return p
def crop_scene(ds, dst, crs, p0, p1):
Z = ds.ReadAsArray(int(p0[0]), int(p0[1]), math.ceil(p1[0] - p0[0]), math.ceil(p1[1] - p0[1]))
geo_pts = [gdal.ApplyGeoTransform(ds.GetGeoTransform(), float(x), float(y)) for x, y in [(int(p0[0]),int(p0[1]))]][0]
transform = Affine(ds.GetGeoTransform()[1], ds.GetGeoTransform()[2], geo_pts[0], ds.GetGeoTransform()[4],ds.GetGeoTransform()[5], geo_pts[1])
with rasterio.open(dst,'w', driver='GTiff',height=Z.shape[1],width=Z.shape[2],count=Z.shape[0],dtype=Z.dtype,crs=crs,transform=transform) as d:
for i in range(len(Z)):
d.write(Z[i], i+1)
def main():
ds = gdal.Open(src)
ds_intersect = gdal.Open(src_intersect)
poly = get_boundary(ds)
poly_intersect = get_boundary(ds_intersect)
crs = get_crs(ds)
crs_intersect = get_crs(ds_intersect)
poly_intersect_new = change_polygon(poly_intersect, crs_intersect, crs)
isIntersection = get_intersection(poly, poly_intersect_new)
max_x, max_y, min_x, min_y = get_min_max_lat_lng(isIntersection)
x0, y0, x1, y1 = min_x, max_y, max_x, min_y
p0 = get_geo_transform(ds, x0, y0)
p1 = get_geo_transform(ds, x1, y1)
crop_scene(ds, dst, crs, p0, p1)
if __name__ == '__main__':
'''
src : crop될 영상
src_intersect : intersect용 영상
'''
src = '/nas/Dataset/IndusRiver/scenes/WV3_20200524_062425_RGB_PS.TIF'
src_intersect = '/nas/Dataset/IndusRiver/scenes/WV2_20220905_063224_RGB_PS.TIF'
dst = src.replace('RGB_PS', 'RGB_PS_cropped2')
main()