import xarray as xr
import pandas as pd
from glob import glob


gcms = [
    #"access1.0_RCP45",
    "access1.0_RCP85",
    "access1.3_RCP85",
    "bcc-csm1.1_RCP85",
    "canesm2_RCP85",
    "ccsm4_RCP85",
    "csiro-mk3.6.0_RCP85",
    "fgoals-g2_RCP85",
    "gfdl-cm3_RCP85",
    "giss-e2-h_RCP85",
    "miroc5_RCP85",
    "mri-cgcm3_RCP85",
    "noresm1-m_RCP85"
]


#idir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/bcdata/"
#odir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/bc_csv/"
#idir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/bcdata_test/"
idir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/bcdata_wndf/"
odir = "/home/disk/rocinante/DATA/temp/crystal_fire/wrf/csv_wndf/"

# Create csv
def create_csv(ds, gcm, odir):
    #ds['WND'] = ((ds.U10**2) + (ds.V10**2))**0.5
    ds['tmx7'] = ds.T2.rolling(times=7).mean()
    ds['tmx60'] = ds.T2.rolling(times=60).mean()
    print(ds)
    nonzero = ds.T2.sum(dim='times')>0
    
    for i in [-99,4,5,6,7,8,9,10,11]:
        for j in [-99,17,18,24]:

            if (i == -99) & (j == -99):
                continue
            print('\tWND: ', i, " tmax: ", j)
            #ds = ds.where((ds.times.dt.month >= 5) & (ds.times.dt.month <= 9))
            dx = ds.where((ds.WND >= i) & (ds.tmx7 >= j) & (ds.times.dt.month >= 5) & (ds.times.dt.month <= 9)).groupby('times.year').count(dim='times')
            dy = ds.where((ds.WND >= i) & (ds.tmx60 >= j) & (ds.times.dt.month >= 5) & (ds.times.dt.month <= 9)).groupby('times.year').count(dim='times')
            
            dfx = dx['WND'].where(nonzero).to_dataframe().reset_index().drop(['x', 'y'], axis=1)
            dfy = dy['WND'].where(nonzero).to_dataframe().reset_index().drop(['x', 'y'], axis=1)
            dfx = dfx.dropna()
            dfy = dfy.dropna()
            dfx = dfx[['year', 'lat', 'lon', 'WND']]
            dfy = dfy[['year', 'lat', 'lon', 'WND']]            
            #dfx = dfx[dfx.WND != 0]
            #dfy = dfy[dfy.WND != 0]
            
            dfx.columns = ['year', 'lat', 'lon', 'value']
            dfy.columns = ['year', 'lat', 'lon', 'value']
            dfx['tmax_window'] = 7
            dfy['tmax_window'] = 60
            dfx['wind_thld'] = i
            dfy['wind_thld'] = i
            dfx['tmax_thld'] = j
            dfy['tmax_thld'] = j
            dfx['gcm'] = gcm
            dfy['gcm'] = gcm

            out7 = "{}/{}_wnd{}_t{}_{}.csv".format(odir, gcm, i, 7, j)
            out60 = "{}/{}_wnd{}_t{}_{}.csv".format(odir, gcm, i, 60, j)
            dfx.to_csv(out7, index=False, float_format='%0.5f')
            dfy.to_csv(out60, index=False, float_format='%0.5f')
            
# Create table
def create_table(idir, ofile):
    print("\tCreating table")
    data = sorted(glob('{}/*.csv'.format(idir)))
    df = pd.DataFrame()
    
    for i in data:
        print(i)
        dc = pd.read_csv(i)
        #df = df.append(dc)
        df = pd.concat([df, dc])
        
    df.to_csv(ofile, index=False, float_format='%0.5f')
        

def mk_csv():
    for gcm in gcms:
        print("\nWorking on : ", gcm)

        ifile = "{}/{}.nc".format(idir, gcm)
        ds = xr.open_dataset(ifile)
        #ofile = "./data_out_May2Sep.csv"
    
        create_csv(ds, gcm, odir)

mk_csv()

ofile = "./bcdata_wndf.csv"
create_table(odir, ofile)
