import xarray as xr
import pandas as pd



def getdata(dataset, metric):
    idir = '/home/disk/rocinante/DATA/temp/TNC_stormwater/{}/maps/'.format(dataset.lower())
    d1 = xr.open_dataset('{}/{}_{}_interp.nc'.format(idir, dataset, metric))
    d2 = xr.open_dataset('{}/WRF-NARR_{}-time_{}.nc'.format(idir, dataset, metric)) 
    d3 = d1 - d2
    
    df1 = d1.to_dataframe()
    df2 = d2.to_dataframe()
    df3 = d3.to_dataframe()
    df3['PREC'] = df1.PREC - df2.PREC

    df1['dataset'] = dataset
    df1['metric'] = metric
    df2['dataset'] = '{} fitted WRF-NARR'.format(dataset)
    df2['metric'] = metric
    df3['dataset'] = '{} - WRF-NARR'.format(dataset)
    df3['metric'] = metric
    
    df = pd.concat([df1, df2, df3])
    return df

def getbc(dataset, metric):
    idir = '/home/disk/rocinante/DATA/temp/TNC_stormwater/bias_correction/maps/'
    pdir = '/home/disk/rocinante/DATA/temp/TNC_stormwater/prism-daily/maps/'
    d1 = xr.open_dataset('{}/PRISM-{}_{}_interp.nc'.format(idir, dataset, metric))
    d2 = xr.open_dataset('{}/PRISM-daily_{}_interp.nc'.format(pdir, metric)) 
    d3 = d1 - d2
    
    df1 = d1.to_dataframe()
    df2 = d2.to_dataframe()
    df3 = d3.to_dataframe()
    df3['PREC'] = df1.PREC - df2.PREC

    df1['dataset'] = dataset
    df1['metric'] = metric
    #df2['dataset'] = 'WRF-NARR'
    #df2['metric'] = metric
    df3['dataset'] = '{} - PRISM-daily'.format(dataset)
    df3['metric'] = metric
    
    df = pd.concat([df1, df3])
    return df


def loop_set(metric):
    df_imerg = getdata('IMERG', metric)
    df_prism = getdata('PRISM', metric)
    df_qpe = getdata('QPE', metric)
    df_prismd = getdata('PRISM-daily', metric)
    df_bc0_25 = getbc('bc0.25', metric)
    df_bc0_5 = getbc('bc0.5', metric)
    df_bc1_0 = getbc('bc1.0', metric)

    df = pd.concat([df_imerg, df_prism, df_qpe, df_prismd, df_bc0_25, df_bc0_5, df_bc1_0])
    return df



d1 = loop_set('Quantile_OCT-MAR')
d2 = loop_set('SEASONAL_TTL_JAN-MAR')
d3 = loop_set('SEASONAL_TTL_APR-JUN')
d4 = loop_set('SEASONAL_TTL_JUL-SEP')
d5 = loop_set('SEASONAL_TTL_OCT-DEC')

df = pd.concat([d2, d3, d4, d5])
d1 = d1.dropna()
d2 = d2.dropna()
d1 = d1.drop_duplicates()
d2 = d2.drop_duplicates()
d1.to_csv('quantile.csv')
df.to_csv('seasonal.csv')
