import pandas as pd
import sys
import hydroeval as hev
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.dates import DateFormatter
import numpy as np

#/home/disk/tsuga2/jswon11/workdir/2020-07_Snoho-autocalibration/obs


#12134500_DailyFlows.csv - Skykomish River near Gold Bar
#12147500_DailyFlows.csv - North Fork Tolt River near Carnation
#12149000_DailyFlows.csv - Snoqualmie River near Carnation
#12155300_DailyFlows.csv - Pilchuck River near Snohomish

#SkykomishRNrGoldBar
#NFToltRNrCarnation
#SnoqualmieRNrCarnation
#PilchuckRNrSnohomish


def plot_dy(df, out, title=None):
    plt.clf()
    a = df[df.columns[0]].values
    b = df[df.columns[1]].values
    palette=['#333333', '#2596be']
    date_form = DateFormatter("%h")
    
    g = sns.FacetGrid(df.fillna(np.inf), row='year', sharex=False, height=2.5, aspect=3)
    g = g.map_dataframe(sns.lineplot, x='time', y='value', hue='variable', palette=sns.color_palette(palette,2), alpha=0.8)
    g.set_ylabels('Streamflow (cfs)')
    g.set_titles("{row_name}")
    g.fig.suptitle(title)
    

    years = df.year.unique()
    
    for ax, d in zip(g.axes.flat, years):        
        a = df[(df.year==d) & (df.variable != "Obs")].value.values
        b = df[(df.year==d) & (df.variable == "Obs")].value.values
        namask = ~np.isnan(a)
        if sum(namask) == 0:
            continue
        a = a[namask]
        b = b[namask]
        r = round(hev.kge(a,b)[1][0],2)
        nse = round(hev.nse(a,b),2)
        txt = "R = {}  NSE = {}".format(r, nse)
        xs = ax.get_xlim()
        ys = ax.get_ylim()
        x0 = (xs[1] - xs[0]) / 50 + xs[0]
        y1 = ys[1] - (ys[1] - ys[0]) * .1
        ax.text(x0, y1, txt, fontsize=14)
        ax.xaxis.set_major_formatter(date_form)
        
    plt.subplots_adjust(top=0.95)
    g.savefig(out, dpi=300)    
    print(out)


# Create plots comparing model vs observed daily flows
def plot_mn(df, out, title=None):
    plt.clf()
    a = df[df.columns[0]].values
    b = df[df.columns[1]].values
    palette=['#333333', '#2596be']
    g = sns.FacetGrid(df.fillna(np.inf), row='dec', sharex=False, height=2.5, aspect=3)
    g = g.map_dataframe(sns.lineplot, x='time', y='value', hue='variable', palette=sns.color_palette(palette,2), alpha=0.8)
    g.set_titles('')
    g.set_ylabels('Streamflow (cfs)')
    g.fig.suptitle(title)
    #g.add_legend()
    
    decs = df.dec.unique()
    
    for ax, d in zip(g.axes.flat, decs):
        a = df[(df.dec==d) & (df.variable != "Obs")].value.values
        b = df[(df.dec==d) & (df.variable == "Obs")].value.values
        namask = ~np.isnan(a)
        if sum(namask) == 0:
            continue
        a = a[namask]
        b = b[namask]
        
        r = round(hev.kge(a,b)[1][0],2)
        nse = round(hev.nse(a,b),2)
        txt = "R = {}  NSE = {}".format(r, nse)
        xs = ax.get_xlim()
        ys = ax.get_ylim()
        x0 = (xs[1] - xs[0]) / 50 + xs[0]
        y1 = ys[1] - (ys[1] - ys[0]) * .1
        ax.text(x0, y1, txt, fontsize=14)
    
    g.savefig(out, dpi=300)    
    print(out)
    

# Plots comapring modeled and observed peak flows (1) time series plots
def plot_pf(df, out, title=None):
    plt.clf()
    a = df[df.columns[1]].values
    b = df[df.columns[0]].values
    top = max(a.max(), b.max())    
    nse = round(hev.nse(a,b),2)
    r = round(hev.kge(a,b)[1][0],2)
    palette=['#333333', '#2596be']
    txt = "R = {}  NSE = {}".format(r, nse)

    ax = sns.lineplot(data=df, palette=sns.color_palette(palette, 2))    
    ax.set_title(title)
    ax.set(ylabel='Streamflow (cfs)', xlabel='')
    ax.set(ylim=(-1, top*1.1))
    plt.text(df.index[1], top, txt)
    #plt.legend(loc='lower right', bbox_to_anchor=(1,0))
    fig = ax.get_figure()
    fig.set_size_inches(5,3)
    plt.subplots_adjust(top=0.9, left=0.12, bottom=0.1, right=.93)
    fig.savefig(out, dpi=300)
    print(out)

# Plots comapring modeled and observed peak flows (2) extreme stats plots
def plot_ps(df, out, title=None):
    plt.clf()
    palette=['#333333', '#2596be']    

    g = sns.FacetGrid(df, col='returnYr', height=3, aspect=.4)
    g = g.map_dataframe(sns.scatterplot, x='variable', y='value', hue=df.variable.unique(), palette=sns.color_palette(palette, 2))    
    g.set_xlabels('')
    g.set_xticklabels('')
    g.set_ylabels('Streamflow (cfs)')
    g.set(xlim=(-.5,1.5), xticks=[])
    g.set_titles("{col_name}")
    
    g.fig.text(.35,0,'Return Intervals (Years)')
    g.savefig(out, dpi=300)
    print(out)


# Create plots comparing monthly average
def plot_ma(df, out, title=None):
    plt.clf()
    palette=['#333333', '#2596be']
    date_form = DateFormatter("%h")
    
    ax = sns.lineplot(data=df, palette=sns.color_palette(palette, 2))
    ax.xaxis.set_major_formatter(date_form)
    ax.set(ylabel='Streamflow (cfs)', xlabel='')
    ax.set_title(title)
    fig = ax.get_figure()
    fig.set_size_inches(5,3)
    plt.subplots_adjust(top=0.9, left=0.12, bottom=0.1, right=.93)
    fig.savefig(out, dpi=300)
    print(out)

    

# Creating table for 1) daily correlation and NSE 2) monthly correlation and nse 3) peak flow correlation and nse
def table1(sim, obs, odir):
    print(2)



keys = {
    'CHICOMAIN':'ChicoMS',
    'CHICOabvDKRSN':'ChicoGC',
    'DICKERSON':'Dickr',
    'KITSAP':'Kitsap_Lk_outlet'}#,
    #'LOST':'Lost',
    #'WILDCAT':'Wildc_Lk_outlet'}



names = {'CHICOMAIN':'Chico Main',
         'CHICOabvDKRSN':'Chico Above Dickerson',
         'DICKERSON':'Dickerson',
         'KITSAP':'Kitsap',
         'LOST':'Lost',
         'WILDCAT':'Wildcat'}

gcms = [
    'access1.0_RCP85_WRF_run84_1969-2099',
    'access1.3_RCP85_WRF_run84_1969-2099',
    'bcc-csm1.1_RCP85_WRF_run84_1969-2099',
    'canesm2_RCP85_WRF_run84_1969-2099',
    'ccsm4_RCP85_WRF_run84_1969-2099',
    'csiro-mk3.6.0_RCP85_WRF_run84_1969-2099',
    'fgoals-g2_RCP85_WRF_run84_1969-2099',
    'gfdl-cm3_RCP85_WRF_run84_1969-2099',
    'giss-e2-h_RCP85_WRF_run84_1969-2099',
    'miroc5_RCP85_WRF_run84_1969-2099',
    'mri-cgcm3_RCP85_WRF_run84_1969-2099',
    'noresm1-m_RCP85_WRF_run84_1969-2099',
    #'kitsap_run84_pnnl'
]

gcms = [
    'access1.0_RCP45',
    'access1.0_RCP85',    
    'access1.3_RCP85',
    'bcc-csm1.1_RCP85',
    'canesm2_RCP85',
    'ccsm4_RCP85',
    'csiro-mk3.6.0_RCP85',
    'fgoals-g2_RCP85',
    'gfdl-cm3_RCP85',
    'giss-e2-h_RCP85',
    'miroc5_RCP85',
    'mri-cgcm3_RCP85',
    'noresm1-m_RCP85',
    'kitsap_run84_pnnl']

#gcms = ['access1.0_RCP45']

data_dir = '/home/disk/rocinante/DATA/temp/chico/hyak/data_atmos/'
pnnl = '/home/disk/tsuga2/jswon11/workdir/2020-07_Snoho-autocalibration/WRF_runs/merge/class36_500k_pnnl_1980-2015/'
obs_dir = '/home/disk/rocinante/DATA/temp/chico/hyak/data_ref/stream/'
out_dir = '/home/disk/rocinante/DATA/temp/chico/hyak/plots/'



for g in gcms:   
    for k in keys:

        print(g, k)

        def read_daily():
            obs = "{}/{}.day".format(obs_dir, k)
            sim = "{}/{}/{}.day".format(data_dir, g, k)
            
            odf = pd.read_csv(obs, header=None)
            sdf = pd.read_csv(sim, sep='[ \t]+', engine='python', header=None)
            hdr = ['Year', 'Month', 'Day', 'Flow']
            odf.columns = hdr
            sdf.columns = hdr
            #odf['Flow'] *= 35.314666212661
            sdf = sdf[274:]
            odf.index = pd.to_datetime(odf[hdr[:-1]])
            sdf.index = pd.to_datetime(sdf[hdr[:-1]])
            odf = odf[[hdr[-1]]]
            sdf = sdf[[hdr[-1]]]
            df = odf.join(sdf, lsuffix='Obs', rsuffix='Sim', how='inner')
            #df = df / 1000
            df.columns = ['Obs', g]
            
            df['year'] = df.index.year + (df.index.month > 9)
            df['time'] = df.index
            df = df.melt(id_vars=['time', 'year'])
            
            out = '{}/{}_{}_Day.png'.format(out_dir, g, k)
            #plot_dy(df, out, "{}: ({})".format(names[k], k))
            plot_dy(df, out, "{}".format(names[k]))
        
        # 1)
        def read_day2mon():
            obs = "{}/{}.day".format(obs_dir, k)
            sim = "{}/{}/{}.day".format(data_dir, g, k)
            
            odf = pd.read_csv(obs, header=None)
            sdf = pd.read_csv(sim, sep='[ \t]+', engine='python', header=None)    
            hdr = ['Year', 'Month', 'Day', 'Flow']
            odf.columns = hdr
            sdf.columns = hdr
            #odf['Flow'] *= 35.314666212661
            sdf = sdf[274:]
            odf.index = pd.to_datetime(odf[hdr[:-1]])
            sdf.index = pd.to_datetime(sdf[hdr[:-1]])
            odf = odf[[hdr[-1]]]
            sdf = sdf[[hdr[-1]]]
            df = odf.join(sdf, lsuffix='Obs', rsuffix='Sim', how='inner')
            #df = df / 1000
            df.columns = ['Obs', g]
            
            df = df.resample('M').mean()
            #df = df[(df.index.year >= 1995) & (df.index.year <= 2015)]
            df['dec'] = ((df.index.year + 4) / 10).astype(int) * 10
            df['time'] = df.index
            df = df.melt(id_vars=['time', 'dec'])
            
            out = '{}/{}_{}_Month.png'.format(out_dir, g, k)
            #plot_mn(df, out, "{}: ({})".format(names[k], k))
            plot_mn(df, out, "{}".format(names[k]))
            
            
        def read_month():
            obs = "{}/{}_DailyFlows_MonthlyFlows.csv".format(obs_dir, k)
            sim = "{}/{}_MonthlyFlows.csv".format(pnnl, k)
            odf = pd.read_csv(obs)
            sdf = pd.read_csv(sim)
            odf.insert(2, 'Day', 1)
            sdf.insert(2, 'Day', 1)
            odf.index = pd.to_datetime(odf[odf.columns[:-1]])
            sdf.index = pd.to_datetime(sdf[sdf.columns[:-1]])
            odf = odf[[odf.columns[-1]]]
            sdf = sdf[[sdf.columns[-1]]]
            df = odf.join(sdf, lsuffix='Obs', rsuffix='Sim', how='inner')
            #df = df / 1000
            df.columns = ['Obs', 'PNNL']
            df['dec'] = ((df.index.year + 4) / 10).astype(int) * 10
            df['time'] = df.index
            df = df.melt(id_vars=['time', 'dec'])
            
            out = '{}/{}_Month.png'.format(out_dir, k)
            plot_mn(df, out, "{}: (USGS #{})".format(names[k], k))
            
        # 2)
        def read_mavg():
            obs = "{}/{}_DailyFlows_MonthlyFlows.csv".format(obs_dir, k)
            sim = "{}/{}_MonthlyFlows.csv".format(pnnl, k)
            odf = pd.read_csv(obs)
            sdf = pd.read_csv(sim)
            odf.insert(2, 'Day', 1)
            sdf.insert(2, 'Day', 1)
            odf.index = pd.to_datetime(odf[odf.columns[:-1]])
            sdf.index = pd.to_datetime(sdf[sdf.columns[:-1]])
            odf = odf[[odf.columns[-1]]]
            sdf = sdf[[sdf.columns[-1]]]
            df = odf.join(sdf, lsuffix='Obs', rsuffix='Sim', how='inner')
            #df = df / 1000
            df.columns = ['Obs', 'PNNL']
            df = df.groupby(df.index.month).mean()
            
            dates = ((df.index<=9) + 2000).astype(str)
            dates = [ x + "-" for x in dates]
            dates = pd.to_datetime(dates + df.index.astype(str) + '-01')
            df.index = dates
        
            out = '{}/{}_MonthAvg.png'.format(out_dir, k)
            plot_ma(df, out, "{}: (USGS #{})".format(names[k], k))
                


        # 3a)
        def read_peakflow():
            obs = "{}/{}_DailyFlows_PeakFlows.csv".format(obs_dir, k)
            sim = "{}/{}_PeakFlows.csv".format(pnnl, k)
            
            odf = pd.read_csv(obs)
            sdf = pd.read_csv(sim)
            odf.index = odf.WYear
            sdf.index = sdf.WYear
            odf = odf[[odf.columns[-1]]]
            sdf = sdf[[sdf.columns[-1]]]
            df = odf.join(sdf, lsuffix='Obs', rsuffix='Sim', how='inner')
            #df = df/1000
            df.columns = ['Obs', 'PNNL']
            
            out = '{}/{}_PeakFlow.png'.format(out_dir, k)
            plot_pf(df, out, "{}: (USGS #{})".format(names[k], k))
                
    
        # 3b)
        def read_peakstat():
            obs = "{}/{}_DailyFlows_PeakStats.csv".format(obs_dir, k)
            sim = "{}/{}_PeakStats.csv".format(pnnl, k)
            
            odf = pd.read_csv(obs)
            sdf = pd.read_csv(sim)
            odf.index = odf.returnYr
            sdf.index = sdf.returnYr
            odf = odf[[odf.columns[-1]]]
            sdf = sdf[[sdf.columns[-1]]]
            df = odf.join(sdf, lsuffix='Obs', rsuffix='Sim', how='inner')
            #df = df/1000
            df.columns = ['0', '1']
            df = df.reset_index()
            df = df.melt(id_vars=['returnYr'])
            
            out = '{}/{}_PeakStat.png'.format(out_dir, k)        
            plot_ps(df, out)
                
                

        #
        read_daily()
        read_day2mon()
        #read_month()
        #read_mavg()
        #read_peakflow()
        #read_peakstat()

    
        #sys.exit()
    

    
