#!/bin/usr/python3
import sys
import argparse
import pandas as pd
import numpy as np
import xarray as xr
import bc_lib as bc
import multiprocessing as mp
from tqdm import tqdm
import itertools as it

parser = argparse.ArgumentParser()
parser.add_argument('obsf', help='Observational data for reference')
parser.add_argument('simf', help='Historical simulation for training')
parser.add_argument('outf', help='Training outfile')
parser.add_argument('--his_yrs', nargs=2, type=int,
                    help='Historical years', default=[1970,1999])
parser.add_argument('--fut_yrs', nargs=2, type=int,
                    help='Future years', default=[2000,2099])
args = parser.parse_args()

# Read datafile
obsf = args.obsf
simf = args.simf
outf = args.outf
his_yrs = args.his_yrs
fut_yrs = args.fut_yrs
#varlist = ['pr', 'tasmax', 'tasmin']
varlist = ['PREC', 'TMAX', 'TMIN']

obs = xr.open_dataset(obsf)
sim = xr.open_dataset(simf)


## Formatting
obs = obs.rename({'x':'lon', 'y':'lat'})
sim = sim.rename({'x':'lon', 'y':'lat'})
obs = obs.transpose('lon', 'lat', 'times')
sim = sim.transpose('lon', 'lat', 'times')

his = sim.sel(times=slice(str(his_yrs[0]), str(his_yrs[1])))
fut = sim.sel(times=slice(str(fut_yrs[0]), str(fut_yrs[1])))
bchis = his.copy()
bcfut = fut.copy()

x = len(his.lon)
y = len(his.lat)
n = x * y

# Parallel processing for bc application
def process_cell(cell):
    mode = 'prec' if cell[2] == varlist[0] else cell[2]    # Change to PREC
    (h, f) = bc.apply_bc(cell[3], cell[4], cell[5], mode, 0.09906)
    return (cell[0], cell[1], h, f)

# Convert result array to dataset format
def results2ds(res):
    rf = pd.DataFrame(res)    
    rf = rf.melt([0, 1])
    rf.columns = ['x', 'y', 'time', 'val']
    rf['x'] = rf.x.astype(int)
    rf['y'] = rf.y.astype(int)
    rf = rf.set_index(['x', 'y', 'time'])
    rs = rf.to_xarray()
    return rs
    


p = mp.Pool(10)

for var in varlist:
    print('--' + var)
    pbar = tqdm(total=n)
    cells = []

    for i,j in it.product(range(x), range(y)):
        #print("Building cell: {} - {}".format(ilat, jlon))        
        obs_ij = obs.sel(lon=i, lat=j)[var]
        his_ij = his.sel(lon=i, lat=j)[var]
        fut_ij = fut.sel(lon=i, lat=j)[var]
        
        cells += [[i, j, var, obs_ij, his_ij, fut_ij]] 
             
    res = [p.apply_async(process_cell, args=(cells[k],),
                         callback=lambda _:pbar.update(1)) for k in range(n)]
    results = [p.get() for p in res]
    pbar.close()

    print('--Saving results: {}'.format(var))
    hres = [np.concatenate([[i,j], h]) for i,j,h,f in results]
    fres = [np.concatenate([[i,j], f]) for i,j,h,f in results]
    hr_set = results2ds(hres)
    fr_set = results2ds(fres)
    bchis[var].data = hr_set['val'].data
    bcfut[var].data = fr_set['val'].data
    
ds = xr.concat([bchis, bcfut], 'times')
ds.to_netcdf(outf)
print("\n", outf)










    
    #print(bchis[var].shape)
    #for i,j,h,f in results:
    #    print(h.shape)
    #    print(bchis[var].data[:,i,j].shape)
    #    print('set his {} {}'.format(i,j))
    #    bchis[var].data[:, i, j] = h
    #    print('set fut {} {}'.format(i,j))
    #    bcfut[var].data[:, i, j] = f
    #pbar.close()

    #pbar.close()
    #pbar = tqdm(total=len(results))
    #print('saving')
    #print(len(results))
    #res = [p.apply_async(save_bc, args=(results[l], var,),
    #                     callback=lambda _:pbar.update(1)) for l in range(len(results))]
    #print('done')
    #pbar.close()















