import xarray as xr
import xesmf as xe


def get_seasonal_sum(ds, out):
    ds = ds.groupby(ds.time.dt.year).sum(skipna=False)
    ds = ds.mean(dim='year')
    ds.to_netcdf(out)
    return ds

def get_seasonal_sum2(ds, out, mult):
    ds = ds.groupby(ds.time.dt.year).mean(skipna=False)
    ds = ds.mean(dim='year')
    ds = ds * mult
    ds.to_netcdf(out)
    return ds

def mk_seasonal_maps(ds, odir, fn):
    print('Make Seasonal Maps: ', fn)
    # Make Seasonal maps  JFM AMJ JAS OND
    d01_03 = ds.sel(time=((ds.time.dt.month == 1) | (ds.time.dt.month == 2) | (ds.time.dt.month == 3)))
    d04_06 = ds.sel(time=((ds.time.dt.month == 4) | (ds.time.dt.month == 5) | (ds.time.dt.month == 6)))
    d07_09 = ds.sel(time=((ds.time.dt.month == 7) | (ds.time.dt.month == 8) | (ds.time.dt.month == 9)))
    d10_12 = ds.sel(time=((ds.time.dt.month == 10) | (ds.time.dt.month == 11) | (ds.time.dt.month == 12)))
    
    out = "{}/{}_SEASONAL_TTL_JAN-MAR.nc".format(odir, fn)
    d01_03 = get_seasonal_sum2(d01_03, out, 90)

    out = "{}/{}_SEASONAL_TTL_APR-JUN.nc".format(odir, fn)
    d04_06 = get_seasonal_sum2(d04_06, out, 91)
    
    out = "{}/{}_SEASONAL_TTL_JUL-SEP.nc".format(odir, fn)
    d07_09 = get_seasonal_sum2(d07_09, out, 92) 

    out = "{}/{}_SEASONAL_TTL_OCT-DEC.nc".format(odir, fn)
    d10_12 = get_seasonal_sum2(d10_12, out, 92)

    
    

def mk_quantile_maps(ds, odir, fn):
    print('Make Quantile Maps: ', fn)
    # Make Quantile Maps - hourly quantiles over Oct-Mar
    dq = ds.sel(time=((ds.time.dt.month <= 3) | (ds.time.dt.month >= 10)))
    #dq['wyear'] = dq.time.dt.year + (dq.time.dt.month > 9)
    #dq = dq.set_coords('wyear')
    #dx = dq.groupby('wyear').quantile(q=[0.75, 0.90, 0.95, 0.99, 0.995, 0.999], dim='time')
    #dx = dx.mean(dim='wyear')
    dx = dq.quantile(q=[0.75, 0.90, 0.95, 0.99, 0.995, 0.999], dim='time')
    dx.to_netcdf('{}/{}_Quantile_OCT-MAR.nc'.format(odir, fn))

    

setname = "1.0"
ds = xr.open_dataset('/home/disk/rocinante/DATA/temp/TNC_stormwater/bias_correction/data/bc_{}bin.nc'.format(setname))
ps = xr.open_dataset('/home/disk/rocinante/DATA/temp/TNC_stormwater/bias_correction/data/wrf/pnnl_19810101-20201231.nc')
ds = ds.sel(time=slice(ds.time[0],ps.time[-1]))
dates = ps['time'].resample(time='D').first()
psel = ps.resample(time='D').sum()
psel['time'] = dates
psel = psel.sel(time=ds.time)

#odir = "/home/disk/rocinante/DATA/temp/TNC_stormwater/bias_correction/maps"
odir = "/home/disk/rocinante/DATA/temp/TNC_stormwater/bias_correction/scripts/check/"

#mk_seasonal_maps(ds, odir, "PRISM-bc{}".format(setname))
mk_quantile_maps(ds, odir, "PRISM-bc{}".format(setname))

#mk_seasonal_maps(psel, odir, 'WRF-NARR_bc{}-time'.format(setname))
mk_quantile_maps(psel, odir, 'WRF-NARR_bc{}-time'.format(setname))


    

#ps_new = ps.where(ps.times.isin(ds.time), drop=True)
#dtime = pd.to_datetime([x.isoformat() for x in ds.time.values])

# get average across current timestep interval for the sample period
# multiple by expected number of samples
# average across the years

#ds['mask'] = xr.where(~np.isnan(ds['PREC'].isel(time=0)), 1, 0)
