from osgeo import gdal, osr
import math
from shapely.geometry import Polygon
from pyproj import Proj, transform
import warnings
warnings.filterwarnings(action='ignore')
from rasterio.transform import Affine
import rasterio
from shapely.geometry import Polygon
from typing import Tuple

def get_boundary(ds: gdal):
    p0 = [0,0]
    p1 = [ds.RasterXSize,ds.RasterYSize]
    geo_pts_0 = [gdal.ApplyGeoTransform(ds.GetGeoTransform(), float(x), float(y)) for x, y in [(p0[0],p0[1])]][0]
    geo_pts_1 = [gdal.ApplyGeoTransform(ds.GetGeoTransform(), float(x), float(y)) for x, y in [(p1[0],p1[1])]][0]
    polygon = Polygon(
        [
            (geo_pts_0[0], geo_pts_0[1]), 
            (geo_pts_1[0], geo_pts_0[1]),
            (geo_pts_1[0], geo_pts_1[1]),
            (geo_pts_0[0], geo_pts_1[1])
        ]
    )
    return polygon

def get_intersection(poly1: Polygon, poly2: Polygon) -> Polygon:
    isIntersection = poly1.intersection(poly2)
    return isIntersection

def get_min_max_lat_lng(poly: Polygon) -> Tuple[str, str, str, str]:
    x, y = poly.exterior.coords.xy
    max_x, max_y, min_x, min_y = max(x), max(y), min(x), min(y)
    return max_x, max_y, min_x, min_y

def get_crs(ds: gdal):
    prj=ds.GetProjection()
    srs = osr.SpatialReference(wkt=prj)
    crs = 'EPSG:'+str(srs.GetAttrValue('AUTHORITY',1))
    return crs
    
    
def change_polygon(poly_intersect, src_crs, dst_crs):
    changed_poly_intersect = list()
    poly_intersect_list = list(zip(*poly_intersect.exterior.coords.xy))
    inProj = Proj(init=src_crs)
    outProj = Proj(init=dst_crs)
    for x1, y1 in poly_intersect_list:
        x2, y2 = transform(inProj,outProj, x1, y1)
        changed_poly_intersect.append((x2, y2))
    
    return Polygon(changed_poly_intersect)
    
    
def get_geo_transform(ds, x, y):
    inv_geo_transform = gdal.InvGeoTransform(ds.GetGeoTransform())
    p = gdal.ApplyGeoTransform(inv_geo_transform, x, y)
    return p
    

def crop_scene(ds, dst, crs, p0, p1):
    Z = ds.ReadAsArray(int(p0[0]), int(p0[1]), math.ceil(p1[0] - p0[0]), math.ceil(p1[1] - p0[1]))
    geo_pts = [gdal.ApplyGeoTransform(ds.GetGeoTransform(), float(x), float(y)) for x, y in [(int(p0[0]),int(p0[1]))]][0]
    transform = Affine(ds.GetGeoTransform()[1], ds.GetGeoTransform()[2], geo_pts[0], ds.GetGeoTransform()[4],ds.GetGeoTransform()[5], geo_pts[1])

    with rasterio.open(dst,'w', driver='GTiff',height=Z.shape[1],width=Z.shape[2],count=Z.shape[0],dtype=Z.dtype,crs=crs,transform=transform) as d:
        for i in range(len(Z)):
            d.write(Z[i], i+1)

def main():
    ds = gdal.Open(src)
    ds_intersect = gdal.Open(src_intersect)
    poly = get_boundary(ds)
    poly_intersect = get_boundary(ds_intersect)
    crs = get_crs(ds)
    crs_intersect = get_crs(ds_intersect) 
    poly_intersect_new = change_polygon(poly_intersect, crs_intersect, crs)
    isIntersection = get_intersection(poly, poly_intersect_new)
    max_x, max_y, min_x, min_y = get_min_max_lat_lng(isIntersection)
    x0, y0, x1, y1 = min_x, max_y, max_x, min_y
    p0 = get_geo_transform(ds, x0, y0)
    p1 = get_geo_transform(ds, x1, y1)
    crop_scene(ds, dst, crs, p0, p1)
    

if __name__ == '__main__':
    '''
    src : crop될 영상
    src_intersect : intersect용 영상
    '''
    src = '/nas/Dataset/IndusRiver/scenes/WV3_20200524_062425_RGB_PS.TIF'
    src_intersect = '/nas/Dataset/IndusRiver/scenes/WV2_20220905_063224_RGB_PS.TIF'
    dst = src.replace('RGB_PS', 'RGB_PS_cropped2')

    main()