import pandas as pd
from glob import glob
import os
#from multiprocessing import Pool
#import multiprocessing as mp
import time
from multiprocessing import Process, Queue
from multiprocessing import Pool
import random
from tqdm import tqdm


# grab all forcings
# calculate summary

# Metrics
# Average annual temp, Average seasonal temp, Total Annual precip, Total Seasonal precip all for specified time period
# Time periods: 1980s, 2020s, 2050s, 2080s
# Datasets: MACA, WRF, WRF-NARR
# Domain: Lake Tapps Grid Cell, Puyallup basin, White Basin

def get_data(ifile):
    df = pd.read_csv(ifile, sep=' ', header=None)
    df.index = pd.to_datetime(df[0], format='%m/%d/%Y-%H:%M:%S')
    df = df[[1,6]]
    df.columns = ['temp', 'prec']
    df.index.name = 'date'
    return df

def get_stats(df, styr, edyr, ifile, gcm):
    s = os.path.basename(ifile).split('_')
    lat = s[1]
    lon = s[2]
    
    cf = df[(df.index.year >= styr) & (df.index.year <= edyr)]
    
    # Annual temp
    temp_ann = cf['temp'].resample('Y').mean().mean()

    # Seasonal temp
    temp_JFM = cf['temp'][(cf.index.month >= 1) & (cf.index.month <= 3)].resample('Y').mean().mean()
    temp_AMJ = cf['temp'][(cf.index.month >= 4) & (cf.index.month <= 6)].resample('Y').mean().mean()
    temp_JAS = cf['temp'][(cf.index.month >= 7) & (cf.index.month <= 9)].resample('Y').mean().mean()
    temp_OND = cf['temp'][(cf.index.month >= 10) & (cf.index.month <= 12)].resample('Y').mean().mean()
    
    # Annual prec
    prec_ann = cf['prec'].resample('Y').sum().mean()
    
    # Seasonal prec
    prec_JFM = cf['prec'][(cf.index.month >= 1) & (cf.index.month <= 3)].resample('Y').sum().mean()
    prec_AMJ = cf['prec'][(cf.index.month >= 4) & (cf.index.month <= 6)].resample('Y').sum().mean()
    prec_JAS = cf['prec'][(cf.index.month >= 7) & (cf.index.month <= 9)].resample('Y').sum().mean()
    prec_OND = cf['prec'][(cf.index.month >= 10) & (cf.index.month <= 12)].resample('Y').sum().mean()
    
    #print(temp_ann, temp_JFM, temp_AMJ, temp_JAS, temp_OND)
    #print(prec_ann, prec_JFM, prec_AMJ, prec_JAS, prec_OND)

    row = pd.DataFrame({
        'lat':[s[1]], 'lon':[s[2]], 'styr':[styr], 'edyr':[edyr], 'gcm':[gcm],
        'temp_ann': [temp_ann],
        'temp_JFM':[temp_JFM], 'temp_AMJ':[temp_AMJ],
        'temp_JAS':[temp_JAS], 'temp_OND':[temp_OND],
        'prec_ann': [prec_ann],
        'prec_JFM':[prec_JFM], 'prec_AMJ':[prec_AMJ],
        'prec_JAS':[prec_JAS], 'prec_OND':[prec_OND]
    })

    return row

# MACA Processing
def get_stats_maca(i):
    (d, gcm) = i
    df = pd.DataFrame()
    #print(d)            
    cf = get_data(d)
    g = os.path.basename(gcm)

    if 'historical' in g:
        row = get_stats(cf, 1970, 1999, d, g) #1980s
        df = pd.concat([df, row])
    else:
        row = get_stats(cf, 2010, 2039, d, g) # 2020s
        df = pd.concat([df, row])
        row = get_stats(cf, 2040, 2069, d, g) # 2050s
        df = pd.concat([df, row])
        row = get_stats(cf, 2070, 2099, d, g) # 2080s
        df = pd.concat([df, row])
    return df

def run_maca():
    print('Processing MACA data')
    maca_dir = "/home/disk/becassine/jswon11/laketapps/maca/"
    df = pd.DataFrame()
    gcms = sorted(glob('{}/*/*'.format(maca_dir)))

    for gcm in gcms:
        print(gcm)
        data = sorted(glob('{}/*'.format(gcm)))
        x = [(d, gcm) for d in data]
        with Pool(processes=10) as p:
            r = list(tqdm(p.imap(get_stats_maca, x), total=len(data)))
            df = pd.concat([df] + r)
            
    return df
    

# WRF Processing
def get_stats_wrf(i):
    (d, gcm) = i
    yrs = [(1970, 1999), (2010, 2039), (2040, 2069), (2070, 2099)]
    df = pd.DataFrame()
    #print(d)
    cf = get_data(d)
    g = os.path.basename(gcm)
    for yr in yrs:
        row = get_stats(cf, yr[0], yr[1], d, g)
        df = pd.concat([df, row])
    return df
    

    
def run_wrf():
    print('Processing WRF data')
    wrf_dir = "/home/disk/becassine/jswon11/forcing/8.wrf_final/puyallup/"
    df = pd.DataFrame()    
    gcms = sorted(glob('{}/*RCP*'.format(wrf_dir)))
    
    for gcm in gcms:
        print(gcm)
        data = sorted(glob('{}/*'.format(gcm)))
        x = [(d, gcm) for d in data]
        with Pool(processes=10) as p:
            r = list(tqdm(p.imap(get_stats_wrf, x), total=len(data)))
            df = pd.concat([df] + r)
            
    return df

    
if __name__ == '__main__':
    maca_tapps = "data_47.21875_-122.15625"
    wrf_tapps = "data_47.23265_-122.14883"
    
    #/home/disk/rocinante/DATA/temp/kcp3/scripts/puyallup_summary
    #ifile = '/home/disk/becassine/jswon11/forcing/8.wrf_final/puyallup/mri-cgcm3_RCP85/data_47.15322_-121.73587'
    #df = get_data(ifile)
    #row = get_stats(df, 1980, 2020, ifile)
    #print(row)

    d1 = run_maca()
    d1.to_csv('pool_maca_summary.csv', index=False, float_format='%0.5f')
    d2 = run_wrf()
    d2.to_csv('pool_wrf_summary.csv', index=False, float_format='%0.5f')
