import xarray as xr
import numpy as np



pts = (
    ['Goodell', 48.683,-121.227],
    ['EagleCr',	45.618,-121.942],
    ['Maple',47.579,-123.151],
    ['BeachieCr',44.84014,-122.37729]
)

gcms = (
    #'access1.0_RCP85',
    'access1.3_RCP85',
    #'bcc-csm1.1_RCP85',
    'canesm2_RCP85',
    #'ccsm4_RCP85',
    #'csiro-mk3.6.0_RCP85',
    #'fgoals-g2_RCP85',
    'gfdl-cm3_RCP85',
    #'giss-e2-h_RCP85',
    #'miroc5_RCP85',
    #'mri-cgcm3_RCP85',
    'noresm1-m_RCP85',
)


def get_latlon2d(ds, lat, lon):
    lats = ds.lat.values
    lons = ds.lon.values
    dist = ((lats - lat)**2 + (lons-lon)**2)**0.5
    i,j = np.unravel_index(np.argmin(dist, axis=None), dist.shape)
    #print("[{}, {}]: {:0.5f}, {:0.5f}".format(i, j, lats[i,j], lons[i,j]))    
    return (i,j)

def ds2df(ds, gcm):
    ds['T7'] = ds.Tmax.rolling(times=7).mean()
    ds['T60'] = ds.Tmax.rolling(times=60).mean()
    df = ds.to_dataframe()
    df = df.melt(id_vars=['lat', 'lon', 'WND'], ignore_index=False)
    df['doy'] = df.index.dayofyear
    df['fdate'] = df.index.strftime('2000/%m/%d')
    df['loc'] = loc
    df['gcm'] = gcm
    df['times'] = df.index.strftime('%Y/%m/%d')
    df.to_csv('{}_{}.csv'.format(gcm, loc), float_format='%4f', index=False)
    #return df

def get_his(loc, lat, lon):
    print('Getting his data:', loc)

    wnd = xr.open_dataset('/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data_wndf/pnnl/pnnl_WND_1981-2020.nc')
    tmp = xr.open_dataset('/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data/gridMET.nc')

    (wi, wj) = get_latlon2d(wnd, lat, lon)
    (ti, tj) = get_latlon2d(tmp, lat, lon)

    wnd_data = wnd.sel(x=wi, y=wj)
    tmp_data = tmp.sel(x=tj, y=ti)
    tmp_data = tmp_data.drop_duplicates(dim='times', keep='last')
    tmp_data = tmp_data.sel(times=wnd_data.times)

    ds = wnd_data
    ds['Tmax'] = tmp_data.tmmx
    ds2df(ds, 'his')
    wnd.close()
    tmp.close()

def get_fut(gcm, loc, lat, lon):
    print('Getting ', gcm, 'data:', loc)
    ds = xr.open_dataset('/home/disk/rocinante/DATA/temp/crystal_fire/wrf/bcdata_wndf/{}.nc'.format(gcm))
    (i, j) = get_latlon2d(ds, lat, lon)
    data = ds.sel(x=i, y=j)
    data = data.rename({'T2':'Tmax'})
    ds2df(data, gcm)
    



for pt in pts:
    loc = pt[0]
    lat = pt[1]
    lon = pt[2]
    
    get_his(loc, lat, lon)
    for g in gcms:
        get_fut(g, loc, lat, lon)
