import pandas as pd
from glob import glob
import seaborn as sns
import numpy as np


data = glob('*Air_CLIP.csv')

def jd_plot(df, out):
    print(df)
    if len(df) == 0:
        return
    g = sns.jointplot(data=df, x='air_temp', y='stream_temp', kind='hex')
    g.plot_joint(sns.kdeplot, color='r', levels=5)
    g.fig.suptitle(' '.join(lbl.replace('.csv', '').split('_')))    
    g.savefig(out, dpi=300)
    

for d in data:
    print(d)
    df_air = pd.read_csv(d, header=None)
    df_air.columns = ['x', 'date', 'air_temp']
    df_air.index = pd.to_datetime(df_air.date, format='%m/%d/%Y %H:%M')
    df_air = df_air[['air_temp']]
    s = d.replace('_Air_CLIP.csv', '.csv').split('_')
    lbl = '_'.join(s[1:])    
    s = '../CLEAN_' + lbl
    df_stream = pd.read_csv(s, header=None)
    df_stream.columns = ['Year', 'Month', 'Day', 'Hour', 'Minute', 'stream_temp']
    df_stream.index = pd.to_datetime(df_stream[df_stream.columns[:-1]])
    df_stream = df_stream[['stream_temp']]

    df = df_air.join(df_stream, how='outer')
    jd_plot(df, lbl.replace('csv', 'png'))
    
    wdf = df[(df.index.month >= 12) & (df.index.month <= 2)]
    jd_plot(wdf, lbl.replace('.csv', '_DJF.png'))

    sdf = df[(df.index.month >= 7) & (df.index.month <= 9)]
    jd_plot(sdf, lbl.replace('.csv', '_JAS.png'))
    

    #exit()
