import xarray as xr
from glob import glob
import xesmf as xe
import os

gdir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data/"
ref_wnd = "/home/disk/rocinante/DATA/temp/crystal_fire/data/wind/wind_daily/pnnl_WND_1981-2020.nc"
odir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/scripts/wnd_cmp/csv_quantile/"
os.makedirs(odir, exist_ok=True)

def format_ds(ds, styr, edyr, thd):
    print('\t', styr, edyr, thd)
    stdt = "{}-01-01".format(styr)
    eddt = "{}-12-31".format(edyr)
    dx = ds.sel(times=slice(stdt, eddt))    
    dx['q'] = sum(dx.WND < thd) / len(dx.WND)
    dq = dx[['q']]
    return dq

def save_csv(ds, gcm, thd, x):
    df = ds.to_dataframe().reset_index()
    df = df[['lat', 'lon', 'q']]    
    df['threshold'] = thd
    df['gcm'] = gcm
    df['diff'] = x
    df.to_csv('{}/{}{}_wnd{}.csv'.format(odir, gcm, x, thd), index=False, float_format='%0.5f')


def save_ds(ds, dr, gcm, thd):
    save_csv(ds, gcm, thd, "")
    ds['q'] = ds.q - dr.q
    save_csv(ds, gcm, thd, "-Ref")
    
    
dref = xr.open_dataset(ref_wnd)
dr4 = format_ds(dref, 1981, 2009, 4)
dr11 = format_ds(dref, 1981, 2009, 11)
gcms = sorted(glob("{}/*/*_RCP85.nc".format(gdir)))

for g in gcms:
    fn = os.path.basename(g).replace('.nc', '')
    print(fn)
    dg = xr.open_dataset(g)[['WND']]
    dr = xe.Regridder(dg, dref, 'bilinear')
    ds = dr(dg)

    dq4 = format_ds(ds, 1981, 2009, 4)
    dq11 = format_ds(ds, 1981, 2009, 11)
    
    save_ds(dq4, dr4, fn, 4)
    save_ds(dq11, dr11, fn, 11)

save_ds(dr4, dr4, 'pnnl', 4)
save_ds(dr11, dr11, 'pnnl', 11)


    
    
