#!/usr/bin/env python

import sys
import numpy as np
import numpy.ma as ma
import inspect
import xarray as xr
import pandas as pd
np.set_printoptions(suppress=True)


# Get matlab style percentile values
def matlab_percentile(x, p):
    p = np.asarray(p, dtype=float)
    n = len(x)
    p = (p-50)*n/(n-1) + 50
    p = np.clip(p, 0, 100)
    return np.percentile(x, p)

# Helper for percentile formatting
def get_pct(data, i):
    if ma.is_masked(data):
        data = data.compressed()
    return matlab_percentile(data, i)

# Helper for percentile formatting
def get_pctmm(data, i):
    imin = get_pct(data, i-1)
    imax = get_pct(data, i)
    return [imin, imax]


def apply_coldsnap_bc(obs, his, fut):
    [obs_imin, obs_imax] = get_pctmm(obs, 0.01)
    [his_imin, his_imax] = get_pctmm(his, 0.01)
    [fut_imin, fut_imax] = get_pctmm(fut, 0.01)
    fobs = ma.where((obs <= obs_imax))
    fhis = ma.where((his <= his_imax))
    ffut = ma.where((fut <= fut_imax))
    obs_avg = obs_imin if (len(obs[fobs])==0) | (obs_imin == obs_imax) else obs[fobs].mean()
    his_avg = his_imin if (len(his[fhis])==0) | (his_imin == his_imax) else his[fhis].mean()
    ratio_avg = obs_avg - his_avg
    
    his[fhis] = his[fhis] + ratio_avg
    fut[ffut] = fut[ffut] + ratio_avg
    return (his, fut)

def get_delta(obsdata, hisdata, futdata, mode, ddthres=0.0996, max_ratio=5):    
    obs = ma.fix_invalid(obsdata, fill_value = -9999)
    his = ma.fix_invalid(hisdata, fill_value = -9999)
    fut = ma.fix_invalid(futdata, fill_value = -9999)
    
    # Check for data
    if (((~his.mask).sum()==0) | ((~fut.mask).sum()==0)):
        return (hisdata, futdata)

    if (mode == 'prec'):
        # Apply dry threshold        
        obs = ma.where(obs < ddthres, 0, obs)
              
        # Fix lowest quantile if dealing with precip
        pmin = np.argmin(~(get_pct(obs, np.arange(0,101))>0))
        pmin_zero = get_pct(his, pmin)
        
        bchis = ma.where(his <= pmin_zero, 0, his)
        bcfut = ma.where(fut <= pmin_zero, 0, fut)
       
    else:
        pmin = 0
        bchis = his.copy()
        bcfut = fut.copy()

    obs_out = np.zeros(100)
    his_out = np.zeros(100)
    ratio_avg = np.zeros(100)
    bc_out = np.zeros(100)
    
    for i in [x for x in range(pmin,101)]:
        #print("percentile: ", i)
        [obs_imin, obs_imax] = get_pctmm(obs, i)
        [his_imin, his_imax] = get_pctmm(his, i)
        [fut_imin, fut_imax] = get_pctmm(fut, i)
                
        if (i == pmin):
            fobs = ma.where((obs <= obs_imax))
            fhis = ma.where((his <= his_imax))
            ffut = ma.where((fut <= fut_imax))
            
        elif(i == 100):
            fobs = ma.where(obs > obs_imin)
            fhis = ma.where(his > his_imin)
            ffut = ma.where(fut > fut_imin)

        else:
            fobs = ma.where((obs > obs_imin) & (obs <= obs_imax))
            fhis = ma.where((his > his_imin) & (his <= his_imax))
            ffut = ma.where((fut > fut_imin) & (fut <= fut_imax))
            
        obs_avg = obs_imin if (len(obs[fobs])==0) | (obs_imin == obs_imax) else obs[fobs].mean()
        his_avg = his_imin if (len(his[fhis])==0) | (his_imin == his_imax) else his[fhis].mean()
        obs_out[i-1] = obs_avg
        his_out[i-1] = his_avg
        
        
        if (mode == 'prec'):
            if ((his_avg == 0) & (obs_avg <= ddthres)):
                ratio_avg[i-1] = 1
            elif (his_avg == 0):
                ratio_avg[i-1] = obs_avg / ddthres
            else:
                ratio_avg[i-1] = obs_avg / his_avg

            ratio_avg[ratio_avg > max_ratio] = max_ratio
            
            bchis[fhis] = his[fhis] * ratio_avg[i-1]
            bcfut[ffut] = fut[ffut] * ratio_avg[i-1]

        elif ((mode == 'V10') | (mode == 'U10')):
            if (his_avg == 0):
                if (obs_avg <= ddthres) & (obs_avg > 0):
                    ratio_avg[i-1] = 1
                elif (obs_avg >= -1*ddthres) & (obs_avg < 0):
                    ratio_avg[i-1] = -1
                else:
                    ratio_avg[i-1] = obs_avg / ddthres
            else:
                ratio_avg[i-1] = obs_avg / his_avg

            ratio_avg[ratio_avg > max_ratio] = max_ratio
            
            bchis[fhis] = his[fhis] * ratio_avg[i-1]
            bcfut[ffut] = fut[ffut] * ratio_avg[i-1]
            
        else:
            ratio_avg[i-1] = obs_avg - his_avg
            bchis[fhis] = his[fhis] + ratio_avg[i-1]
            bcfut[ffut] = fut[ffut] + ratio_avg[i-1]
            #print(obs_avg, his_avg, ratio_avg[i-1])

        bc_out[i-1] = bchis[fhis].mean()

    if (mode == 'prec'):
        bchis = ma.where(bchis < ddthres, 0, bchis)
        bcfut = ma.where(bcfut < ddthres, 0, bcfut)
#    elif: ((mode == 'V10') | (mode == 'U10')):
#        bchis = ma.where(bchis < ddthres, 0, bchis)
#        bcfut = ma.where(bcfut < ddthres, 0, bcfut)
        

    bchis = bchis.data
    bcfut = bcfut.data
    
    bchis = np.where(bchis == -9999, np.nan, bchis)
    bcfut = np.where(bcfut == -9999, np.nan, bcfut)    
    
    #print("obs: ", obs_out)
    #print("his: ", his_out)
    #print("rat: ", ratio_avg)
    #print("bc: ", bc_out)
    #return(ratio_avg)
    return([obs_out, his_out, ratio_avg, bc_out, bchis])
        
def run_bc_nc(pnnl, wrf, gcm, window_size=45):
    #obs = pnnl.sel(times=~((pnnl.times.dt.month==2) & (pnnl.times.dt.day==29)))

    # Format the data

    # convert to no leap data select data for each time period
    pnnl = pnnl.convert_calendar('noleap', dim='times')    
    wrf = wrf.convert_calendar('noleap', dim='times')
    out = wrf.sel(times=slice('1981-01-01','2099-12-31'))

    varlist = list(wrf.keys())
    #varlist = list(["U10", 'V10'])
    
    for vname in varlist:
        print(vname)
        # adjust window sizes; wnd data forced to not use windows
        if ((vname == 'U10') | (vname == 'V10')):
            # perform correction for wind using data only between May 1 to Sept 30 JFM AMJ JAS OND (5-9)
            obs = pnnl.where(pnnl.times.dt.month.isin([5,6,7,8,9]))            
            out = out.where(out.times.dt.month.isin([5,6,7,8,9]))
            print('wnd filter')
        elif (window_size != 0):
            obs = pnnl.rolling(times=window_size).mean()#.dropna('times')
            out = out.rolling(times=window_size).mean()#.dropna('times')        
        else:
            obs = pnnl

        #his = out.sel(times=obs.times)
        obs = obs.sel(times=slice('1981-02-14', '2010-12-31'))
        his = out.sel(times=slice('1981-02-14', '2010-12-31'))    
        fut = out.sel(times=slice('2011-01-01', '2099-12-31'))
        out = out.sel(times=slice('1981-02-14', '2099-12-31'))

        
        for i,j in [(i,j) for i in wrf.x for j in wrf.y]:
            print(gcm, vname, i.item(), j.item())
            obsdata = obs.sel(x=i, y=j)[vname].values
            hisdata = his.sel(x=i, y=j)[vname].values
            futdata = fut.sel(x=i, y=j)[vname].values
            #print(obs.sel(x=i, y=j))
            lat = "{:5f}".format(obs.sel(x=i, y=j).lat.values)
            lon = "{:5f}".format(obs.sel(x=i, y=j).lon.values)
            print(lat, lon)

            if len(obsdata[~np.isnan(obsdata)]) == 0:
                out[vname][:,i,j] = np.nan
                #print(obsdata, 'nan')
            else:
                #bchis, bcfut = apply_bc(obsdata, hisdata, futdata, vname, ddthres=0.0996, max_ratio=5)
                #out[vname][:,i,j] = np.concatenate([bchis, bcfut])
                [obsavg, hisavg, ratioavg, bcavg, bcdata] = get_delta(obsdata, hisdata, futdata, vname, ddthres=0.0996, max_ratio=5)
                df = pd.DataFrame([obsavg, hisavg, ratioavg]).T
                df.columns = ['obs', 'his', 'ratio']
                dd = pd.DataFrame([obsdata, hisdata, bcdata]).T
                dd.columns = ['obs', 'his', 'bc']
                dd.index = obs.times
                oz = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/bcdata_wnd/"
                #df.to_csv('{}/percentile_{}_{}_{}.csv'.format(oz,vname, lat, lon), index=False)
                dd.to_csv('{}/data_{}_{}_{}.csv'.format(oz, vname, lat, lon))


                
    return out








gcms = [
    'access1.0_RCP45',
    'access1.0_RCP85',
    'access1.3_RCP85',
    'bcc-csm1.1_RCP85',
    'canesm2_RCP85',
    'ccsm4_RCP85',
    'csiro-mk3.6.0_RCP85',
    'fgoals-g2_RCP85',
    'gfdl-cm3_RCP85',
    'giss-e2-h_RCP85',
    'miroc5_RCP85',
    'mri-cgcm3_RCP85',
    'noresm1-m_RCP85',
]


def main():
    #pnnl_file = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data/pnnl/pnnl_interp.nc"        
    pnnl_file = '/home/disk/rocinante/DATA/temp/crystal_fire/wrf/scripts/hisdata.nc'
    pnnl = xr.open_dataset(pnnl_file)    

    for gcm in gcms:
        wrf_file = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/data_wndf/{}/{}.nc".format(gcm, gcm)
        wrf = xr.open_dataset(wrf_file)
        ds = run_bc_nc(pnnl, wrf, gcm, 0)
        outfile = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/bcdata_wndf/{}.nc".format(gcm)
        ds.to_netcdf(outfile)
        pnnl.close()
        wrf.close()

main()



# no window -
# may 1 to sept 30
# interpolate from wrf res to pnnl res
