import xarray as xr
import os
import time
import numpy as np
import hydrolib as hlib

##==========================================================
## Data Processing
##==========================================================
def get_latlon2d(ds, lat, lon):
    lats = ds.lat.values
    lons = ds.lon.values
    dist = ((lats - lat)**2 + (lons-lon)**2)**0.5
    i,j = np.unravel_index(np.argmin(dist, axis=None), dist.shape)
    #print("\t[{}, {}]: {:0.5f}, {:0.5f}".format(i, j, lats[i,j], lons[i,j]))
    return (i,j)

def get_data(gcm, styr, edyr, var, metric):
    #gcm = os.path.basename(fpath)
    print("Getting data: {}[{}] - {}:{}".format(gcm, var, styr, edyr))
    start = time.time()
    #idir = '/home/disk/becassine/jswon11/DATA/WRF/{}/{}'.format(gcm, var)
    data = "/home/disk/becassine/jswon11/DATA/WRF-daily/{}/{}_{}.zarr".format(gcm, gcm, var)
    
    # Crop time
    ds = xr.open_mfdataset(data, engine='zarr', parallel=True)
    ds = ds.sel(time=slice('{}-01-01'.format(styr),'{}-12-31'.format(edyr)))
    
    # Crop location
    (x1,y1) = get_latlon2d(ds, 45.347,-124.780)
    (x2,y2) = get_latlon2d(ds, 49.320,-116.586)
    ds = ds.sel(x=slice(x1,x2), y=slice(y1,y2))
    end = time.time()
    ds = ds.rename({var:metric})
    return ds[metric]

#T2 - daily min, max, avg
#prec - daily sum    

##--------------------------------------------------------

gcms = [
    'access1.0_RCP85',
    'access1.3_RCP85',
    'bcc-csm1.1_RCP85',
    'canesm2_RCP85',
    'ccsm4_RCP85',
    'csiro-mk3.6.0_RCP85',
    'fgoals-g2_RCP85',
    'gfdl-cm3_RCP85',
    'giss-e2-h_RCP85',
    'miroc5_RCP85',
    'mri-cgcm3_RCP85',
    'noresm1-m_RCP85',
]

func1 = [
    ('PREC', hlib.get_late_summer_precip, hlib.calc_diff, 'LateSummerPrecip', None),
    #('TMAX', hlib.max_temp, calc_diff, 'LateSummerMaximumTemperature', None),
    #('PREC', hlib.precip_day, calc_diff, '1InchPrecipitationDays', {'threshold':1*25.4}),
    #('PREC', hlib.precip_day, calc_diff, '2InchPrecipitationDays', {'threshold':2*25.4}),
    #('PREC', hlib.precip_day, calc_diff, '3InchPrecipitationDays', {'threshold':3*25.4}),
    #('TAVG', hlib.heating_days, calc_diff, 'HeatingDegreeDays', {'threshold':18.3333}),
    #('TAVG', hlib.cooling_days, calc_diff, 'CoolingDegreeDays', {'threshold':18.3333}),
    #('TMAX', hlib.hot_days, calc_diff, 'HotDays', {'threshold':37.7778}),
    #(('TAVG', 'RH'), max_humidex, calc_diff, 'maxhumidex', {'treshold':90})       
]

func2 = [
    ('PREC', hlib.get_late_summer_precip, hlib.calc_diff, 'LateSummerPrecip', None),
    

]

his_times = [
    #(1981, 2010)
    #(1981, 2006)
    (1970, 1999)
]

fut_times = [
    #(2020, 2049),
    #(2030, 2059),
    #(2040, 2069),
    #(2050, 2079),
    #(2060, 2089),
    (2070, 2099),
    #(1981, 2006)
    #(1970, 1999)
]


odir = "/home/disk/rocinante/DATA/temp/clearinghouse/data_out/wrf_yrs/"
os.makedirs(odir, exist_ok=True)
os.makedirs(odir+'his/', exist_ok=True)
os.makedirs(odir+'fut/', exist_ok=True)


print(1)
print([(g, his_t, fut_t) for g in gcms for his_t in his_times for fut_t in fut_times])
exit()
print(2)


for (var, func, ctype, metric, other) in funcs:
    for g in gcms:
        for his_t in his_times:
            for fut_t in fut_times:
                
                print('\n' +  g, his_t, fut_t)
                (his_st, his_ed) = his_t
                (fut_st, fut_ed) = fut_t
                
                if type(var) is str:
                    
                    his_data = get_data(g, his_st, his_ed, var, metric)
                    fut_data = get_data(g, fut_st, fut_ed, var, metric)
                    
                    (delta, his_res, fut_res) = hlib.form_function_one(his_data, fut_data, func, ctype, **(other or {}))

                else:
                    his_data = []
                    fut_data = []
                    
                    for v in var:
                        his_data = [his_data, get_data(g, his_st, his_ed, v)]
                        fut_data = [fut_data, get_data(g, fut_st, fut_ed, v)]
                    
                    (delta, his_res, fut_res) = hlib.form_function_two(his_data, fut_data, func, ctype, **(other or {}))
                    
                    #(d, his, fut )= form_function_one(g, g, his_t, fut_t, var, func, ctype, **(other or {}))                    
                    
                

                
                out = f'{odir}/{g}_WRF_{metric}_{fut_st}-{fut_ed}.nc'
                hout = f'{odir}/his/{g}_WRF_{metric}_{his_st}-{his_ed}_his.nc'
                fout = f'{odir}/fut/{g}_WRF_{metric}_{fut_st}-{fut_ed}_fut.nc'
                #hout_yr = f'{odir}/his/{g}_WRF_{metric}_{his_st}-{his_ed}_hisyrs.nc'
                #fout_yr = f'{odir}/fut/{g}_WRF_{metric}_{fut_st}-{fut_ed}_futyrs.nc'
                delta.to_netcdf(out)
                his_res.to_netcdf(hout)
                fut_res.to_netcdf(fout)
                #his_data.to_netcdf(hout_yr)
                #fut_data.to_netcdf(fout_yr)
                print(os.path.basename(out))
        
