import pandas as pd
from glob import glob
import os

# grab all forcings
# calculate summary

# Metrics
# Average annual temp, Average seasonal temp, Total Annual precip, Total Seasonal precip all for specified time period
# Time periods: 1980s, 2020s, 2050s, 2080s
# Datasets: MACA, WRF, WRF-NARR
# Domain: Lake Tapps Grid Cell, Puyallup basin, White Basin

def get_data(ifile):
    df = pd.read_csv(ifile, sep=' ', header=None)
    df.index = pd.to_datetime(df[0], format='%m/%d/%Y-%H:%M:%S')
    df = df[[1,6]]
    df.columns = ['temp', 'prec']
    df.index.name = 'date'
    return df

def get_stats(df, styr, edyr, ifile, gcm):
    s = os.path.basename(ifile).split('_')
    lat = s[1]
    lon = s[2]
    
    cf = df[(df.index.year >= styr) & (df.index.year <= edyr)]
    
    # Annual temp
    temp_ann = cf['temp'].resample('Y').mean().mean()

    # Seasonal temp
    temp_JFM = cf['temp'][(cf.index.month >= 1) & (cf.index.month <= 3)].resample('Y').mean().mean()
    temp_AMJ = cf['temp'][(cf.index.month >= 4) & (cf.index.month <= 6)].resample('Y').mean().mean()
    temp_JAS = cf['temp'][(cf.index.month >= 7) & (cf.index.month <= 9)].resample('Y').mean().mean()
    temp_OND = cf['temp'][(cf.index.month >= 10) & (cf.index.month <= 12)].resample('Y').mean().mean()
    
    # Annual prec
    prec_ann = cf['prec'].resample('Y').sum().mean()
    
    # Seasonal prec
    prec_JFM = cf['prec'][(cf.index.month >= 1) & (cf.index.month <= 3)].resample('Y').sum().mean()
    prec_AMJ = cf['prec'][(cf.index.month >= 4) & (cf.index.month <= 6)].resample('Y').sum().mean()
    prec_JAS = cf['prec'][(cf.index.month >= 7) & (cf.index.month <= 9)].resample('Y').sum().mean()
    prec_OND = cf['prec'][(cf.index.month >= 10) & (cf.index.month <= 12)].resample('Y').sum().mean()
    
    #print(temp_ann, temp_JFM, temp_AMJ, temp_JAS, temp_OND)
    #print(prec_ann, prec_JFM, prec_AMJ, prec_JAS, prec_OND)

    row = pd.DataFrame({
        'lat':[s[1]], 'lon':[s[2]], 'styr':[styr], 'edyr':[edyr], 'gcm':[gcm],
        'temp_ann': [temp_ann],
        'temp_JFM':[temp_JFM], 'temp_AMJ':[temp_AMJ],
        'temp_JAS':[temp_JAS], 'temp_OND':[temp_OND],
        'prec_ann': [prec_ann],
        'prec_JFM':[prec_JFM], 'prec_AMJ':[prec_AMJ],
        'prec_JAS':[prec_JAS], 'prec_OND':[prec_OND]        
    })

    return row
    
def run_maca():
    maca_dir = "/home/disk/becassine/jswon11/laketapps/maca/"

    df = pd.DataFrame()

    # maca data
    gcms = sorted(glob('{}/*/*'.format(maca_dir)))
    for gcm in gcms:        
        data = sorted(glob('{}/*'.format(gcm)))
        
        #1980s, 2020s, 2050s, 2080s
        for d in data:
            print(d)
            cf = get_data(d)
            g = os.path.basename(gcm)

            if 'historical' in g:
                row = get_stats(cf, 1970, 1999, d, g) #1980s
                df = pd.concat([df, row])
            else:
                row = get_stats(cf, 2010, 2039, d, g) # 2020s
                df = pd.concat([df, row])
                row = get_stats(cf, 2040, 2069, d, g) # 2050s
                df = pd.concat([df, row])
                row = get_stats(cf, 2070, 2099, d, g) # 2080s
                df = pd.concat([df, row])
    return df
    

# WRF data
def run_wrf():
    wrf_dir = "/home/disk/becassine/jswon11/forcing/8.wrf_final/puyallup/"

    df = pd.DataFrame()
    yrs = [(1970, 1999), (2010, 2039), (2040, 2069), (2070, 2099)]
    
    gcms = sorted(glob('{}/*RCP*'.format(wrf_dir)))
    for gcm in gcms:        
        data = sorted(glob('{}/*'.format(gcm)))        
        for d in data:
            print(d)
            cf = get_data(d)
            g = os.path.basename(gcm)
            for yr in yrs:
                row = get_stats(cf, yr[0], yr[1], d, g)
                df = pd.concat([df, row])
            print(df)
    return df

    

def main():
    maca_tapps = "data_47.21875_-122.15625"
    wrf_tapps = "data_47.23265_-122.14883"
    
    #/home/disk/rocinante/DATA/temp/kcp3/scripts/puyallup_summary
    #ifile = '/home/disk/becassine/jswon11/forcing/8.wrf_final/puyallup/mri-cgcm3_RCP85/data_47.15322_-121.73587'
    #df = get_data(ifile)
    #row = get_stats(df, 1980, 2020, ifile)
    #print(row)

    d1 = run_maca()
    d1.to_csv('maca_summary.csv')
    d2 = run_wrf()
    d2.to_csv('wrf_summary.csv')
    
main()
