
import argparse
import pandas as pd
import os
import hydroeval as hev


def calc_metric(df):
    ndf = df.dropna()
    kge = hev.kge(ndf.sim_temp.values, ndf.obs_temp.values)
    nse = hev.nse(ndf.sim_temp.values, ndf.obs_temp.values)
    return [float(kge[0]), float(kge[1]), nse]


def format_plot(ax, df, title=None):
    met = calc_metric(df)

    r = round(met[1],2)
    nse = round(met[2],2)
    txt = "R = {}  NSE = {}".format(r, nse)

    xs = ax.get_xlim()
    ys = ax.get_ylim()
    x0 = (xs[1] - xs[0]) / 50 + xs[0]
    y1 = ys[1] - (ys[1] - ys[0]) * .1
    ax.set_ylim(ys[0], ys[1]*1.1)
    ax.text(x0, y1*1.1, txt, fontsize=14)
    ax.set_ylabel('Stream Temperature (C)')
    ax.set_title(title)
    fig = ax.get_figure()
    return fig
    

def plot_hourly(df, odir, fname):
    print('--| Creating hourly plots: ', fname)
    ax = df.plot()
    fig = format_plot(ax, df, fname)    
    fig.savefig('{}/HOURLY_{}.png'.format(odir, fname))

def plot_dmax(df, odir, fname):
    print('--| Creating daily max plots: ', fname)
    dmax = df.resample('D').max()
    ax = dmax.plot()
    fig = format_plot(ax, dmax, fname)
    fig.savefig('{}/DMAX_{}.png'.format(odir, fname))

def plot_7dadmax(df, odir, fname):
    print('--| Creating 7 day average daily max plots: ', fname)
    dmax = df.resample('D').max()
    d7max = dmax.resample('7D').mean()
    ax = d7max.plot()
    fig = format_plot(ax, d7max, fname)
    fig.savefig('{}/7DADMAX_{}.png'.format(odir, fname))

    
def main():
    ## Parser
    parser = argparse.ArgumentParser()    
    parser.add_argument('input_file', help="Temp file or input directory")
    parser.add_argument('table', help='Lookup table')
    parser.add_argument('rdir', help='Reference directory')
    parser.add_argument('odir', help='Output directory')
    parser.add_argument('-c', action='store_true', default=False,
                        help='Create hourly plots')
    parser.add_argument('-d', action='store_true', default=False,
                        help='Create daily max plots')
    parser.add_argument('-e', action='store_true', default=False,
                        help='Create 7DADMAX plots')
    parser.add_argument('--filter', default=None,
                        help='List of segments to extract. Treats input_file as directory')
    args = parser.parse_args()

    ifile = args.input_file
    tbl_file = args.table
    rdir = args.rdir
    odir = args.odir
    

    if args.filter:
        flist = [ "{}/seg{}.temp".format(ifile, x) for x in args.filter.split(',') ]
    else:
        flist = [ifile]

    print("sites:", flist)

    for ifile in flist:
        # Get Reference file
        n = os.path.basename(ifile).replace('seg', '').replace('.temp', '')
        tbl = pd.read_csv(tbl_file)    
        fname = tbl[tbl.rbm_id == int(n)].short.values[0]
        rfile = "{}/CLEAN_{}.csv".format(rdir, fname)
        rf = pd.read_csv(rfile)
        rf.columns = ['year', 'month', 'day', 'hour', 'minute', 'temp']
        rf.index = pd.to_datetime(rf[rf.columns[:-1]])
        rf = rf[['temp']].resample('H').mean()
        

        # Merge reference and data
        df = pd.read_csv(ifile)
        df.index = pd.to_datetime(df.date)
        df = df[['temp']]
        ixt = df.index.intersection(rf.index)
        df = df.loc[ixt]
        df.columns = ['sim_temp']
        df['obs_temp'] = rf.loc[ixt]
        
        xname = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(ifile))))
        fname = "{}_{}".format(fname, xname)
        print(fname)
        
        # Create plots
        if args.c:
            plot_hourly(df, odir, fname)
            
        if args.d:
            plot_dmax(df, odir, fname)
            
        if args.e:
            plot_7dadmax(df, odir, fname) 
                
    

    
if __name__ == "__main__":
    main()





