import xarray as xr
import xesmf as xe
import sys
import numpy as np


ifile = sys.argv[1]
ds = xr.open_dataset(ifile)
ps = xr.open_dataset('/home/disk/rocinante/DATA/temp/TNC_stormwater/pnnl/pnnl_1981-2020.nc')

try:
    ds = ds.rename({'x':'lon', 'y':'lat'})
except:
    print()

if 'quantile' in ifile.lower():
    ds['mask'] = xr.where(~np.isnan(ds['PREC'].isel(quantile=0)), 1, 0)
    ds = ds.transpose('quantile', 'lat', 'lon')        
else:
    ds['mask'] = xr.where(~np.isnan(ds['PREC']), 1, 0)
    ds = ds.transpose('lat', 'lon')
ps['mask'] = xr.where(~np.isnan(ps['PREC'].isel(time=0)), 1, 0)
ps = ps.transpose('time', 'y', 'x')

grid = xe.Regridder(ds, ps, 'bilinear')
dr = grid(ds, skipna=True, na_thres=0.25)
dr = dr[['PREC']]

out = ifile.replace('.nc', '_interp.nc')
print(out)
dr.to_netcdf(out)


#/home/disk/rocinante/DATA/temp/TNC_stormwater/pnnl

