import pandas as pd
from glob import glob
import os
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.dates import DateFormatter
import hydroeval as hev
import numpy as np
from PIL import Image, ImageOps
from datetime import datetime

ref_dir = "/home/disk/rocinante/DATA/temp/kcp3/calibration/data/hyak/dhsvm/obs/"
#ref_dir = "/home/disk/rocinante/DATA/temp/kcp3/output/snoho/rawWRF_ptmin_delta_tlapse_pmap/"
#plot_dir = "/home/disk/tsuga2/jswon11/workdir/2021_09_KingCounty-Phase3/plots/"
plot_dir = "/home/disk/rocinante/DATA/temp/kcp3/calibration/data/hyak/dhsvm/plots/"
refext = "_daily.csv"
#refext = ".day"


data_dir = "/home/disk/rocinante/DATA/temp/kcp3/calibration/data/hyak/dhsvm/"
basins = ["snoho"]

bcs = ['SAVE_00001', 'SAVE_00002','SAVE_00003','SAVE_00004','SAVE_00005','SAVE_00006','SAVE_00007','SAVE_00008','SAVE_00009','SAVE_00010',]
#bcs = ['SAVE_00001']


table = dict({
    "12134500": ["SkykomishRnrGoldBar", .01, 100, 250, 50],
    "12145500": ["RagingRnrFallCity", .001, 10, 10, 3],
    "12147500": ["NFToltRnrCarnation", .001, 15, 20, 8],
    "12149000": ["SnoqualmieRnrCarnation", .01, 125, 250, 75],
    "12150800": ["SnohomishRNrMonroe", .1, 200, 700, 175],
    "12155300": ["PilchuckRnrSnohomish", .005, 20, 50, 10]    
})

#table = dict({
#    "12147500": ["NFToltRnrCarnation", .001, 15, 20, 8],
#    "12148500": ["ToltRNrCarnation", .001, 15, 30, 6],
#    "12147470": ["NFToltRAbvYellowCr", .001, 15, 25, 8],
#    "12147600": ["SFToltRNrIndex", .001, 15, 5, 1],
#    "12148000": ["SFToltRNrCarnation", .001, 15, 10, 2],
#    "12148300": ["SFToltRBlwRegBasin", .001, 15, 12, 2]
#})



#table = dict({
#    "12134500": ["SkykomishRnrGoldBar", .01, 400, 250],
#    "12145500": ["RagingRnrFallCity", .001, 20, 10],
#    "12147500": ["NFToltRnrCarnation", .001, 35, 20],
#    "12149000": ["SnoqualmieRnrCarnation", .01, 300, 250],
#    "12150800": ["SnohomishRNrMonroe", .1, 800, 700],
#    "12155300": ["PilchuckRnrSnohomish", .005, 45, 50]
#})


def plot_monhydro(df, out, s, y1, y2):
    df = df.resample('M').sum()
    df = df.groupby(df.index.month).mean()
    df.index = [ pd.to_datetime("{}/1/{}".format(x, 2000-(x>9))) for x in df.index ]
    df.index.name = "date"
    df = df.melt(var_name="Source", value_name="Flow", ignore_index=False).reset_index()

    fig, ax = plt.subplots(1,1)
    g = sns.lineplot(data=df, x="date", y="Flow", hue="Source")
    g.set_xlabel("")
    g.set_ylabel("Streamflow (kcfs)")
    g.set_title(s)
    g.set(ylim=(y1,y2))
    ax.xaxis.set_major_formatter(DateFormatter('%h'))
    fig.savefig(out, dpi=300)
    plt.close(fig)
    
def plot_cdf(df, out, s, y1, y2):
    qs = 0.005
    df.index.name = "date"
    dq = df.quantile(np.arange(qs,1,qs)).reset_index()
    
    dq.columns = ['q', 'Sim', 'Obs']
    #dq['Sim'] = np.log(dq.Sim)
    #dq['Obs'] = np.log(dq.Obs)

    dq = dq.melt(var_name="Source", value_name="Flow", id_vars="q")
    fig, ax = plt.subplots(1,1)
    g = sns.lineplot(data=dq, x="q", y="Flow", hue="Source")
    g.set_xlabel("")
    g.set_ylabel("Flow (kcfs)")
    #g.set(ylim=(0,np.log(y2)))
    plt.yscale('log', base=10)
    plt.legend(loc='upper left')
    ax.set_ylim(ymin=y1,ymax=y2)
    g.set_title(s)

    a = df.Obs.values
    b = df.Sim.values    
    r = round(hev.kge(a,b)[1][0],2)
    nse = round(hev.nse(a,b),2)
    nselog = round(hev.nse(np.log10(a), np.log10(b)),2)
    txt = "R = {}  NSE = {}  NSE Log = {}".format(r, nse, nselog)

    xs = ax.get_xlim()
    ys = ax.get_ylim()
    x0 = (xs[1] - xs[0]) / 50 + xs[0]
    y1 = ys[0]*2 #+ (ys[1]-ys[0])*0.01
    ax.text(x0, y1, txt, fontsize=14)

    
    fig.savefig(out, dpi=300)
    plt.close(fig)    

def plot_as(df, out, s, y1, y2):
    df.index.name = "date"
    df = df.melt(var_name="Source", value_name="Flow", ignore_index=False).reset_index()
    
    fig, ax = plt.subplots(1,1)
    g = sns.lineplot(data=df, x="date", y="Flow", hue="Source")
    g.set_xlabel("")
    g.set_ylabel("Streamflow (kcfs)")
    g.set_title(s)
    g.set(ylim=(y1,y2*8))
    ax.xaxis.set_major_formatter(DateFormatter('%Y'))
    fig.savefig(out, dpi=300)
    plt.close(fig)
    
def plot_ts(df, out, s, y1, y2):
    date_form = DateFormatter("%h")
    df['year'] = df.index.year + (df.index.month > 9)
    df['time'] = df.index
    df = df.melt(id_vars=['time', 'year'])
    df["value"] = df.value.astype(float)

    fig, ax = plt.subplots(1,1)
    g = sns.FacetGrid(df, row='year', sharex=False, height=2.5, aspect=3)
    g = g.map_dataframe(sns.lineplot, x='time', y='value', hue='variable', alpha=0.8)
    g.set_ylabels('Flow (kcfs)')
    g.set_titles("{row_name}")
    g.fig.suptitle(s)    
    g.set(ylim=(y1,y2))

    years = df.year.unique()
    lg = True
    for ax, d in zip(g.axes.flat, years):
        a = df[(df.year==d) & (df.variable == "Obs")].value.values
        b = df[(df.year==d) & (df.variable == "Sim")].value.values
        if (len(a)==0) | (len(b)==0):
            r = "NaN"
            nse = "NaN"
            nselog = "NaN"
        else:
            r = round(hev.kge(a,b)[1][0],2)
            nse = round(hev.nse(a,b),2)
            nselog = round(hev.nse(np.log10(a), np.log10(b)),2)
        txt = "R = {}  NSE = {}  NSE Log = {}".format(r, nse, nselog)
        if lg:
            legend = ax.legend()
            lg = False
        ax.set_xlim(datetime(d-1, 10, 15),datetime(d, 10,15))
        xs = ax.get_xlim()
        ys = ax.get_ylim()
        x0 = (xs[1] - xs[0]) / 50 + xs[0]
        y1 = ys[1] - (ys[1] - ys[0]) * .1
        ax.text(x0, y1, txt, fontsize=14)
        ax.xaxis.set_major_formatter(date_form)

    hgt = 360
    pd = 120
    n = len(years)
    x = (hgt*n) / (hgt*n + pd)
    plt.subplots_adjust(top=x-0.02)
    g.savefig(out, dpi=300)
    plt.close(g.fig)

def format_df(df):
    df.columns = ["year", "month", "day", "flow"]
    df.index = pd.to_datetime(df[["year", "month", "day"]])
    df.flow = pd.to_numeric(df.flow, errors="coerce") / 1000
    df = df[["flow"]]
    return df


palette=['#2596be', "#333333"]
sns.set_palette(sns.color_palette(palette))

for basin in basins:    
    sites = sorted(glob("{}/{}/*{}".format(ref_dir, basin, refext)))    
    #sites = sorted(glob("{}/*{}".format(ref_dir, refext)))
    
    sites = [ (x, os.path.basename(x).split(refext)[0]) for x in sites]
    sites = [ (x,y) for x,y in sites if y in table.keys() ]
        
    for bc in bcs:
        #ddir = "{}/{}/{}/".format(data_dir, basin, bc)        
        #ddir = "/home/disk/rocinante/DATA/temp/kcp3/output/prism_map/prismWRF_pmap/"        
        #pdir = "{}/{}/{}/".format(plot_dir, basin, bc)
        ddir = "{}/{}/".format(data_dir, bc)
        pdir = "{}/{}/".format(plot_dir, bc)
        os.makedirs(pdir, exist_ok=True)
        print(basin, "-", bc)

        print(sites)
        for obs_file,site in sites:
            print("\t", site)
            sn = table[site][0]
            print("\t", sn)
            # get sim
            ifile = glob("{}/{}*.day".format(ddir, site.split('_')[0]))
            if len(ifile) == 0:
                print(site, " Sim file not found: {}/{}.day".format(ddir, site.split('_')[0]))
                continue
            sim = pd.read_csv(ifile[0], header=None, sep='[ \t]+', engine='python')
            sim = sim[90:]
            sim = format_df(sim) *35.314666721489

            # get obs
            obs = pd.read_csv(obs_file, header=None, sep='[ \t,]+', engine='python')
            obs = format_df(obs)

            # merge data
            df = sim.merge(obs, left_index=True, right_index=True, suffixes=("_Sim", "_Obs"))
            #df = df[(df.index.year >= 2012) & (df.index.year <= 2015)]

            df.columns = [ x.split('_')[1] for x in df.columns ]
            #df = df[(df.index.year >= 1992) & (df.index.year <= 2015)]
            #df = df[(df.index.year >= 1982) & (df.index.year <= 2015)]
            #df = df[(df.index.year >= 1992) & (df.index.year <= 1992)]
            wyear = df.index.year + (df.index.month > 9)
            df = df[(wyear >= 2013) & (wyear <= 2015)]
            title = "{}: {}".format(sn, bc)

            if df.empty:
                print("\t-> No data overlap")
                print(sim)
                print(obs)
                continue

            y1 = table[site][1]
            y2 = table[site][2]
            y3 = table[site][3]
            y4 = table[site][4]
            
            
            out = "{}/{}_{}_monavg_hydrograph.png".format(pdir, site, sn)
            plot_monhydro(df, out, title, 0, y3)
            
            out = "{}/{}_{}_cdf.png".format(pdir, site, sn)
            plot_cdf(df, out, title, y1, y2)
            
            out = "{}/{}_{}_dailyTimeseries.png".format(pdir, site, sn)
            plot_ts(df, out, title, y1, y4)
            
            out = "{}/{}_{}_monthlyTimeseries.png".format(pdir, site, sn)
            mf = df.resample('BMS').sum()
            plot_ts(mf, out, title, y1, y3*1.25)
            
            out = "{}/{}_{}_AnnualTotal.png".format(pdir, site, sn)
            af = df.resample('AS-OCT').sum()
            plot_as(af, out, title, y1, y3*1.25)
            
            #print(min(df.min()))
            #print(max(df.max()))
        
        # merge images
        print('Merging Images')
        #types = ['monthlyTimeseries']
        #types = ['cdf', 'monavg_hydrograph', 'Timeseries']
        types = ['cdf', 'monavg_hydrograph', 'dailyTimeseries', 'monthlyTimeseries', 'AnnualTotal']
        #types = ['cdf', 'monavg_hydrograph']
        #types = ['Timeseries', 'monthlyTimeseries']
        #types = ['AnnualTotal']
        
        #merge_plot()
        for t in types:
            print('\t', t)
            pdir = "{}/{}/".format(plot_dir, bc)
            imgs = [ Image.open(x) for x in sorted(glob("{}/121*_{}.png".format(pdir, t))) ]
            widths, heights = zip(*(i.size for i in imgs))
            #total_wdt = int(sum(widths)/2)
            #max_hgt = int(2*max(heights))
            if "Timeseries" in t:
                total_wdt = sum(widths)
                max_hgt = max(heights)                
            else:
                total_wdt = int(max(widths)*np.ceil(len(imgs)/2))
                max_hgt = int(2*max(heights))                

            new_im = Image.new('RGB', (total_wdt, max_hgt))
            x_offset = 0
            y_offset = 0
            for im in imgs:
                new_im.paste(im, (x_offset,y_offset))
                x_offset += im.size[0]
                if x_offset >= total_wdt:
                    x_offset = 0
                    y_offset += max(heights)
            new_im.save('{}/plot6_{}_{}.png'.format(pdir, t, bc))
