import pandas as pd
import numpy as np
import sys
import os
from itertools import compress
#import warnings
#warnings.filterwarnings("ignore", category=FutureWarning)

def create_delta(liv_cells, wrf_cells):
    deltas = []
    for i in range(len(liv_cells)):
        liv = liv_cells[i]
        wrf = wrf_cells[i]
        tmax = wrf.Tmax - liv.Tmax
        tmin = wrf.Tmin - liv.Tmin
        prec = wrf.Prec / liv.Prec
        wind = wrf.Wind - liv.Wind
        df = pd.DataFrame({'Prec':prec, 'Tmax':tmax, 'Tmin':tmin, 'Wind':wind})
        deltas.append(df)
    return deltas

def create_lowelv_correction(liv_cells, wrf_cells):
    n = len(liv_cells)
    if n < 1:
        return np.nan
    delta = liv_cells[0].copy()*0
    for i in range(n):
        delta.Tmax += liv_cells[i].Tmax - wrf_cells[i].Tmax
        delta.Tmin += liv_cells[i].Tmin - wrf_cells[i].Tmin
        delta.Prec += liv_cells[i].Prec / wrf_cells[i].Prec
        delta.Wind += liv_cells[i].Wind - wrf_cells[i].Wind
    delta /= n
    return delta

def read_livold_monthly(ifile):
    print(ifile)
    ds = xr.open_dataset(ifile)
    if np.isnan(ds.Prec.max()):
        print('  |---> This cell is empty')
        return pd.DataFrame()
    ds = ds.sel(time=slice("{}-01-01".format(styr), "{}-12-31".format(edyr)))
    xmds = ds.groupby(ds.time.dt.month).sum()
    xmds = xmds / ds.groupby(ds.time.dt.month).count()
    tmax = xmds.Tmax.values
    tmin = xmds.Tmin.values
    prec = xmds.Prec.values
    wind = xmds.Wind.values
    df = pd.DataFrame({'Prec':prec, 'Tmax':tmax, 'Tmin':tmin, 'Wind':wind})
    return df

def read_liv_monthly(ifile):
    print(ifile)
    if not os.path.exists(ifile):
        print('  |---> This wrf file does not exists')
        return pd.DataFrame()
    df = pd.read_csv(ifile, header=None)
    df.columns = ['Prec', 'Tmax', 'Tmin', 'Wind']
    df.index = pd.date_range('1950-01-01', '2013-12-31')
    df = df[(df.index.year >= styr) & (df.index.year <= edyr)]
    xmdf = df.groupby(df.index.month).sum()
    xmdf = xmdf / df.groupby(df.index.month).count()
    tmax = xmdf.Tmax.values
    tmin = xmdf.Tmin.values
    prec = xmdf.Prec.values
    wind = xmdf.Wind.values
    df = pd.DataFrame({'Prec':prec, 'Tmax':tmax, 'Tmin':tmin, 'Wind':wind})
    return df


def read_wrf_monthly(ifile):
    print(ifile)
    if not os.path.exists(ifile):
        print('  |---> This wrf file does not exists')
        return pd.DataFrame()
    df = pd.read_csv(ifile, header=None)
    df.columns = ['Prec', 'Tmax', 'Tmin', 'Wind']
    df.index = pd.date_range('1981-01-01', '2015-12-31')
    df = df[(df.index.year >= styr) & (df.index.year <= edyr)]
    xmdf = df.groupby(df.index.month).sum()
    xmdf = xmdf / df.groupby(df.index.month).count()
    tmax = xmdf.Tmax.values
    tmin = xmdf.Tmin.values
    prec = xmdf.Prec.values
    wind = xmdf.Wind.values
    df = pd.DataFrame({'Prec':prec, 'Tmax':tmax, 'Tmin':tmin, 'Wind':wind})
    return df



def read_wrf_monthly_forcdhsvm(ifile):
    print(ifile)
    if not os.path.exists(ifile):
        print('  |---> This wrf file does not exists')
        return pd.DataFrame()
    df = pd.read_csv(ifile, header=None, sep='[ ]+', engine='python')
    df.columns = ['date', 'Temp', 'Wind', 'RH', 'SW', 'LW', 'Prec']
    df.index = pd.to_datetime(df.date, format='%m/%d/%Y-%H:%M:%S')
    df = df[(df.index.year >= styr) & (df.index.year <= edyr)]
    df = df[['Temp', 'Wind', 'Prec']]
    ddf = df.resample('D').mean()
    ddf['Prec'] = df.Prec.resample('D').sum() * 1000
    ddf['Tmax'] = df.Temp.resample('D').max()
    ddf['Tmin'] = df.Temp.resample('D').min()
    xmdf = ddf.groupby(ddf.index.month).sum()
    xmdf = xmdf / ddf.groupby(ddf.index.month).count()
    tmax = xmdf.Tmax.values
    tmin = xmdf.Tmin.values
    prec = xmdf.Prec.values
    wind = xmdf.Wind.values
    df = pd.DataFrame({'Prec':prec, 'Tmax':tmax, 'Tmin':tmin, 'Wind':wind})
    return df

def apply_biascorr(ifile):
    ds = xr.open_dataset(ifile)
    coord = os.path.basename(ifile).replace('.nc', '').replace('data_', '')
    delta1 = '{}/delta_{}.csv'.format(delta_dir, coord)
    delta2 = '{}/low_elev_delta.csv'.format(delta_dir)
    df1 = pd.read_csv(delta1)
    df2 = pd.read_csv(delta2)
    for i in range(len(df1)):
        ds.Tmax[ds.time.dt.month==(i+1)] += df1.Tmax[i] + df2.Tmax[i]
        ds.Tmin[ds.time.dt.month==(i+1)] += df1.Tmin[i] + df2.Tmin[i]
        ds.Prec[ds.time.dt.month==(i+1)] *= df1.Prec[i] * df2.Prec[i]
        ds.Wind[ds.time.dt.month==(i+1)] += df1.Wind[i]
    out = "{}/data_{}".format(out_dir, coord)
    df = ds.to_dataframe()
    df = df[['Prec', 'Tmax', 'Tmin', 'Wind']]
    df.to_csv(out, index=False, header=False, float_format='%.4f', sep=' ')
    print(out)


# Create bias-correction deltas
def main():
    # Read in list of cells
    latlons = ["{}_{}".format(x.lat, x.lon) for idx, x in cell_df.iterrows()]
    # Create monthly frames of livneh and wrf cells
    print('\nGet Livneh Data: Livneh 1950-2013\n' + '--'*50)
    liv_cells = [ "{}/data_{}".format(liv_dir, x) for x in latlons ]
    liv_cells = [ read_liv_monthly(x) for x in liv_cells ]
    print('\nGet WRF data: \n' + '--'*50)
    wrf_cells = [ "{}/data_{}".format(wrf_dir, x) for x in latlons ]
    wrf_cells = [ read_wrf_monthly(x) for x in wrf_cells ]

    # Check for null cells
    nulls_liv = [ (not x.empty) for x in liv_cells ]
    nulls_wrf = [ (not x.empty) for x in wrf_cells ]
    nulls = [ (nulls_liv[i] & nulls_wrf[i]) for i in range(len(latlons)) ]
    latlons = list(compress(latlons, nulls))
    liv_cells = list(compress(liv_cells, nulls))
    wrf_cells = list(compress(wrf_cells, nulls))

    # Create grid cell corrections
    print('\nCreate bias-correction:')
    pdelta = create_delta(liv_cells, wrf_cells)
    for i in range(len(latlons)):
        latlon = latlons[i]
        pdel = pdelta[i]
        pdel.to_csv("{}/delta_{}.csv".format(delta_dir, latlon), index=False)


    # Select only low elevation cells
    ll_lelv = ["{}_{}".format(x.lat, x.lon) for idx, x in lelv_df.iterrows()]
    lelv_filter = [ x in ll_lelv for x in latlons ]
    latlons = list(compress(latlons, lelv_filter))
    liv_cells = list(compress(liv_cells, lelv_filter))
    wrf_cells = list(compress(wrf_cells, lelv_filter))

    # Create low elevation corrections
    print('\nCreate low elevation delta:')
    ldelta = create_lowelv_correction(liv_cells, wrf_cells)
    ldelta.to_csv('{}/low_elev_delta.csv'.format(delta_dir), index=False)


def main_apply():
    #cell_df = pd.read_csv('smallest_intersect.csv')
    print('Create bias-corrected forcings: \n')

    for idx, row in cell_df.iterrows():
        ifile = '{}/data_{}_{}.nc'.format(liv_dir, row.lat, row.lon)
        apply_biascorr(ifile)


if __name__ == '__main__':
    print(1)
    liv_dir = "data_bclivneh/"
    wrf_dir = "data_rawWRF/"
    delta_dir = "delta/"
    out_dir = "output/"
    cell_df = pd.read_csv('tables/wrfcells.csv', dtype='str')
    lelv_df = pd.read_csv('tables/wrfpoints_15km_blw500m.csv', dtype='str')
    #cell_df = lelv_df
    styr = 1981
    edyr = 2013

    os.makedirs(delta_dir, exist_ok=True)
    os.makedirs(out_dir, exist_ok=True)

    print('Bias-correction period: {} - {}'.format(styr, edyr))
    print('Delta Directory: {}'.format(delta_dir))
    print('Output Directory: {}'.format(out_dir))
    print('--'*40 + '\n')
    
    
    main()
    exit()
    
    
    liv_dir = "/home/disk/rocinante/DATA/Livneh/Livneh_WWA_1915_2011/vic_forc_nc/"
    wrf_dir = "/home/disk/rocinante/DATA/WRF/NNRP/vic_16d/WWA_1950_2010/raw/forcings_ascii/"
    delta_dir = "/home/disk/tsuga2/jswon11/workdir/2020-08_skagit_water_supply_project/biascorr_forc/data/delta/delta{}_winds5/".format(styr)
    out_dir = "/home/disk/tsuga2/jswon11/workdir/2020-08_skagit_water_supply_project/biascorr_forc/data/his/period_{}-{}_winds5/".format(styr, edyr)

    #main()
    main_apply()



    
#take the buffer zones from arcgis to find cells that fall in eachzone for 15km and 10km
#feed the buffer cells into python, compare with dem values to read in low elevation cells and create an average
#calculate the average from the original rawWRF cells using the same buffer cells
#create a monthly delta value from the comparison of the two averages
