import xarray as xr


ds = xr.open_dataset('/home/disk/rocinante/DATA/temp/TNC_stormwater/prism/data/PRISM_198101-202305.nc')
odir = "/home/disk/rocinante/DATA/temp/TNC_stormwater/prism/data/3.maps/"


# Make Seasonal maps  JFM AMJ JAS OND
def get_seasonal_sum(ds):
    ds = ds.groupby(ds.time.dt.year).sum()
    ds = ds.mean(dim='year')
    return ds

d01_03 = ds.sel(time=((ds.time.dt.month == 1) | (ds.time.dt.month == 2) | (ds.time.dt.month == 3)))
d04_06 = ds.sel(time=((ds.time.dt.month == 4) | (ds.time.dt.month == 4) | (ds.time.dt.month == 6)))
d07_09 = ds.sel(time=((ds.time.dt.month == 7) | (ds.time.dt.month == 7) | (ds.time.dt.month == 9)))
d10_12 = ds.sel(time=((ds.time.dt.month == 10) | (ds.time.dt.month == 11) | (ds.time.dt.month == 12)))

d01_03 = get_seasonal_sum(d01_03)
d04_06 = get_seasonal_sum(d04_06)
d07_09 = get_seasonal_sum(d07_09)
d10_12 = get_seasonal_sum(d10_12)

out = "{}/PRISM_SEASONAL_TTL_JAN-MAR.nc".format(odir)
#d01_03.to_netcdf(out)
out = "{}/PRISM_SEASONAL_TTL_APR-JUN.nc".format(odir)
#d04_06.to_netcdf(out)
out = "{}/PRISM_SEASONAL_TTL_JUL-SEP.nc".format(odir)
#d07_09.to_netcdf(out)
out = "{}/PRISM_SEASONAL_TTL_OCT-DEC.nc".format(odir)
#d10_12.to_netcdf(out)

# Make Quantile Maps - hourly quantiles over Oct-Mar
dq = ds.sel(time=((ds.time.dt.month <= 3) | (ds.time.dt.month >= 10)))
dq['wyear'] = dq.time.dt.year + (dq.time.dt.month > 9)
dq = dq.set_coords('wyear')

