import xarray as xr
from glob import glob
import xesmf as xe
import os

gdir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data/"
ref_wnd = "/home/disk/rocinante/DATA/temp/crystal_fire/data/wind/wind_daily/pnnl_WND_1981-2020.nc"
odir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/scripts/wnd_cmp/csv/"


gcms = sorted(glob("{}/*/*_RCP85.nc".format(gdir)))

#days with winds above the 4 and 11 thresholds only, no temperature restrictions
#1980 - 2009 vs. 2010 - 2039
#1980 - 2009 vs. 2040 - 2069
#1980 - 2009 vs. 2070 - 2099
#Take the difference for each location and each year and show that dynamically or average for the time period


def format_ds(ds, styr, edyr, thd):
    print('\t', styr, edyr, thd)
    stdt = "{}-01-01".format(styr)
    eddt = "{}-12-31".format(edyr)
    dx = ds.sel(times=slice(stdt, eddt))
    dx = dx.where(dx.WND >= thd).groupby('times.year').count(dim='times')

    #year_month_idx = pd.MultiIndex.from_arrays([dx['times.year'].values, dx['times.month'].values])
    #dx.coords['year_month'] = ('times', year_month_idx)
    #dx = dx.where(dx.WND >= thd).groupby('year_month').count()
    return dx

def ds_avg(ds, dr, styr, edyr, thd, gcm):
    print('\t', styr, edyr, thd)
    stdt = "{}-01-01".format(styr)
    eddt = "{}-12-31".format(edyr)
    dx = ds.sel(times=slice(stdt, eddt))

    dx = dx.where(dx.WND >= thd).groupby('times.year').count(dim='times')
    dx = dx.mean(dim='year')
    dx['WND'] = dx.WND - dr.WND
    df = dx.to_dataframe()
    df = df[['lat', 'lon', 'WND']]
    df['threshold'] = thd
    df['period'] = '{}-{}'.format(styr, edyr)
    df['gcm'] = gcm
    df.to_csv('{}/{}_wnd{}_{}-{}.csv'.format(odir, gcm, thd, styr, edyr), index=False, float_format='%0.5f')
    
def ds_full(ds, styr, edyr, thd, gcm):
    stdt = "{}-01-01".format(styr)
    eddt = "{}-12-31".format(edyr)
    dx = ds.sel(times=slice(stdt, eddt))
    dx = dx.where(dx.WND >= thd).groupby('times.year').count(dim='times')
    
    
dref = xr.open_dataset(ref_wnd)
for g in gcms:
    fn = os.path.basename(g).replace('.nc', '')
    print(fn)
    dg = xr.open_dataset(g)[['WND']]
    dr = xe.Regridder(dg, dref, 'bilinear')
    ds = dr(dg)

    dr4 = format_ds() 
    
    dr4 = dref.where(dref.WND >= 4).groupby('times.year').count(dim='times').mean(dim='year')
    dr11 = dref.where(dref.WND >= 11).groupby('times.year').count(dim='times').mean(dim='year')
    
    ds_avg(ds, dr4, 1980, 2009, 1, fn)
    ds_avg(ds, dr4, 1980, 2009, 2, fn)    
    ds_avg(ds, dr4, 1980, 2009, 3, fn)    
    ds_avg(ds, dr4, 1980, 2009, 4, fn)    
    ds_avg(ds, dr11, 1980, 2009, 8, fn)
    ds_avg(ds, dr11, 1980, 2009, 11, fn)
    ds_avg(ds, dr4, 2010, 2039, 1, fn)
    ds_avg(ds, dr4, 2010, 2039, 2, fn)
    ds_avg(ds, dr4, 2010, 2039, 3, fn)
    ds_avg(ds, dr4, 2010, 2039, 4, fn)
    ds_avg(ds, dr11, 2010, 2039, 8, fn)
    ds_avg(ds, dr11, 2010, 2039, 11, fn)
    ds_avg(ds, dr4, 2040, 2069, 1, fn)
    ds_avg(ds, dr4, 2040, 2069, 2, fn)
    ds_avg(ds, dr4, 2040, 2069, 3, fn)
    ds_avg(ds, dr4, 2040, 2069, 4, fn)
    ds_avg(ds, dr11, 2040, 2069, 8, fn)
    ds_avg(ds, dr11, 2040, 2069, 11, fn)
    ds_avg(ds, dr4, 2070, 2099, 1, fn)
    ds_avg(ds, dr4, 2070, 2099, 2, fn)
    ds_avg(ds, dr4, 2070, 2099, 3, fn)
    ds_avg(ds, dr4, 2070, 2099, 4, fn)
    ds_avg(ds, dr11, 2070, 2099, 8, fn)
    ds_avg(ds, dr11, 2070, 2099, 11, fn)

#print('Historical')
#ds_avg(dref, 1980, 2009, 4, 'pnnl')
#ds_avg(dref, 1980, 2009, 11, 'pnnl')

    
    
#dx.sel(times=slice('1970-01-01', '2010-12-31'))
#z = dy.where(dy.WND >= 4).groupby('times.year').count(dim='times')
#z.WND.mean(dim='year')
