import xarray as xr
import os
import time
from glob import glob
import pandas as pd
import numpy as np
import sys
import hydrolib as hlib


##==========================================================

def calc_diff(his, fut):
    print("Calculating difference")
    diff = fut - his
    return diff

def calc_change(his, fut):
    print("Calculating Percent Change")
    diff = fut / his * 100
    return diff

def get_latlon2d(ds, lat, lon):
    lats = ds.lat.values
    lons = ds.lon.values
    dist = ((lats - lat)**2 + (lons-lon)**2)**0.5
    i,j = np.unravel_index(np.argmin(dist, axis=None), dist.shape)
    #print("\t[{}, {}]: {:0.5f}, {:0.5f}".format(i, j, lats[i,j], lons[i,j]))
    return (i,j)

def get_data(fpath, styr, edyr, var):
    gcm = os.path.basename(fpath)
    print("Getting data: {}[{}] - {}:{}".format(gcm, var, styr, edyr))
    start = time.time()
    #idir = '/home/disk/becassine/jswon11/DATA/WRF/{}/{}'.format(gcm, var)
    data = "/home/disk/becassine/jswon11/DATA/WRF-daily/{}/{}_{}.zarr".format(gcm, gcm, var)
    #df = pd.DataFrame({'path':data, 'yr':[int(x.replace('.zarr', '').split('_')[-1]) for x in data]})
    #df = df[(df.yr >= styr) & (df.yr <= edyr)]
    #data = list(df.path.values)
    ds = xr.open_mfdataset(data, engine='zarr', parallel=True)
    ds = ds.sel(time=slice('{}-01-01'.format(styr),'{}-12-31'.format(edyr)))
    
    # crop
    (x1,y1) = get_latlon2d(ds, 45.347,-124.780)
    (x2,y2) = get_latlon2d(ds, 49.320,-116.586)
    ds = ds.sel(x=slice(x1,x2), y=slice(y1,y2))
    end = time.time()
    return ds

#T2 - daily min, max, avg
#prec - daily sum


def form_function_one(his_dir, fut_dir, his_times, fut_times, var, func, change, **kwargs):
    (his_st, his_ed) = his_times
    (fut_st, fut_ed) = fut_times
    
    #for (his_st, his_ed) in his_time:
    #    for (fut_st, fut_ed) in fut_time:
    print("Running {} with {} for {}-{} / {}-{}".format(func.__name__, var, his_st, his_ed, fut_st, fut_ed))
    his = get_data(his_dir, his_st, his_ed, var)
    #his.to_netcdf('histest.nc')
    fut = get_data(fut_dir, fut_st, fut_ed, var)    
    his_res = func(his[var], **kwargs)
    fut_res = func(fut[var], **kwargs)
    diff = change(his_res, fut_res)
    diff = diff.to_dataset()
    diff = diff.assign_coords(lat=his.lat, lon=his.lon)
    return (diff, his_res, fut_res)

##--------------------------------------------------------

gcms = [
    'access1.0_RCP85',
    'access1.3_RCP85',
    'bcc-csm1.1_RCP85',
    'canesm2_RCP85',
    'ccsm4_RCP85',
    'csiro-mk3.6.0_RCP85',
    'fgoals-g2_RCP85',
    'gfdl-cm3_RCP85',
    'giss-e2-h_RCP85',
    'miroc5_RCP85',
    'mri-cgcm3_RCP85',
    'noresm1-m_RCP85',
]

funcs = [
    ('TMAX', hlib.max_temp, calc_diff, 'LateSummerMaximumTemperature', None),
    ('PREC', hlib.precip_day, calc_diff, 'precipday1in', {'threshold':1*25.4}),
    ('PREC', hlib.precip_day, calc_diff, 'precipday2in', {'threshold':2*25.4}),
    ('PREC', hlib.precip_day, calc_diff, 'precipday3in', {'threshold':3*25.4}),
    ('TAVG', hlib.heating_days, calc_diff, 'heatingdegreedays', {'threshold':18.3333}),
    ('TAVG', hlib.cooling_days, calc_diff, 'coolingdegreedays', {'threshold':18.3333}),
    ('TMAX', hlib.hot_days, calc_diff, 'hotdays', {'threshold':37.7778}),
    #(('TAVG', 'RH'), max_humidex, calc_diff, 'maxhumidex', {'treshold':90})
       
]

his_times = [
    (1981, 2010)
]

fut_times = [
    (2020, 2049),
    (2030, 2059),
    (2040, 2069),
    (2050, 2079),
    (2060, 2089),
    (2070, 2099),
]

odir = "/home/disk/rocinante/DATA/temp/clearinghouse/data_out/"

for g in gcms:
    for (var, func, ctype, label, other) in funcs:
        for his_t in his_times:
            for fut_t in fut_times:
                print(g, his_t, fut_t)

                (his_st, his_ed) = his_t
                (fut_st, fut_ed) = fut_t
                print(fut_st, fut_ed)
            
                (d, his, fut )= form_function_one(g, g, his_t, fut_t, var, func, ctype, **(other or {}))
                
                out = '{}/{}_{}_{}-{}.nc'.format(odir, g, label, fut_st, fut_ed)
                d.to_netcdf(out)
                print(os.path.basename(out))
        
