#!/usr/bin/python3
import xarray as xr
import pandas as pd
import sys
import os
import dask

from dask.diagnostics import ProgressBar as pbar
from multiprocessing.pool import ThreadPool
dask.config.set(schedular='threads', pool=ThreadPool(8))

args = sys.argv
ifile = args[1]


gcm = os.path.basename(ifile).replace('.nc', '')
#odir = "/home/disk/rocinante/DATA/temp/kcp3/forcings/4.forc_nc/"
#odir = "/home/disk/rocinante/DATA/temp/kcp3/forcings/xtra/murphy/data/"
odir = "/home/disk/becassine/jswon11/4.forc_nc/"
gdir = "{}/{}/".format(odir, gcm)
os.makedirs(gdir, exist_ok=True)
#ldf = pd.read_csv('/home/disk/rocinante/DATA/temp/WRF/scripts/regrid_c2c/pnnl_elev_5d.csv')
ldf = pd.read_csv('/home/disk/rocinante/DATA/temp/kcp3/forcings/xtra/murphy/clip2.csv')


def main():
    ds = xr.open_dataset(ifile)
    ds = ds.chunk({'times':len(ds.times), 'x':1, 'y':1})
    print(ds)
    
    cells = []
    outs = []
    dist = []
    
    for i,j in [(i,j) for i in ds.x for j in ds.y]:
        cell = ds.sel(x=i, y=j)
        lat = "{:0.5f}".format(cell.lat.values)
        lon = "{:0.5f}".format(cell.lon.values)
        fname = "{}_{}".format(lat, lon)
        check = ldf[(float(lat)==ldf.lat) & (float(lon)==ldf.lon)]
        if(len(check)==1):            
            out = "{}/data_{}.nc".format(gdir, fname)
            print(lat, lon)
            cells = cells  + [cell]
            outs = outs + [out]
            
    print("Saving: {}".format( len(outs)))
    with pbar():
        xr.save_mfdataset(cells, outs)
        
if __name__ == "__main__":
    main()
