import xarray as xr
import numpy as np
from glob import glob
import xesmf as xe
import os

gdir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data/"
ref_wnd = "/home/disk/rocinante/DATA/temp/crystal_fire/data/wind/wind_daily/pnnl_WND_1981-2020.nc"
odir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/scripts/wnd_cmp/csv_scatter/"
os.makedirs(odir, exist_ok=True)


def get_latlon2d(ds, lat, lon):
    lats = ds.lat.values
    lons = ds.lon.values
    dist = ((lats - lat)**2 + (lons-lon)**2)**0.5
    i,j = np.unravel_index(np.argmin(dist, axis=None), dist.shape)
    return ds.sel(x=i, y=j)
    
def get_scatter(ds, lat, lon, gcm):
    ds = get_latlon2d(ds, lat, lon)
    lat = ds.lat.values
    lon = ds.lon.values
    print('\t', gcm, lat, lon)

    df = ds.to_dataframe()
    df.index.names = ['date']
    df = df.reset_index()[['date', "lat", "lon", 'WND_SIM', 'WND_REF']]
    df[['gcm']] = gcm
    out = "{}/wnd_{}_{:.5f}_{:.5f}.csv".format(odir, gcm, lat, lon)
    df.to_csv(out, index=False, float_format='%0.5f')

    
ref = xr.open_dataset(ref_wnd)
ref = ref.rename({'WND':'WND_REF'})
gcms = sorted(glob("{}/*/*_RCP85.nc".format(gdir)))
for g in gcms:
    fn = os.path.basename(g).replace('.nc', '')
    print(fn)
    dg = xr.open_dataset(g)[['WND']]
    dr = xe.Regridder(dg, ref, 'bilinear')
    sim = dr(dg)
    sim = sim.rename({'WND':'WND_SIM'})
    ds = sim.merge(ref, join='inner')
    ds = ds[['WND_SIM', 'WND_REF']]
    
    get_scatter(ds, 47.872,-123.713, fn)
    get_scatter(ds, 46.884,-121.695, fn)
    get_scatter(ds, 46.193,-121.436, fn)
    get_scatter(ds, 45.700,-122.225, fn)
    get_scatter(ds, 45.496,-122.739, fn)
    get_scatter(ds, 46.381,-123.282, fn)
    get_scatter(ds, 47.114,-123.652, fn)
    get_scatter(ds, 47.598,-122.368, fn)

#47.872,-123.713
#46.884,-121.695
#46.193,-121.436
#45.700,-122.225
#45.496,-122.739
#46.381,-123.282
#47.114,-123.652
#47.598,-122.368


