#!/usr/bin/env python

import sys
import numpy as np
import numpy.ma as ma
import inspect
np.set_printoptions(suppress=True)


# Get matlab style percentile values
def matlab_percentile(x, p):
    p = np.asarray(p, dtype=float)
    n = len(x)
    p = (p-50)*n/(n-1) + 50
    p = np.clip(p, 0, 100)
    return np.percentile(x, p)

# Helper for percentile formatting
def get_pct(data, i):
    if ma.is_masked(data):
        data = data.compressed()
    return matlab_percentile(data, i)

# Helper for percentile formatting
def get_pctmm(data, i):
    imin = get_pct(data, i-1)
    imax = get_pct(data, i)
    return [imin, imax]

def apply_bc(obsdata, hisdata, futdata, mode, ddthres=0.0996, max_ratio=5):
    
    obs = ma.fix_invalid(obsdata, fill_value = -9999)
    his = ma.fix_invalid(hisdata, fill_value = -9999)
    fut = ma.fix_invalid(futdata, fill_value = -9999)
    
    # Check for data
    if (((~his.mask).sum()==0) | ((~fut.mask).sum()==0)):
        return (hisdata, futdata)

    if (mode == 'prec'):
        # Apply dry threshold
        obs = ma.where(obs < ddthres, 0, obs)
              
        # Fix lowest quantile if dealing with precip
        pmin = np.argmin(~(get_pct(obs, np.arange(0,101))>0))
        pmin_zero = get_pct(his, pmin)
        
        bchis = ma.where(his <= pmin_zero, 0, his)
        bcfut = ma.where(fut <= pmin_zero, 0, fut)
        
    else:
        pmin = 0
        bchis = his.copy()
        bcfut = fut.copy()

    ratio_avg = np.zeros(100)
    
    for i in [x for x in range(pmin,101)]:
        [obs_imin, obs_imax] = get_pctmm(obs, i)
        [his_imin, his_imax] = get_pctmm(his, i)
        [fut_imin, fut_imax] = get_pctmm(fut, i)
        
        if (i == pmin):
            fobs = ma.where((obs <= obs_imax))
            fhis = ma.where((his <= his_imax))
            ffut = ma.where((fut <= fut_imax))
            
        elif(i == 100):
            fobs = ma.where(obs > obs_imin)
            fhis = ma.where(his > his_imin)
            ffut = ma.where(fut > fut_imin)

        else:
            fobs = ma.where((obs > obs_imin) & (obs <= obs_imax))
            fhis = ma.where((his > his_imin) & (his <= his_imax))
            ffut = ma.where((fut > fut_imin) & (fut <= fut_imax))
            
        obs_avg = obs_imin if (len(obs[fobs])==0) | (obs_imin == obs_imax) else obs[fobs].mean()
        his_avg = his_imin if (len(his[fhis])==0) | (his_imin == his_imax) else his[fhis].mean()
        
        if (mode == 'prec'):
            if ((his_avg == 0) & (obs_avg <= ddthres)):
                ratio_avg[i-1] = 1
            elif (his_avg == 0):
                ratio_avg[i-1] = obs_avg / ddthres
            else:
                ratio_avg[i-1] = obs_avg / his_avg

            ratio_avg[ratio_avg > max_ratio] = max_ratio
            
            bchis[fhis] = his[fhis] * ratio_avg[i-1]
            bcfut[ffut] = fut[ffut] * ratio_avg[i-1]

        else:
            ratio_avg[i-1] = obs_avg - his_avg
            bchis[fhis] = his[fhis] + ratio_avg[i-1]
            bcfut[ffut] = fut[ffut] + ratio_avg[i-1]

    if (mode == 'prec'):
        bchis = ma.where(bchis < ddthres, 0, bchis)
        bcfut = ma.where(bcfut < ddthres, 0, bcfut)

    bchis = bchis.data
    bcfut = bcfut.data


    bchis = np.where(bchis == -9999, np.nan, bchis)
    bcfut = np.where(bcfut == -9999, np.nan, bcfut)

    return(bchis, bcfut)
