import xarray as xr
import pandas as pd
from glob import glob
import xesmf as xe
import os

gdir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data/"
ref_wnd = "/home/disk/rocinante/DATA/temp/crystal_fire/data/wind/wind_daily/pnnl_WND_1981-2020.nc"
odir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/scripts/wnd_cmp/csv_doy/"
os.makedirs(odir, exist_ok=True)
gcms = sorted(glob("{}/*/*_RCP85.nc".format(gdir)))


def format_ds(ds, styr, edyr, gcm):
    print('\t', styr, edyr, gcm)
    stdt = "{}-01-01".format(styr)
    eddt = "{}-12-31".format(edyr)
    dx = ds.sel(times=slice(stdt, eddt))
    dx = dx.where((dx.times.dt.month >= 6 )& (dx.times.dt.month <= 9))
    dx = dx.groupby('times.dayofyear').mean()
    dx = dx.sel(dayofyear=slice(121, 274))
    print(dx)

    df = ds.to_dataframe().reset_index()
    print(df)
    df = df[['lat', 'lon', 'WND', 'dayofyear']]
    df['gcm'] = gcm
    df['period'] = '{}-{}'.format(styr, edyr)
    
    df.to_csv('{}/{}_wnd{}.csv'.format(odir, gcm, period), index=False, float_format='%0.5f')

        
    
dref = xr.open_dataset(ref_wnd)
dref = dref.sel(times=slice('1981', '2009'))

for g in gcms:
    fn = os.path.basename(g).replace('.nc', '')
    print(fn)
    dg = xr.open_dataset(g)[['WND']]
    dr = xe.Regridder(dg, dref, 'bilinear')
    ds = dr(dg)

    #dref = ds.where((dref.times.dt.month >= 5) & (dref.times.dt.month <= 9))
    #dr4 = dref.where(dref.WND >= 4).groupby(['times.year', 'times.month']).count(dim='times')
    #dr11 = dref.where(dref.WND >= 11).groupby(['times.year', 'times.month']).count(dim='times')

    format_ds(ds, 1981, 2009, fn)    
    #format_ds(ds, 2010, 2039, fn)
    #format_ds(ds, 2040, 2069, fn)
    #format_ds(ds, 2070, 2099, fn)

format_ds(dref, 1981, 2009, 'pnnl')


    
