import xarray as xr
from glob import glob
import os

## Create WRF GCM dataset for wind and temp bias-corrected by GridMET dataset
gcms = [
    "access1.0_RCP45",
    "access1.0_RCP85",
    "access1.3_RCP85",
    "bcc-csm1.1_RCP85",
    "canesm2_RCP85",
    "ccsm4_RCP85",
    "csiro-mk3.6.0_RCP85",
    "fgoals-g2_RCP85",
    "gfdl-cm3_RCP85",
    "giss-e2-h_RCP85",
    "miroc5_RCP85",
    "mri-cgcm3_RCP85",
    "noresm1-m_RCP85"
]

idir = "/home/disk/rocinante/DATA/temp/WRF/var/"
odir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data/"

# Helper function to open and format dataset
def open_ds(idir, gcm, var):
    print("\tOpening dataset: ", gcm, var)
    data = sorted(glob("{}/{}/2d/{}_year/*.nc".format(idir, gcm, var)))
    ds = xr.open_mfdataset(data)
    ds = ds.sel(x=slice(51,75), y=slice(53,91))
    if var == "T2":
        ds = ds.resample(times='1D').max()
    else:
        ds = ds.resample(times='1D').mean()
    print("\tDataset loaded...")
    return ds


# Open gridmet file for basis for delta
print('Opening GridMET dataset')
gms = xr.open_dataset('/home/disk/rocinante/DATA/temp/crystal_fire/wrf/scripts/gridMET_WRF_interp.nc')
gms = gms.sel(times=((gms.times.dt.year >= 1979) & (gms.times.dt.year <= 2006)))
gmm = gms.groupby(gms.times.dt.month).sum() / gms.groupby(gms.times.dt.month).count()

for gcm in gcms:
    print("\nWorking on: ", gcm)
    gdir = "{}/{}/".format(odir, gcm)
    os.makedirs(gdir, exist_ok=True)
    
    # Comparison period 1979 - 2006
    tds = open_ds(idir, gcm, "T2")
    fmax = tds.sel(times=((tds.times.dt.year >= 1979) & (tds.times.dt.year <= 2006)))
    
    # Create delta
    print("\tCreating delta")
    fmm = fmax.groupby(fmax.times.dt.month).sum() / fmax.groupby(fmax.times.dt.month).count()
    fmm['T2'] -= 273.15
    cs = gmm.merge(fmm)
    cs['delta'] = cs.tmmx - cs.T2
    print('\tSaving delta')
    cs.to_netcdf('{}/{}_delta.nc'.format(gdir, gcm)) # save delta
    
    # Create adjusted dataset
    uds = open_ds(idir, gcm, "U10")
    vds = open_ds(idir, gcm, "V10")
    print('\tCreating adjusted dataset')
    tds['WND'] = (uds.U10**2 + vds.V10**2)**0.5
    tds['T2'] -= 273.15
    tds['T2'] = tds.T2.groupby(tds.times.dt.month) + cs.delta
    tds['tmx7'] = tds.T2.rolling(times=7).mean()
    tds['tmx60'] = tds.T2.rolling(times=60).mean()
    print('\tSaving dataset')
    tds.to_netcdf("{}/{}.nc".format(gdir, gcm)) # save data

    # Clean-up
    tds.close()
    uds.xlose()
    vds.close()
    
