#!/usr/bin/env python

import sys
import numpy as np
import numpy.ma as ma
import inspect
np.set_printoptions(suppress=True)


# Get matlab style percentile values
def matlab_percentile(x, p):
    p = np.asarray(p, dtype=float)
    n = len(x)
    p = (p-50)*n/(n-1) + 50
    p = np.clip(p, 0, 100)
    return np.percentile(x, p)

# Helper for percentile formatting
def get_pct(data, i):
    if ma.is_masked(data):
        data = data.compressed()
    return matlab_percentile(data, i)

# Helper for percentile formatting
def get_pctmm(data, i):
    imin = get_pct(data, i-1)
    imax = get_pct(data, i)
    return [imin, imax]

def apply_bc(obsdata, hisdata, futdata, mode, ddthres, max_ratio=5):
    
    obs = ma.fix_invalid(obsdata, fill_value = -9999)
    his = ma.fix_invalid(hisdata, fill_value = -9999)
    fut = ma.fix_invalid(futdata, fill_value = -9999)
    
    # Check for data
    if (((~his.mask).sum()==0) | ((~fut.mask).sum()==0)):
        return (hisdata, futdata)

    if (mode == 'prec'):
        # Apply dry threshold
        obs = ma.where(obs < ddthres, 0, obs)

        #obs = ma.where(obs < 0.25, 0, obs)
        #his = ma.where(his < 0.25, 0, his)
        #fut = ma.where(fut < 0.25, 0, fut)
        
              
        # Fix lowest quantile if dealing with precip
        pmin = np.argmin(~(get_pct(obs, np.arange(0,101))>0))
        pmin_zero = get_pct(his, pmin)
        
        bchis = ma.where(his <= pmin_zero, 0, his)
        bcfut = ma.where(fut <= pmin_zero, 0, fut)
        
    else:
        pmin = 0
        bchis = his.copy()
        bcfut = fut.copy()

    ratio_avg = np.zeros(100)
    oavg = np.zeros(100)
    havg = np.zeros(100)
    hrap = np.zeros(his.shape)
    oact = np.zeros(100)
    hact = np.zeros(100)
    qq = np.zeros(100)
    
    for i in [x for x in range(pmin,101)]:
        [obs_imin, obs_imax] = get_pctmm(obs, i)
        [his_imin, his_imax] = get_pctmm(his, i)
        [fut_imin, fut_imax] = get_pctmm(fut, i)
        
        if (i == pmin):
            fobs = ma.where((obs <= obs_imax) & (obs > 0))
            fhis = ma.where((his <= his_imax) & (his > 0))
            ffut = ma.where((fut <= fut_imax) & (fut > 0))
            
        elif(i == 100):
            fobs = ma.where(obs > obs_imin)
            fhis = ma.where(his > his_imin)
            ffut = ma.where(fut > fut_imin)

        else:
            fobs = ma.where((obs > obs_imin) & (obs <= obs_imax))
            fhis = ma.where((his > his_imin) & (his <= his_imax))
            ffut = ma.where((fut > fut_imin) & (fut <= fut_imax))

            
        obs_avg = obs_imin if (obs_imin == obs_imax) else obs[fobs].mean()
        his_avg = his_imin if (his_imin == his_imax) else his[fhis].mean()
        oavg[i-1] = obs_avg
        havg[i-1] = his_avg
        oact[i-1] = len(obs[fobs])
        hact[i-1] = len(his[fhis])
        np.save('h{}.npy'.format(i), his[fhis].data)
        

        print('Step: ', i)
        #print(obs_avg, his_avg)
        #print('min ',obs_imin, his_imin, fut_imin)
        #print('max', obs_imax, his_imax, fut_imax)
        #print(obs[fobs].shape, his[fhis].shape)
        #print('---------')

        
        if (mode == 'prec'):
            if ((his_avg == 0) & (obs_avg <= ddthres)):
                ratio_avg[i-1] = 1
            elif (his_avg == 0):
                ratio_avg[i-1] = obs_avg / ddthres
            else:
                ratio_avg[i-1] = obs_avg / his_avg

            ratio_avg[ratio_avg > max_ratio] = max_ratio

            
            bchis[fhis] = his[fhis] * ratio_avg[i-1]
            bcfut[ffut] = fut[ffut] * ratio_avg[i-1]
            hrap[fhis] = ratio_avg[i-1]

            #his = ma.where(his < ddthres, 0, his)
            #fut = ma.where(fut < ddthres, 0, fut)

        else:
            ratio_avg[i-1] = obs_avg - his_avg
            bchis[fhis] = his[fhis] + ratio_avg[i-1]
            bcfut[ffut] = fut[ffut] + ratio_avg[i-1]

            
    bchis = ma.where(bchis < ddthres, 0, bchis)
    bcfut = ma.where(bcfut < ddthres, 0, bcfut)

    bchis = bchis.data
    bcfut = bcfut.data


    bchis = np.where(bchis == -9999, np.nan, bchis)
    bcfut = np.where(bcfut == -9999, np.nan, bcfut)

    
    #print('--------------------------')
    #print(bchis)
    #print(bcfut)
    #print('--------------------------')
    #print(bchis.mean())
    #print(bcfut.mean())
    
    
    #print(hisdata, futdata)
    #print(his, fut)
    print(ratio_avg)
    #print(max_ratio)
    
    #print(his.shape, fut.shape)
    #np.save('obsd.npy', obsdata)
    #np.save('his.npy', bchis.data)
    #np.save('fut.npy', bcfut.data)
    #np.save('ratio.npy', ratio_avg)
    #np.save('obsd.npy', obsdata)
    #np.save('hisd.npy', hisdata)
    #np.save('futd.npy', futdata)
    #np.save('oavg.npy', oavg)
    #np.save('havg.npy', havg)
    #np.save('hrap.npy', hrap)
    #np.save('oact.npy', oact)
    #np.save('hact.npy', hact)
    #print(his, fut)
    #print('x')

    return(bchis, bcfut)


    





    






































































































    

    
    
def apply_bc2(obsdata, hisdata, futdata, mode, ddthres, max_ratio=5):
    his_out = np.full(hisdata.shape, np.nan)
    fut_out = np.full(futdata.shape, np.nan)
    
    # Clear out nans
    #obs = obsdata[~np.isnan(obsdata)]
    #his = hisdata[~np.isnan(hisdata)]
    #fut = futdata[~np.isnan(futdata)]
    
    if (his.size | fut.size):
        bc_his = hisdata.copy()
        bc_fut = futdata.copy()
            
        if (mode == 'prec'):
            # Apply dry threshold
            obs[obs < ddthres] = 0            

            # Fix lowest quantile if dealing with precip
            pmin = np.argmin(~(np.percentile(obs, np.arange(0,100))>0))
            pmin_zero = np.percentile(his, pmin)
            
            #print(pmin)
            #sys.exit()
            #pidx_minobs = np.argmin(~(np.percentile(obs, np.arange(0,100))>0))
            #pmin_zero = np.percentile(his, pidx_minobs)

            bc_his[bc_his <= pmin_zero] = 0
            bc_fut[bc_fut <= pmin_zero] = 0
            
        else:
            pmin = 0
            
    ratio_avg = np.zeros(100)
    for i in [x+1 for x in range(pmin,100)]:
        [obs_imin, obs_imax] = get_pct(obs, i)
        [his_imin, his_imax] = get_pct(his, i)
        [fut_imin, fut_imax] = get_pct(fut, i)

    
        if (i == pmin):
            fobs = obs <= obs_imin
            fhis = his <= his_imax
            ffut = fut <= fut_imax
            
        elif(i == 100):
            fobs = obs >= obs_imin
            fhis = his <= his_imax
            ffut = fut <= fut_imax

        else:
            fobs = (obs > obs_imin) & (obs <= obs_imax)
            fhis = (his > his_imin) & (his <= his_imax)
            ffut = (fut > fut_imin) & (fut <= fut_imax)

	                   
        obs_avg = obs_imin if (obs_imin == obs_imax) else obs[fobs].mean()
        his_avg = his_imin if (his_imin == his_imax) else his[fhis].mean()
    
        if (mode == 'prec'):
            if (his_avg == 0 * obs_avg <= ddthres):
                ratio_avg[i-1] = 1
            elif (his_avg == 0):
                ratio_avg[i-1] = obs_avg / ddthres
            else:
                ratio_avg[i-1] = obs_avg / his_avg

            ratio_avg[ratio_avg > max_ratio] = max_ratio

            bc_his[fhis] = bc_his[fhis] * ratio_avg[i-1]
            bc_fut[ffut] = bc_fut[ffut] * ratio_avg[i-1]

            bc_his[bc_his < ddthres] = 0
            bc_fut[bc_fut < ddthres] = 0

        else:
            ratio_avg[i-1] = obs_avg - his_avg
            bc_his[fhis] = bc_his[fhis] + ratio_avg[i-1]
            bc_fut[ffut] = bc_fut[ffut] + ratio_avg[i-1]
            

        his = bc_his
        fut = bc_fut
        
        
