import xarray as xr
import xesmf as xe


def get_seasonal_sum(ds, out):
    ds = ds.groupby(ds.time.dt.year).sum(skipna=False)
    ds = ds.mean(dim='year')
    ds.to_netcdf(out)
    return ds

def get_seasonal_sum2(ds, out, mult):
    ds = ds.groupby(ds.time.dt.year).mean(skipna=False)
    ds = ds.mean(dim='year')
    ds = ds * mult
    ds.to_netcdf(out)
    return ds

def mk_seasonal_maps(ds, odir, fn):
    print('Make Seasonal Maps: ', fn)
    # Make Seasonal maps  JFM AMJ JAS OND
    d01_03 = ds.sel(time=((ds.time.dt.month == 1) | (ds.time.dt.month == 2) | (ds.time.dt.month == 3)))
    d04_06 = ds.sel(time=((ds.time.dt.month == 4) | (ds.time.dt.month == 5) | (ds.time.dt.month == 6)))
    d07_09 = ds.sel(time=((ds.time.dt.month == 7) | (ds.time.dt.month == 8) | (ds.time.dt.month == 9)))
    d10_12 = ds.sel(time=((ds.time.dt.month == 10) | (ds.time.dt.month == 11) | (ds.time.dt.month == 12)))
    
    out = "{}/{}_SEASONAL_TTL_JAN-MAR.nc".format(odir, fn)
    d01_03 = get_seasonal_sum2(d01_03, out, 90)

    out = "{}/{}_SEASONAL_TTL_APR-JUN.nc".format(odir, fn)
    d04_06 = get_seasonal_sum2(d04_06, out, 91)
    
    out = "{}/{}_SEASONAL_TTL_JUL-SEP.nc".format(odir, fn)
    d07_09 = get_seasonal_sum2(d07_09, out, 92) 

    out = "{}/{}_SEASONAL_TTL_OCT-DEC.nc".format(odir, fn)
    d10_12 = get_seasonal_sum2(d10_12, out, 92)

    
    

def mk_quantile_maps(ds, odir, fn):
    print('Make Quantile Maps: ', fn)
    # Make Quantile Maps - hourly quantiles over Oct-Mar
    dq = ds.sel(time=((ds.time.dt.month <= 3) | (ds.time.dt.month >= 10)))
    dq['wyear'] = dq.time.dt.year + (dq.time.dt.month > 9)
    dq = dq.set_coords('wyear')
    dx = dq.groupby('wyear').quantile(q=[0.75, 0.90, 0.95, 0.99, 0.995, 0.999], dim='time')
    dx = dx.mean(dim='wyear')
    dx.to_netcdf('{}/{}_Quantile_OCT-MAR.nc'.format(odir, fn))

    


ds = xr.open_dataset('/home/disk/rocinante/DATA/temp/TNC_stormwater/prism-daily/data_4km/PRISM_19810102-20240107.nc')
ps = xr.open_dataset('/home/disk/rocinante/DATA/temp/TNC_stormwater/pnnl/pnnl_1981-2020.nc')
ds = ds.sel(time=slice(ds.time[0],ps.time[-1]))
#dates = ps.resample(time='M').first()
dates = ps['time'].resample(time='D').first()
psel = ps.resample(time='D').sum()
psel['time'] = dates
psel = psel.sel(time=ds.time)

odir = "/home/disk/rocinante/DATA/temp/TNC_stormwater/prism-daily/maps/"

mk_seasonal_maps(ds, odir, 'PRISM-daily')
mk_quantile_maps(ds, odir, 'PRISM-daily')

mk_seasonal_maps(psel, odir, 'WRF-NARR_PRISM-daily-time')
mk_quantile_maps(psel, odir, 'WRF-NARR_PRISM-daily-time')


    

#ps_new = ps.where(ps.times.isin(ds.time), drop=True)
#dtime = pd.to_datetime([x.isoformat() for x in ds.time.values])

# get average across current timestep interval for the sample period
# multiple by expected number of samples
# average across the years

#ds['mask'] = xr.where(~np.isnan(ds['PREC'].isel(time=0)), 1, 0)
