import xarray as xr
from glob import glob
import xesmf as xe
import os

gdir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data/"
odir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/scripts/wnd_cmp/csv_bias/"
gcms = sorted(glob("{}/*/*_RCP85.nc".format(gdir)))
os.makedirs(odir, exist_ok=True)

#days with winds above the 4 and 11 thresholds only, no temperature restrictions
#1980 - 2009 vs. 2010 - 2039
#1980 - 2009 vs. 2040 - 2069
#1980 - 2009 vs. 2070 - 2099
#Take the difference for each location and each year and show that dynamically or average for the time period


def format_ds(ds, styr, edyr, thd):
    print('\t', styr, edyr, thd)
    stdt = "{}-01-01".format(styr)
    eddt = "{}-12-31".format(edyr)
    dx = ds.sel(times=slice(stdt, eddt))
    dx = dx.where(dx.WND >= thd).groupby('times.year').count(dim='times')
    dx = dx.mean(dim='year')
    return dx

def ds_avg(ds, dr, styr, edyr, thd, gcm):    
    print('\t', styr, edyr, thd)
    dx = format_ds(ds, styr, edyr, thd)    
    dx['WND'] = dx.WND - dr.WND
    df = dx.to_dataframe()
    df = df[['lat', 'lon', 'WND']]
    df['threshold'] = thd
    df['period'] = '{}-{}'.format(styr, edyr)
    df['gcm'] = gcm
    df.to_csv('{}/{}_wnd{}_{}-{}.csv'.format(odir, gcm, thd, styr, edyr), index=False, float_format='%0.5f')
    
def ds_full(ds, styr, edyr, thd, gcm):
    stdt = "{}-01-01".format(styr)
    eddt = "{}-12-31".format(edyr)
    dx = ds.sel(times=slice(stdt, eddt))
    dx = dx.where(dx.WND >= thd).groupby('times.year').count(dim='times')
    
    
for g in gcms:
    fn = os.path.basename(g).replace('.nc', '')
    print(fn)
    ds = xr.open_dataset(g)[['WND']]

    dr4 = format_ds(ds, 1980, 2009, 4) 
    dr11 = format_ds(ds, 1980, 2009, 11) 
        
    ds_avg(ds, dr4, 2010, 2039, 4, fn)
    ds_avg(ds, dr11, 2010, 2039, 11, fn)
    ds_avg(ds, dr4, 2040, 2069, 4, fn)
    ds_avg(ds, dr11, 2040, 2069, 11, fn)
    ds_avg(ds, dr4, 2070, 2099, 4, fn)
    ds_avg(ds, dr11, 2070, 2099, 11, fn)

    
    
#dx.sel(times=slice('1970-01-01', '2010-12-31'))
#z = dy.where(dy.WND >= 4).groupby('times.year').count(dim='times')
#z.WND.mean(dim='year')
