import xarray as xr
import pandas as pd
from glob import glob
import xesmf as xe
import os

gdir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data/"
ref_wnd = "/home/disk/rocinante/DATA/temp/crystal_fire/data/wind/wind_daily/pnnl_WND_1981-2020.nc"
odir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/scripts/wnd_cmp/csv_reverse/"
os.makedirs(odir, exist_ok=True)

gcms = sorted(glob("{}/*/*_RCP85.nc".format(gdir)))

#days with winds above the 4 and 11 thresholds only, no temperature restrictions
#1980 - 2009 vs. 2010 - 2039
#1980 - 2009 vs. 2040 - 2069
#1980 - 2009 vs. 2070 - 2099
#Take the difference for each location and each year and show that dynamically or average for the time period

def format_ds(ds, styr, edyr, thd):
    print('\t', styr, edyr, thd)
    stdt = "{}-01-01".format(styr)
    eddt = "{}-12-31".format(edyr)
    dx = ds.sel(times=slice(stdt, eddt))
    
    # June - sep
    dx = dx.where((dx.times.dt.month >= 6 )& (dx.times.dt.month <= 9))
    dx = dx.where(dx.WND >= thd).groupby('times.year').count(dim='times')
    
    #year_month_idx = pd.MultiIndex.from_arrays([dx['times.year'].values, dx['times.month'].values])
    #dx.coords['year_month'] = ('times', year_month_idx)
    #dx = dx.where(dx.WND >= thd).groupby('year_month').count()
    return dx
    
    


def ds_avg(ds, dr, gcm, thd):
    ds['WND'] = ds.WND - dr.WND
    df = ds.to_dataframe().reset_index()
    df = df[['lat', 'lon', 'WND', 'year']]

    #df['year'] = df.year_month.str[0]
    #df['month'] = df.year_month.str[1]       
    #df = df[['lat', 'lon', 'WND', 'year', 'month']]
    
    df['threshold'] = thd
    df['gcm'] = gcm
    df.to_csv('{}/{}_wnd{}.csv'.format(odir, gcm, thd), index=False, float_format='%0.5f')
    
    
    
dref = xr.open_dataset(ref_wnd)
dref = dref.sel(times=slice('1981', '2009'))

for g in gcms:
    fn = os.path.basename(g).replace('.nc', '')
    print(fn)
    ds = xr.open_dataset(g)[['WND']]
    dr = xe.Regridder(dref, ds, 'bilinear')
    dref = dr(dref)

    #dref = ds.where((dref.times.dt.month >= 5) & (dref.times.dt.month <= 9))
    #dr4 = dref.where(dref.WND >= 4).groupby(['times.year', 'times.month']).count(dim='times')
    #dr11 = dref.where(dref.WND >= 11).groupby(['times.year', 'times.month']).count(dim='times')

    ds4 = format_ds(ds, 1981, 2009, 4)
    dr4 = format_ds(dref, 1981, 2009, 4)
    ds_avg(ds4, dr4, fn, 4)
    
    ds11 = format_ds(ds, 1981, 2009, 11)
    dr11 = format_ds(dref, 1981, 2009, 11)    
    ds_avg(ds11, dr11, fn, 11)


    
