import pandas as pd
from glob import glob
import os
#from multiprocessing import Pool
#import multiprocessing as mp
import time
from multiprocessing import Process, Queue
from multiprocessing import Pool
import random
from tqdm import tqdm


# grab all forcings
# calculate summary

# Metrics
# Average annual temp, Average seasonal temp, Total Annual precip, Total Seasonal precip all for specified time period
# Time periods: 1980s, 2020s, 2050s, 2080s
# Datasets: MACA, WRF, WRF-NARR
# Domain: Lake Tapps Grid Cell, Puyallup basin, White Basin

def get_data(ifile):
    df = pd.read_csv(ifile, sep=' ', header=None)
    df.index = pd.to_datetime(df[0], format='%m/%d/%Y-%H:%M:%S')
    df = df[[1,6]]
    df.columns = ['temp', 'prec']
    df.index.name = 'date'
    return df

def get_stats(df, styr, edyr, ifile, gcm):
    s = os.path.basename(ifile).split('_')
    lat = s[1]
    lon = s[2]
    
    cf = df[(df.index.year >= styr) & (df.index.year <= edyr)]
    
    # Annual temp
    temp_ann = cf['temp'].resample('Y').mean().mean()

    # Seasonal temp
    temp_JFM = cf['temp'][(cf.index.month >= 1) & (cf.index.month <= 3)].resample('Y').mean().mean()
    temp_AMJ = cf['temp'][(cf.index.month >= 4) & (cf.index.month <= 6)].resample('Y').mean().mean()
    temp_JAS = cf['temp'][(cf.index.month >= 7) & (cf.index.month <= 9)].resample('Y').mean().mean()
    temp_OND = cf['temp'][(cf.index.month >= 10) & (cf.index.month <= 12)].resample('Y').mean().mean()
    
    # Annual prec
    prec_ann = cf['prec'].resample('Y').sum().mean()
    
    # Seasonal prec
    prec_JFM = cf['prec'][(cf.index.month >= 1) & (cf.index.month <= 3)].resample('Y').sum().mean()
    prec_AMJ = cf['prec'][(cf.index.month >= 4) & (cf.index.month <= 6)].resample('Y').sum().mean()
    prec_JAS = cf['prec'][(cf.index.month >= 7) & (cf.index.month <= 9)].resample('Y').sum().mean()
    prec_OND = cf['prec'][(cf.index.month >= 10) & (cf.index.month <= 12)].resample('Y').sum().mean()
    
    #print(temp_ann, temp_JFM, temp_AMJ, temp_JAS, temp_OND)
    #print(prec_ann, prec_JFM, prec_AMJ, prec_JAS, prec_OND)

    row = pd.DataFrame({
        'lat':[s[1]], 'lon':[s[2]], 'styr':[styr], 'edyr':[edyr], 'gcm':[gcm],
        'temp_ann': [temp_ann],
        'temp_JFM':[temp_JFM], 'temp_AMJ':[temp_AMJ],
        'temp_JAS':[temp_JAS], 'temp_OND':[temp_OND],
        'prec_ann': [prec_ann],
        'prec_JFM':[prec_JFM], 'prec_AMJ':[prec_AMJ],
        'prec_JAS':[temp_JAS], 'prec_OND':[temp_OND]        
    })

    return row

def get_stats_maca(d, gcm):
    df = pd.DataFrame()
    print(d)            
    cf = get_data(d)
    g = os.path.basename(gcm)

    if 'historical' in g:
        row = get_stats(cf, 1970, 1999, d, g) #1980s
        df = pd.concat([df, row])
    else:
        row = get_stats(cf, 2010, 2039, d, g) # 2020s
        df = pd.concat([df, row])
        row = get_stats(cf, 2040, 2069, d, g) # 2050s
        df = pd.concat([df, row])
        row = get_stats(cf, 2070, 2099, d, g) # 2080s
        df = pd.concat([df, row])
    return df
    
    
def get_stats_wrf(d, gcm):
    yrs = [(1970, 1999), (2010, 2039), (2040, 2069), (2070, 2099)]

    df = pd.DataFrame()
    print(d)
    cf = get_data(d)
    g = os.path.basename(gcm)
    for yr in yrs:
        row = get_stats(cf, yr[0], yr[1], d, g)
        df = pd.concat([df, row])
    return df
    

    
def run_maca():
    maca_dir = "/home/disk/becassine/jswon11/laketapps/maca/"

    df = pd.DataFrame()

    # maca data
    gcms = sorted(glob('{}/*/*'.format(maca_dir)))
    for gcm in gcms:        
        data = sorted(glob('{}/*'.format(gcm)))
        
        for d in data:
            get_stats_maca(d, gcm)
            df = pd.concat([df, cf])
    return df
    

# WRF data
def run_wrf():
    wrf_dir = "/home/disk/becassine/jswon11/forcing/8.wrf_final/puyallup/"

    df = pd.DataFrame()
    
    gcms = sorted(glob('{}/*RCP*'.format(wrf_dir)))
    for gcm in gcms:        
        data = sorted(glob('{}/*'.format(gcm)))        
        for d in data:
            cf = get_stats_wrf(d, gcm)
            df = pd.concat([df, cf])
    return df

    

def main():
    maca_tapps = "data_47.21875_-122.15625"
    wrf_tapps = "data_47.23265_-122.14883"
    
    #/home/disk/rocinante/DATA/temp/kcp3/scripts/puyallup_summary
    #ifile = '/home/disk/becassine/jswon11/forcing/8.wrf_final/puyallup/mri-cgcm3_RCP85/data_47.15322_-121.73587'
    #df = get_data(ifile)
    #row = get_stats(df, 1980, 2020, ifile)
    #print(row)


    d1 = run_maca()
    d1.to_csv('pool_maca_summary.csv')
    #d2 = run_wrf()
    #d2.to_csv('pool_wrf_summary.csv')
            
#main()


def do_something(a, r, b, pbar):
    #print(a)
    time.sleep(r)
    #b.append(a)
    b.put(a)
    print(a, r)
    pbar.update(2)

if __name__ == '__main__':
    #pool = multiprocessing.pool.Pool()

    #process_list = []
    #for i in range(10):
    #    p = mp.Process(target=do_something, args=(i, 2))
    #    p.start()
    #    process_list.append(p)
#
#    for process in process_list:
#        process.join()

    #q = Queue()
    #pbar = tqdm(total=10)
    #process_list = []
   # for i in range(10):
   #     r = random.randint(0, 10) 
    #    p = Process(target=do_something, args=(i, r, q, pbar ))
    #    p.start()
    #    process_list.append(p)

    
    #for process in process_list:
    #    process.join()
        

    #print(q.get())

    def crunch(a):
        time.sleep(1)
        return a

    maca_dir = "/home/disk/becassine/jswon11/laketapps/maca/"
    gcms = sorted(glob('{}/*/*'.format(maca_dir)))
    
    with Pool(processes=4) as p:
        r = list(tqdm(p.imap(crunch, gcms), total=len(gcms)))
    print(r)
    #    progress_bar = tqdm(total=40)
    #    print('map')
    #    results = tqdm(pool.imap(crunch, range(40)), total=40)
    #    print('run')
    #    tuple(results)
    #    print('done')
