import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from pyart.graph import cm
import metpy.calc as mpcalc
from metpy.units import units
import cmaps
from datetime import datetime
from scipy.ndimage import gaussian_filter
from metpy.plots import ctables
import matplotlib.colors as mcolors
import argparse

ctables.registry.add_colortable(open('dewpoint2.tbl', 'rt'), 'td')
td_cmap = ctables.registry.get_colortable('td')

cape_levels = [0, 100, 200, 350, 500, 750, 1000, 1300, 1600, 2000, 2400, 2800, 3300, 3800, 4400, 5000]
ctables.registry.add_colortable(open('marco_cape.tbl', 'rt'), 'cape')
cape_norm, cape_cmap = ctables.registry.get_with_boundaries('cape', cape_levels)



def plot_background(lon_0, lon_f, lat_0, lat_f):
    fig, ax = plt.subplots(1, 1, figsize = [12, 10], subplot_kw = {'projection':ccrs.LambertCylindrical()}, constrained_layout = True)
    ax.set_extent([lon_0, lon_f, lat_0, lat_f])
    ax.add_feature(cfeature.BORDERS.with_scale('10m'), edgecolor = 'k')
    ax.add_feature(cfeature.STATES.with_scale('10m'), edgecolor = 'gray', linewidth = 0.5)
    ax.add_feature(cfeature.OCEAN.with_scale('10m'), facecolor = 'none', edgecolor = 'k', zorder = 2)

    return fig, ax

def bwd(ui, vi, uf, vf):
    return mpcalc.wind_speed((uf-ui) * units('m/s'), (vf-vi) * units('m/s'))

def calc_sblcl(t, td):
    return (125*(t-td))

def scp_fix(sbcape, srh03, bwd06):
    scp = (sbcape/1000)*(srh03/150)*(bwd06/20)
    return scp

def stp_fix(sbcape, bwd06, srh01, sblcl, sbcin):
    stp = (sbcape/1000)*(bwd06/20)*(srh01/100)*((2000 - sblcl)/1500)*((150 - sbcin)/125)
    return stp

def plot_all(dir_files, dir_figures, lon_0, lon_f, lat_0, lat_f):
    ds = xr.open_mfdataset(dir_files + '/*.nc', concat_dim = 'Time', combine = 'nested')
    lons, lats = np.meshgrid(ds.longitude.values, ds.latitude.values)
    for i in ds.Time.values:
        ds_index = ds.sel(Time = i)
        dt0 = datetime.strptime(ds_index.Time_Start.values[0].astype('datetime64[s]').astype('str'), '%Y-%m-%dT%H:%M:%S')
        dt = datetime.strptime(i.astype('datetime64[s]').astype('str'), '%Y-%m-%dT%H:%M:%S')

        wind200 = (mpcalc.wind_speed(ds_index.uzonal_200hPa * units('m/s'), ds_index.umeridional_250hPa * units('m/s')).values * units('m/s')).to('kt')
        wind250 = (mpcalc.wind_speed(ds_index.uzonal_250hPa * units('m/s'), ds_index.umeridional_250hPa * units('m/s')).values * units('m/s')).to('kt')
        wind500 = (mpcalc.wind_speed(ds_index.uzonal_500hPa * units('m/s'), ds_index.umeridional_500hPa * units('m/s')).values * units('m/s')).to('kt')
        wind700 = (mpcalc.wind_speed(ds_index.uzonal_700hPa * units('m/s'), ds_index.umeridional_700hPa * units('m/s')).values * units('m/s')).to('kt')
        wind850 = (mpcalc.wind_speed(ds_index.uzonal_850hPa * units('m/s'), ds_index.umeridional_850hPa * units('m/s')).values * units('m/s')).to('kt')

        ubwd0_6 = ds_index.uzonal_6km.values - ds_index.u10.values
        vbwd0_6 = ds_index.umeridional_6km.values - ds_index.v10.values
        ubwd0_1 = ds_index.uzonal_1km.values - ds_index.u10.values
        vbwd0_1 = ds_index.umeridional_1km.values - ds_index.v10.values

        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        ref = ax.contourf(lons, lats, ds_index.refl10cm_max, cmap = cm.LangRainbow12, levels = np.arange(0, 77.5, 2.5), extend = 'max', transform = ccrs.PlateCarree())
        uh = ax.contour(lons, lats, ds_index.updraft_helicity_min, colors = ['purple'], levels = [75], linestyles = '-', transform = ccrs.PlateCarree())
        fig.colorbar(ref, ticks = np.arange(0, 80, 5), label = '[dBZ]', orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)
                        
        plt.title('Composite Reflectivity [dBZ] | Minimum Updraft Helicity < -50 [m$^{-2}$ s$^{-2}$]' + '\n' + 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/ref_{i}.png', bbox_inches = 'tight', dpi = 300)
   
        fig.clf()
        plt.close()

        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        cape = ax.contourf(lons, lats, ds_index.mlcape.values, levels = cape_levels, cmap = cape_cmap, norm = cape_norm, extend = 'max', transform = ccrs.PlateCarree())
        cin = ax.contourf(lons, lats, ds_index.mlcin, levels = [25, 50, 75, 100, 125, 150], cmap = 'Greys', hatches = ['--'], alpha = 0.5, extend = 'max', transform = ccrs.PlateCarree())
        ax.barbs(lons[::40, ::40], lats[::40, ::40], (ubwd0_6 * units('m/s')).to('kt')[::40, ::40], (vbwd0_6 * units('m/s')).to('kt')[::40, ::40], barbcolor = 'darkblue', length = 6, flip_barb = True, zorder = 4, transform = ccrs.PlateCarree())
        fig.colorbar(cape, ticks = cape_levels, label = '[J/kg]', orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)
        plt.title('MLCAPE [j kg$^{-1}$] | MLCIN < -30 [j kg$^{-1}$] | 0-6 km BWD [m s$^{-1}$]' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/cape_{i}.png', bbox_inches = 'tight', dpi = 300)
    
        fig.clf()
        plt.close()
                                                    
        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        srh0_3 = ax.contourf(lons, lats, ds_index.srh_0_3km, cmap = cmaps.GMT_seis, levels = np.arange(-700, -45, 5), extend = 'min', transform = ccrs.PlateCarree())
        ax.barbs(lons[::40, ::40], lats[::40, ::40], (ubwd0_6 * units('m/s')).to('kt')[::40, ::40], (vbwd0_6 * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4, transform = ccrs.PlateCarree())
        fig.colorbar(srh0_3, label = '[m²/s²]', ticks = np.arange(-700, 0, 50),orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)
        plt.title('0-3 km SRH [m$^{2}$ s$^{-2}$] | 0-6 km BWD [m s$^{-1}$]' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
    
        fig.clf()
        plt.close()
                                                                                                        
        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        srh0_1 = ax.contourf(lons, lats, ds_index.srh_0_1km, cmap = cmaps.GMT_seis, levels = np.arange(-500, -45, 5), extend = 'min')
        ax.barbs(lons[::40, ::40], lats[::40, ::40], (ubwd0_1 * units('m/s')).to('kt')[::40, ::40], (vbwd0_1 * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4)
        fig.colorbar(srh0_1, label = '[m²/s²]', orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)
        plt.title('0-1 km SRH [m$^{2}$ s$^{-2}$] | 0-1 km BWD [m s$^{-1}$]' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
    
        fig.clf()
        plt.close()                                                                                                                                                         
    
        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        wind_250 = ax.contourf(lons, lats, wind250, levels = np.arange(40, 162.5, 2.5), cmap = cmaps.cmp_haxby, extend = 'max', transform = ccrs.PlateCarree())
        hgt_250 = ax.contour(lons, lats, ds_index.height_250hPa / 10, levels = np.arange(970, 1300, 8), colors = 'k', transform = ccrs.PlateCarree())
        ax.clabel(hgt_250, inline = True)
        ax.barbs(lons[::40, ::40], lats[::40, ::40], (ds_index.uzonal_250hPa.values * units('m/s')).to('kt')[::40, ::40], (ds_index.umeridional_250hPa.values * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4, transform = ccrs.PlateCarree())
        fig.colorbar(wind_250, ticks = np.arange(40, 170, 10), label = '[kt]', orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)
        plt.title('250 hPa Wind [kt] | Geopotential Heights [dm]' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')                                                
        plt.savefig(dir_figures + f'/uv250_{i}.png', bbox_inches ='tight')                                                                                                                                                                                                             
        fig.clf() 
        plt.close()

        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        wind_500 = ax.contourf(lons, lats, wind500, levels = np.arange(20, 122.5, 2.5), cmap = cmaps.cmp_haxby, extend = 'max', transform = ccrs.PlateCarree())
        hgt_500 = ax.contour(lons, lats, ds_index.height_500hPa / 10, levels = np.arange(400, 600, 4), colors = 'k', transform = ccrs.PlateCarree())
        ax.clabel(hgt_500, inline = True)
        ax.barbs(lons[::40, ::40], lats[::40, ::40], (ds_index.uzonal_500hPa.values * units('m/s')).to('kt')[::40, ::40], (ds_index.umeridional_500hPa.values * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4, transform = ccrs.PlateCarree())
        fig.colorbar(wind_500, ticks = np.arange(20, 130, 10), label = '[kt]', orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)
        plt.title('500 hPa Wind [kt] | Geopotential Heights [dm]' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/uv500_{i}.png', bbox_inches = 'tight')

        fig.clf()
        plt.close()
    
        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        wind_850 = ax.contourf(lons, lats, wind850, levels = np.arange(10, 82.5, 2.5), cmap = cmaps.cmp_haxby, extend = 'max', transform = ccrs.PlateCarree())
        hgt_850 = ax.contour(lons, lats, ds_index.height_850hPa/10 , levels = np.arange(100, 500, 2), colors = 'k', transform = ccrs.PlateCarree())
        ax.clabel(hgt_850, inline = True)
        ax.barbs(lons[::40, ::40], lats[::40, ::40], (ds_index.uzonal_850hPa.values * units('m/s')).to('kt')[::40, ::40], (ds_index.umeridional_850hPa.values * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4, transform = ccrs.PlateCarree())
        fig.colorbar(wind_850, ticks = np.arange(10, 85, 5), label = '[kt]', orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)
        plt.title('850 hPa Wind [kt] | Geopotential Heights [dm]' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/uv850_{i}.png', bbox_inches = 'tight')
   
        fig.clf()
        plt.close()

        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        t2m = ax.contourf(lons, lats, (ds_index.temperature_surface.values * units.K).to('degC'), cmap = cmaps.NCV_bright, levels = np.arange(-20, 41, 1.), extend = 'both', transform = ccrs.PlateCarree())
        mslp = ax.contour(lons, lats, (ds_index.mslp.values * units('Pa')).to('hPa'), levels = np.arange(900, 1073, 2), colors = 'k', transform = ccrs.PlateCarree())
        ax.clabel(mslp, inline = True)
        ax.barbs(lons[::40, ::40], lats[::40, ::40], (ds_index.uzonal_surface.values * units('m/s')).to('kt')[::40, ::40], (ds_index.umeridional_surface.values * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4, transform = ccrs.PlateCarree())
        fig.colorbar(t2m, label = '[°C]', orientation = 'vertical', ticks = np.arange(-20, 45, 5), aspect = 30,shrink = 0.8, pad = 0.05)
        plt.title('2 m Temperature [°C] | MSLP [hPa] | 10 m Winds [m s$^{-1}$]]' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/t2m_{i}.png', bbox_inches = 'tight', dpi = 300)
   
        fig.clf()
        plt.close()

        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        td2m = ax.contourf(lons, lats, (ds_index.dewpoint_surface.values * units.K).to('degC'), cmap = td_cmap, levels = np.arange(-40, 31.25, 0.55), extend = 'both', transform = ccrs.PlateCarree())
        mslp = ax.contour(lons, lats, (ds_index.mslp.values * units('Pa')).to('hPa'), levels = np.arange(900, 1073, 2), colors = 'k', transform = ccrs.PlateCarree())
        ax.clabel(mslp, inline = True)
        ax.barbs(lons[::40, ::40], lats[::40, ::40], (ds_index.uzonal_surface.values * units('m/s')).to('kt')[::40, ::40], (ds_index.umeridional_surface.values * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4, transform = ccrs.PlateCarree())
        fig.colorbar(td2m, label = '[°C]', orientation = 'vertical', ticks = np.arange(-40, 35, 5), aspect = 30, shrink = 0.8, pad = 0.05)
        plt.title('2 m Temperature [°C] | MSLP [hPa] | 10 m Winds [m s$^{-1}$]]' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/td2m_{i}.png', bbox_inches = 'tight', dpi = 300)
  
        fig.clf()
        plt.close()
  
        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        lcl = ax.contourf(lons, lats, calc_sblcl((ds_index.temperature_surface.values * units.K).to('degC'), (ds_index.dewpoint_surface.values * units.K).to('degC')), transform=ccrs.PlateCarree(), levels = np.arange(0, 5100, 100), cmap = cmaps.WhiteBlueGreenYellowRed, extend = 'max')
        ax.barbs(lons[::40, ::40], lats[::40, ::40], (ubwd0_1 * units('m/s')).to('kt')[::40, ::40], (vbwd0_1 * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4)
        cbar = fig.colorbar(lcl, ticks = np.arange(0, 5100, 500), orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)
        plt.title(r'SBLCL [m] | BWD 0-1 km [kt]' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/sblcl_{i}.png', bbox_inches = 'tight', dpi = 300)
        
        fig.clf()
        plt.close()

        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        SCP = ax.contourf(lons, lats, scp_fix(ds_index.sbcape, ds_index.srh_0_3km, bwd(ds_index.uzonal_surface, ds_index.uzonal_6km, ds_index.umeridional_surface, ds_index.umeridional_6km)), levels = np.arange(-20, 0.5, 0.5), transform=ccrs.PlateCarree(), cmap = cmaps.GMT_wysiwygcont, extend = 'min')
        #ax.barbs(lons[::40, ::40], lats[::40, ::40], (ubwd0_1.values * units('m/s')).to('kt')[::40, ::40], (vbwd0_1.values * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4)
        cbar = fig.colorbar(SCP, ticks = np.arange(-20, 1, 1),  orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)                       
        plt.title(r'Supercell Composite Parameter (FIX)' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/scp_{i}.png', bbox_inches = 'tight', dpi = 300)

        fig.clf()
        plt.close()
  
        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        STP = ax.contourf(lons, lats, stp_fix(ds_index.sbcape, bwd(ds_index.uzonal_surface, ds_index.uzonal_6km, ds_index.umeridional_surface, ds_index.umeridional_6km), ds_index.srh_0_1km, calc_sblcl((ds_index.temperature_surface.values * units.K).to('degC'), (ds_index.dewpoint_surface.values * units.K).to('degC')).magnitude, ds_index.sbcin), levels = np.arange(-10, -0.25, 0.25), transform=ccrs.PlateCarree(), cmap = cmaps.GMT_wysiwygcont, extend = 'min')
        #ax.barbs(lons[::40, ::40], lats[::40, ::40], (ubwd0_1.values * units('m/s')).to('kt')[::40, ::40], (vbwd0_1.values * units('m/s')).to('kt')[::40, ::40], length = 6, flip_barb = True, zorder = 4)
        cbar = fig.colorbar(STP, ticks = np.arange(-10, 1, 1),  orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)                   
        plt.title(r'Significant Tornado Parameter (FIX)' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/stp_{i}.png', bbox_inches = 'tight', dpi = 300)

        fig.clf()
        plt.close()

        fig, ax = plot_background(lon_0, lon_f, lat_0, lat_f)
        rain = ax.contourf(lons, lats, ds_index.rainnc, levels = [0, 2.5, 5, 12.5, 20, 30, 40, 50, 65, 80, 100, 130, 160, 200, 250, 500], transform=ccrs.PlateCarree(), cmap = cmaps.prcp_1, norm = mcolors.BoundaryNorm([0, 2.5, 5, 12.5, 20, 30, 40, 50, 65, 80, 100, 130, 160, 200, 250, 500], 15), extend = 'max')
        cbar = fig.colorbar(rain, ticks = [0, 2.5, 5, 12.5, 20, 30, 40, 50, 65, 80, 100, 130, 160, 200, 250, 500], orientation = 'vertical', aspect = 30, shrink = 0.8, pad = 0.05)
        plt.title(r'Rain (mm)' + '\n' 'MPAS 3 km', loc = 'left')
        plt.title(f'Iniciado: {dt0:%Y-%m-%d %H:%M:%S}' + '\n' + f'Valido: {dt:%Y-%m-%d %H:%M:%S}', loc = 'right')
        plt.savefig(dir_figures + f'/rainc_{i}.png', bbox_inches = 'tight', dpi = 300)

        fig.clf()
        plt.close()

        ds_index.close()

parser = argparse.ArgumentParser(description='Plot MPAS output')
parser.add_argument('--dir_diag', metavar='-DD', nargs='+', help='path to dig files')
parser.add_argument('--dir_fig', metavar='-DF', nargs='+', help='path to save plots')
parser.add_argument('--lon_0', metavar='-LON0', type = float, nargs='+', help='start longitude of box')
parser.add_argument('--lon_f', metavar='-LONF', type = float, nargs='+', help='end longitude of box')
parser.add_argument('--lat_0', metavar='-LAT0', type = float, nargs='+', help='start latitude of box')
parser.add_argument('--lat_f', metavar='-LATF', type = float, nargs='+', help='end latitude of box')
args = parser.parse_args()

plot_all(args.dir_diag[0], args.dir_fig[0], float(args.lon_0[0]), float(args.lon_f[0]), float(args.lat_0[0]), float(args.lat_f[0]))

