# -*- coding: utf-8 -*-

import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import cartopy.feature as cfeature
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER
import matplotlib as mpl
from matplotlib.colors import BoundaryNorm
import cartopy.io.shapereader as shpreader
import os
import pandas as pd
import warnings
import unicodedata
import re

warnings.filterwarnings("ignore")

def remover_acentos(texto):
    nfkd = unicodedata.normalize('NFKD', texto)
    palavra_sem_acento = u"".join([c for c in nfkd if not unicodedata.combining(c)])
    return palavra_sem_acento

run_directory_2d = r"/dados/grpeta/S2S/season/forecast/3hr_2d/raw/2025031500"
run_directory_3d = r"/dados/grpeta/S2S/season/forecast/6hr_3d/raw/2025031500"
path_shapefile = r"/dados/grpeta/Team/DiegoChagas/WORKETAVIII/RECURSOS/SHAPES/contornos_brasil/estados_2010.shp"
path_figuras = r"/dados/grpeta/Team/DiegoChagas/WORKETAVIII/figuras_mensais"

# PALETTES AND LEVELS
cmap_prec = mpl.colors.ListedColormap([
    '#cccccc', '#b3b3b3', '#999999', '#80807e',
    '#00007c', '#0000cb', '#010afd', '#045af3', '#6297fd', '#a1bcff',
    '#054e03', '#006600', '#007501', '#058405', '#009700', '#00b000',
    '#00c800', '#00dd00', '#01ff01', '#fcff4d', '#ffd263', '#ffb463',
    '#fd9846', '#ff6e3b', '#ff511a', '#fd1500', '#dd0a03', '#c60200',
    '#b80000'
])
cmap_prec.set_under('#fefefe')
cmap_prec.set_over('#7c0104')

# Niveis para Precipitacao MEDIA (mm)
levels_prec_media = [
    0.2, 0.4, 0.6, 0.8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
    12, 14, 16, 18, 20, 25, 30, 40, 50, 60, 70, 80, 90, 100, 125, 150
]
# Niveis para Precipitacao ACUMULADA (mm/mes)
levels_prec_acc = [
    5, 10, 15, 20, 30, 40, 50, 60, 70, 80, 90, 100, 120, 140, 160,
    180, 200, 220, 240, 260, 280, 300, 350, 400, 450, 500, 550, 600, 650, 700
]

cmap_tp2m = mpl.colors.ListedColormap([
    '#326799', '#3F70AC', '#4D85BE', '#5CA2D3',
    '#6FC0E3', '#80D3E5', '#9CE4E8', '#CCF5F0',
    '#F7FEE9', '#FCEEA3', '#FBD058', '#F0B33A',
    '#E08F22', '#CC6E18', '#B44F14', '#973511',
    '#7D200E',
])
cmap_tp2m.set_under('#255C92')
cmap_tp2m.set_over('#5C0E0B')
levels_tp2m = np.arange(4, 40, 2)

levels_ocis = [
    50, 100, 150, 200, 250, 300, 350, 400
]
levels_ocis = sorted(list(set(levels_ocis)))
colors_ocis_custom = [
    '#ffff66', '#ffd700', '#ffbf00', '#ff8c00', '#ff4500',
    '#cc0000', '#a30000', '#800000', '#660000'
]
cmap_ocis = mpl.colors.ListedColormap(colors_ocis_custom)
cmap_ocis.set_under('#ffff66')
cmap_ocis.set_over('#4d0000')

colors_vento10m = [
    '#FFFFFF', '#FFFFB3', '#FFE066', '#FFC200',
    '#FFA500', '#FF8C00', '#FF4500', '#DC143C',
    '#A02020',
]
levels_vento10m = np.arange(2, 18, 2)
cmap_vento10m = mpl.colors.ListedColormap(colors_vento10m[1:8])
cmap_vento10m.set_under(colors_vento10m[0])
cmap_vento10m.set_over(colors_vento10m[-1])

levels_cssf = np.arange(-200, 140, 20)
levels_clsf = np.arange(-200, 140, 20)
colors_fluxos_distinct = [
    '#b866be', '#cf99d4', '#272ae6', '#6d89df', '#4368D6',
    '#00BFFF', '#00CED1', '#20B2AA', '#32CD32', '#ADFF2F',
    '#FFD700', '#FFC107', '#FF8C00', '#FF4500', '#DC143C',
    '#B22222'
]
cmap_fluxos = mpl.colors.ListedColormap(colors_fluxos_distinct)
cmap_fluxos.set_under('#880093')
cmap_fluxos.set_over('#8B0000')

levels_pslm = np.arange(990, 1020, 1)
cmap_pslm = mpl.colors.ListedColormap([
    '#cccccc', '#b3b3b3', '#999999', '#80807e',
    '#00007c', '#0000cb', '#010afd', '#045af3', '#6297fd', '#a1bcff',
    '#054e03', '#006600', '#007501', '#058405', '#009700', '#00b000',
    '#00c800', '#00dd00', '#01ff01', '#fcff4d', '#ffd263', '#ffb463',
    '#fd9846', '#ff6e3b', '#ff511a', '#fd1500', '#dd0a03', '#c60200',
    '#b80000'
])

cmap_zgeo = mpl.colors.ListedColormap([
    '#cccccc', '#b3b3b3', '#999999', '#80807e',
    '#00007c', '#0000cb', '#010afd', '#045af3', '#6297fd', '#a1bcff',
    '#054e03', '#006600', '#007501', '#058405', '#009700', '#00b000',
    '#00c800', '#00dd00', '#01ff01', '#fcff4d', '#ffd263', '#ffb463',
    '#fd9846', '#ff6e3b', '#ff511a', '#fd1500', '#dd0a03', '#c60200',
    '#b80000',
])

colors_temp_suave = [
    '#0000CD', '#1E90FF', '#6495ED', '#ADD8E6', '#F0F8FF',
    '#FFFACD', '#FFD700', '#FF8C00', '#FF4500', '#DC143C',
    '#8B0000'
]
cmap_temp = mpl.colors.LinearSegmentedColormap.from_list(
    "temp_custom_suave", colors_temp_suave, N=256
)
cmap_temp.set_under(colors_temp_suave[0])
cmap_temp.set_over(colors_temp_suave[-1])

def plot_monthly_variable(data_array, var_name, month_label, cmap, levels, extend, unit_cbar, output_dir, fig_name, title_prefix, scale_correction=0, plot_type='contourf'):
    data = data_array
    if var_name == 'prec':
        data = data * 1000
    elif var_name in ['cssf', 'clsf']:
        data = data * -1
    elif scale_correction != 0:
        data = data - scale_correction

    fig = plt.figure(figsize=(12, 12))
    ax = fig.add_subplot(111, projection=ccrs.PlateCarree())

    if plot_type == 'contourf':
        # Definindo a normalização para garantir o mapeamento correto da paleta para os níveis.
        # Isso corrige o problema do numero de cores ser diferente do numero de niveis - 1
        norm = mpl.colors.BoundaryNorm(levels, cmap.N)
        img = ax.contourf(
            data['lon'], data['lat'], data.values,
            levels=levels,
            cmap=cmap,
            norm=norm,
            extend=extend,
            transform=ccrs.PlateCarree()
        )
    else:
        img = ax.contour(
            data['lon'], data['lat'], data.values,
            levels=levels,
            colors=cmap.colors,
            linewidths=1.2,
            transform=ccrs.PlateCarree()
        )
        ax.clabel(img, inline=True, fontsize=10, fmt='%i', colors='black')

    title_map = f'{remover_acentos(title_prefix)} ({unit_cbar}) - {month_label}'
    ax.set_title(title_map, fontsize=20)

    shapefile_geom = list(shpreader.Reader(path_shapefile).geometries())
    ax.add_geometries(shapefile_geom, ccrs.PlateCarree(), edgecolor='black', facecolor='none', linewidth=0.5)

    if plot_type == 'contourf':
        cbar = plt.colorbar(
            img,
            ax=ax,
            pad=0.05,
            fraction=0.25 if var_name == 'prec' else 0.15,
            orientation='horizontal',
            extend=extend,
            spacing='uniform',
        )
        cbar.set_ticks(levels)
        if var_name == 'prec' and 'acumulada' in remover_acentos(title_prefix.lower()):
            cbar.ax.set_xticklabels([str(int(l)) if l == int(l) else str(l) for l in levels], rotation=45, ha='right')
        else:
            cbar.ax.set_xticklabels([str(int(l)) if l == int(l) else str(l) for l in levels])

    gl = ax.gridlines(crs=ccrs.PlateCarree(), color='black', alpha=1.0,
                      linestyle='dotted', linewidth=0.50, draw_labels=True)
    gl.top_labels = False
    gl.right_labels = False
    gl.yformatter = LATITUDE_FORMATTER
    gl.xformatter = LONGITUDE_FORMATTER

    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS)

    fig_path = os.path.join(output_dir, fig_name)
    fig.savefig(fig_path, bbox_inches='tight')
    plt.close(fig)

def plot_monthly_contour_variable(data_array, var_name, month_label, levels, unit_cbar, output_dir, fig_name, title_prefix, cmap=None):
    data = data_array

    fig = plt.figure(figsize=(12, 12))
    ax = fig.add_subplot(111, projection=ccrs.PlateCarree())

    if cmap:
        contour_lines = ax.contour(
            data['lon'], data['lat'], data.values,
            levels=levels,
            cmap=cmap,
            linewidths=0.8,
            transform=ccrs.PlateCarree()
        )
    else:
        contour_lines = ax.contour(
            data['lon'], data['lat'], data.values,
            levels=levels,
            colors='black',
            linewidths=0.8,
            transform=ccrs.PlateCarree()
        )

    ax.clabel(
        contour_lines,
        inline=True,
        fontsize=10,
        fmt='%1.0f'
    )

    title_map = f'{remover_acentos(title_prefix)} ({unit_cbar}) - {month_label}'
    ax.set_title(title_map, fontsize=20)

    shapefile_geom = list(shpreader.Reader(path_shapefile).geometries())
    ax.add_geometries(shapefile_geom, ccrs.PlateCarree(), edgecolor='black', facecolor='none', linewidth=0.5)

    gl = ax.gridlines(crs=ccrs.PlateCarree(), color='black', alpha=1.0,
                      linestyle='dotted', linewidth=0.50, draw_labels=True)
    gl.top_labels = False
    gl.right_labels = False
    gl.yformatter = LATITUDE_FORMATTER
    gl.xformatter = LONGITUDE_FORMATTER

    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS)

    fig_path = os.path.join(output_dir, fig_name)
    fig.savefig(fig_path, bbox_inches='tight')
    plt.close(fig)
    print(f"Plot de contorno para {var_name} em {remover_acentos(month_label)} gerado.")

def plot_monthly_wind10m(u10m_mean, v10m_mean, month_label, cmap, levels, output_dir, label_rodada, month_label_pt_upper, year_num):
    wind_speed = np.sqrt(u10m_mean**2 + v10m_mean**2)

    fig = plt.figure(figsize=(12, 12))
    ax = fig.add_subplot(111, projection=ccrs.PlateCarree())

    img = ax.contourf(
        wind_speed['lon'], wind_speed['lat'], wind_speed.values,
        cmap=cmap,
        levels=levels,
        extend='both',
        transform=ccrs.PlateCarree()
    )

    skip = 10
    ax.quiver(
        u10m_mean['lon'].values[::skip], u10m_mean['lat'].values[::skip],
        u10m_mean.values[::skip, ::skip], v10m_mean.values[::skip, ::skip],
        transform=ccrs.PlateCarree(),
        color='black',
        scale=300,
        width=0.002,
        alpha=0.5
    )

    title_map = f'Media Mensal da Velocidade do Vento 10m (m/s) - {month_label}'
    ax.set_title(remover_acentos(title_map), fontsize=20)

    shapefile_geom = list(shpreader.Reader(path_shapefile).geometries())
    ax.add_geometries(shapefile_geom, ccrs.PlateCarree(), edgecolor='black', facecolor='none', linewidth=0.5)

    cbar = plt.colorbar(
        img,
        ax=ax,
        pad=0.10,
        fraction=0.15,
        orientation='horizontal',
        extend='both',
        spacing='proportional',
    )
    cbar.set_ticks(levels)
    cbar.ax.set_xticklabels([str(int(l)) for l in levels])

    gl = ax.gridlines(crs=ccrs.PlateCarree(), color='black', alpha=1.0,
                      linestyle='dotted', linewidth=0.50, draw_labels=True)
    gl.top_labels = False
    gl.right_labels = False
    gl.yformatter = LATITUDE_FORMATTER
    gl.xformatter = LONGITUDE_FORMATTER

    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS)

    fig_name = f"{label_rodada}_vento10m_media_mensal_{remover_acentos(month_label_pt_upper)}_{year_num}.png"
    fig_path = os.path.join(output_dir, fig_name)
    fig.savefig(fig_path, bbox_inches='tight')
    plt.close(fig)

def process_run_for_monthly_averages(run_directory_2d, run_directory_3d):
    run_date_str = os.path.basename(run_directory_2d)
    label_rodada = "Eta20"

    meses_pt = {
        1: 'Janeiro', 2: 'Fevereiro', 3: 'Marco', 4: 'Abril',
        5: 'Maio', 6: 'Junho', 7: 'Julho', 8: 'Agosto',
        9: 'Setembro', 10: 'Outubro', 11: 'Novembro', 12: 'Dezembro'
    }

    print(f"Iniciando calculo de dados mensais para a rodada: {run_date_str}")

    run_output_dir = os.path.join(path_figuras, run_date_str)
    os.makedirs(run_output_dir, exist_ok=True)
    print(f"Diretorio de saida criado/verificado: {run_output_dir}")

    variables_2d_to_process = {
        'prec': {'label': 'Precipitacao', 'unit_media': 'mm', 'unit_acc': 'mm/mes', 'cmap': cmap_prec, 'levels_media': levels_prec_media, 'levels_acc': levels_prec_acc, 'extend': 'both', 'corr_escala': 0},
        'tp2m': {'label': 'Temperatura 2m', 'unit': '\u00B0C', 'cmap': cmap_tp2m, 'levels': levels_tp2m, 'extend': 'both', 'corr_escala': 273.15},
        'mntp': {'label': 'Temperatura Minima', 'unit': '\u00B0C', 'cmap': cmap_tp2m, 'levels': levels_tp2m, 'extend': 'both', 'corr_escala': 273.15},
        'mxtp': {'label': 'Temperatura Maxima', 'unit': '\u00B0C', 'cmap': cmap_tp2m, 'levels': levels_tp2m, 'extend': 'both', 'corr_escala': 273.15},
        'ocis': {'label': 'OCIS', 'unit': 'W/m\u00B2', 'cmap': cmap_ocis, 'levels': levels_ocis, 'extend': 'both', 'corr_escala': 0},
        'pslm': {'label': 'Pressao ao Nivel Medio do Mar', 'unit': 'hPa', 'levels': levels_pslm, 'cmap': cmap_pslm},
        'u10m': {'label': 'Vento U-componente 10m', 'unit': 'm/s'},
        'v10m': {'label': 'Vento V-componente 10m', 'unit': 'm/s'},
        'cssf': {'label': 'Fluxo de Calor Sensivel', 'unit': 'W/m\u00B2', 'cmap': cmap_fluxos, 'levels': levels_cssf, 'extend': 'both', 'corr_escala': 0},
        'clsf': {'label': 'Fluxo de Calor Latente', 'unit': 'W/m\u00B2', 'cmap': cmap_fluxos, 'levels': levels_clsf, 'extend': 'both', 'corr_escala': 0},
    }

    variables_3d_to_process = {
        'zgeo': {'label': 'Altura Geopotencial', 'unit': 'm', 'cmap': cmap_zgeo, 'levels': [200, 500, 850]},
        'temp': {'label': 'Temperatura', 'unit': '\u00B0C', 'cmap': cmap_temp, 'extend': 'both', 'corr_escala': 273.15, 'levels': {200: None, 500: None, 850: None}},
    }
    
    datasets_2d = {}
    datasets_3d = {}
    
    for var, config in variables_2d_to_process.items():
        if var in ['u10m', 'v10m']: continue
        file_path = os.path.join(run_directory_2d, f"{label_rodada}_{run_date_str}_{var}.nc")
        try:
            datasets_2d[var] = xr.open_dataset(file_path)
            print(f"Dataset 2D para {var} carregado: {file_path}")
        except FileNotFoundError as e:
            print(f"Aviso: Arquivo 2D nao encontrado para {var} em {file_path}. Pulando. Erro: {e}")
            datasets_2d[var] = None

    u10m_file = os.path.join(run_directory_2d, f"{label_rodada}_{run_date_str}_u10m.nc")
    v10m_file = os.path.join(run_directory_2d, f"{label_rodada}_{run_date_str}_v10m.nc")
    try:
        datasets_2d['u10m_ds'] = xr.open_dataset(u10m_file)
        datasets_2d['v10m_ds'] = xr.open_dataset(v10m_file)
        print(f"Datasets para u10m e v10m carregados: {u10m_file}, {v10m_file}")
    except FileNotFoundError as e:
        print(f"Aviso: Arquivos de vento nao encontrados para {run_date_str}. Pulando plots de vento. Erro: {e}")
        datasets_2d['u10m_ds'] = None
        datasets_2d['v10m_ds'] = None

    for var, config in variables_3d_to_process.items():
        file_path = os.path.join(run_directory_3d, f"{label_rodada}_{run_date_str}_{var}.nc")
        try:
            datasets_3d[var] = xr.open_dataset(file_path)
            print(f"Dataset 3D para {var} carregado: {file_path}")
        except FileNotFoundError as e:
            print(f"Aviso: Arquivo 3D nao encontrado para {var} em {file_path}. Pulando. Erro: {e}")
            datasets_3d[var] = None

    first_ds_key = None
    for key, ds in list(datasets_2d.items()) + list(datasets_3d.items()):
        if ds is not None and 'time' in ds.coords:
            first_ds_key = key
            ref_ds = ds
            break
    
    if first_ds_key is None:
        print(f"Nenhum dataset valido encontrado para a rodada {run_date_str}. Pulando esta rodada.")
        return

    times = pd.to_datetime(ref_ds['time'].values)
    start_time = times.min()
    end_time = times.max()

    print(f"Periodo de dados disponivel para {run_date_str}: de {start_time} a {end_time}")

    first_full_month_start = start_time.to_period('M').start_time
    if start_time > first_full_month_start:
        first_full_month_start = (start_time + pd.DateOffset(months=1)).to_period('M').start_time

    last_full_month_end = end_time.to_period('M').end_time
    if end_time < last_full_month_end:
        last_full_month_end = (end_time - pd.DateOffset(months=1)).to_period('M').end_time

    if first_full_month_start > last_full_month_end:
        print(f"Nenhum mes completo encontrado para calculo de media na rodada {run_date_str}.")
        return

    monthly_periods = pd.date_range(start=first_full_month_start, end=last_full_month_end, freq='MS')

    if not monthly_periods.empty:
        print("\nMeses completos identificados para plotagem:")
        for month_dt in monthly_periods:
            print(f"- {remover_acentos(meses_pt.get(month_dt.month, ''))} {month_dt.year}")
    else:
        print("Nenhum mes completo encontrado no intervalo de dados fornecido.")
        return

    limites_temp_nivel = {}
    if datasets_3d.get('temp') is not None:
        ds_temp = datasets_3d['temp']
        niveis_temp = [200, 500, 850]
        
        for nivel in niveis_temp:
            try:
                temp_data_all_times_at_level = ds_temp['temp'].sel(lev=nivel, method='nearest') - 273.15
                min_val = int(np.floor(temp_data_all_times_at_level.min()))
                max_val = int(np.ceil(temp_data_all_times_at_level.max()))
                
                step = 2
                levels = np.arange(min_val, max_val + step, step)
                limites_temp_nivel[nivel] = levels
                
            except Exception as e:
                print(f"Erro ao calcular limites para o nivel {nivel} hPa: {e}. Pulando.")
                limites_temp_nivel[nivel] = None

    variables_3d_to_process['temp']['levels'] = limites_temp_nivel

    for month_dt in monthly_periods:
        month_label_pt = f"{meses_pt.get(month_dt.month, '')} {month_dt.year}"
        month_label_pt_upper = f"{meses_pt.get(month_dt.month, '')}"
        month_num = month_dt.month
        year_num = month_dt.year

        print(f"\nProcessando plots para {remover_acentos(month_label_pt)}...")

        for var_name, config in variables_2d_to_process.items():
            if var_name in ['u10m', 'v10m']: continue
            if datasets_2d.get(var_name) is None: continue

            ds = datasets_2d[var_name]
            monthly_data_slice = ds[var_name].sel(time=(ds['time'].dt.month == month_num) & (ds['time'].dt.year == year_num))

            if monthly_data_slice.sizes['time'] < 20: 
                print(f"Aviso: Dados incompletos para {var_name} em {remover_acentos(month_label_pt)}. Pulando.")
                continue

            if var_name == 'mntp':
                monthly_data_mean = monthly_data_slice.min(dim='time')
            elif var_name == 'mxtp':
                monthly_data_mean = monthly_data_slice.max(dim='time')
            else:
                monthly_data_mean = monthly_data_slice.mean(dim='time')

            if monthly_data_mean.isnull().all():
                print(f"Aviso: Dados mensais para {var_name} em {remover_acentos(month_label_pt)} sao todos NaNs. Pulando plot.")
                continue

            final_fig_name = f"{label_rodada}_{var_name}_media_mensal_{remover_acentos(month_label_pt_upper)}_{year_num}.png"

            if var_name == 'pslm':
                plot_monthly_contour_variable(
                    data_array=monthly_data_mean,
                    var_name=var_name,
                    month_label=remover_acentos(month_label_pt),
                    levels=config['levels'],
                    unit_cbar=config['unit'],
                    output_dir=run_output_dir,
                    fig_name=final_fig_name,
                    title_prefix=f'Media Mensal de {remover_acentos(config["label"])}',
                    cmap=config.get('cmap')
                )
            elif var_name == 'prec':
                monthly_data_acc = monthly_data_slice.sum(dim='time')
                final_fig_name_acc = f"{label_rodada}_prec_acumulada_mensal_{remover_acentos(month_label_pt_upper)}_{year_num}.png"
                
                plot_monthly_variable(
                    data_array=monthly_data_acc,
                    var_name='prec',
                    month_label=remover_acentos(month_label_pt),
                    cmap=config['cmap'],
                    levels=config['levels_acc'],
                    extend='both',
                    unit_cbar='mm/mes',
                    output_dir=run_output_dir,
                    fig_name=final_fig_name_acc,
                    title_prefix='Precipitacao Mensal Acumulada',
                    scale_correction=0
                )
                print(f"Precipitacao Acumulada para {remover_acentos(month_label_pt)} gerada.")

                final_fig_name_mean = f"{label_rodada}_prec_media_mensal_{remover_acentos(month_label_pt_upper)}_{year_num}.png"
                
                plot_monthly_variable(
                    data_array=monthly_data_mean,
                    var_name='prec',
                    month_label=remover_acentos(month_label_pt),
                    cmap=config['cmap'],
                    levels=config['levels_media'],
                    extend='both',
                    unit_cbar='mm',
                    output_dir=run_output_dir,
                    fig_name=final_fig_name_mean,
                    title_prefix='Media Mensal Precipitacao',
                    scale_correction=0
                )
                print(f"Media Mensal Precipitacao para {remover_acentos(month_label_pt)} gerada.")
            
            else:
                plot_monthly_variable(
                    data_array=monthly_data_mean,
                    var_name=var_name,
                    month_label=remover_acentos(month_label_pt),
                    cmap=config['cmap'],
                    levels=config['levels'],
                    extend=config['extend'],
                    unit_cbar=config['unit'],
                    output_dir=run_output_dir,
                    fig_name=final_fig_name,
                    title_prefix=f'Media Mensal de {remover_acentos(config["label"])}',
                    scale_correction=config.get('corr_escala', 0)
                )
                print(f"Plot para {var_name} em {remover_acentos(month_label_pt)} gerado.")

        for var_name, config in variables_3d_to_process.items():
            if datasets_3d.get(var_name) is None: continue
            
            ds = datasets_3d[var_name]
            
            if var_name == 'temp':
                level_list = sorted(config['levels'].keys())
            else:
                level_list = config['levels']
                
            for level_val in level_list:
                monthly_data_slice = ds[var_name].sel(time=(ds['time'].dt.month == month_num) & (ds['time'].dt.year == year_num), lev=level_val, method='nearest')
                
                if monthly_data_slice.sizes['time'] < 10: 
                     print(f"Aviso: Dados incompletos para {var_name} no nivel {level_val} hPa em {remover_acentos(month_label_pt)}. Pulando.")
                     continue
                
                monthly_data_mean = monthly_data_slice.mean(dim='time')

                if monthly_data_mean.isnull().all():
                    print(f"Aviso: Dados mensais para {var_name} no nivel {level_val} hPa em {remover_acentos(month_label_pt)} sao todos NaNs. Pulando plot.")
                    continue
                    
                final_fig_name = f"{label_rodada}_{var_name}_{level_val}hPa_media_mensal_{remover_acentos(month_label_pt_upper)}_{year_num}.png"
                
                if var_name == 'zgeo':
                    min_val = float(monthly_data_mean.min())
                    max_val = float(monthly_data_mean.max())
                    levels_zgeo_plot = np.linspace(min_val, max_val, 15)
                    
                    plot_monthly_contour_variable(
                        data_array=monthly_data_mean,
                        var_name=var_name,
                        month_label=remover_acentos(month_label_pt),
                        levels=levels_zgeo_plot,
                        unit_cbar=config['unit'],
                        output_dir=run_output_dir,
                        fig_name=final_fig_name,
                        title_prefix=f'Media Mensal da {remover_acentos(config["label"])} {level_val} hPa',
                        cmap=config['cmap']
                    )
                
                elif var_name == 'temp':
                    levels_temp_plot = config['levels'][level_val]
                    
                    if levels_temp_plot is not None:
                        plot_monthly_variable(
                            data_array=monthly_data_mean,
                            var_name=var_name,
                            month_label=remover_acentos(month_label_pt),
                            cmap=config['cmap'],
                            levels=levels_temp_plot,
                            extend=config['extend'],
                            unit_cbar=config['unit'],
                            output_dir=run_output_dir,
                            fig_name=final_fig_name,
                            title_prefix=f'Media Mensal da {remover_acentos(config["label"])} {level_val} hPa',
                            scale_correction=config.get('corr_escala', 0)
                        )
                        print(f"Plot para {var_name} no nivel {level_val} hPa em {remover_acentos(month_label_pt)} gerado.")
                    else:
                        print(f"Aviso: Nao foi possivel gerar os levels de temperatura para o nivel {level_val} hPa. Pulando plot.")

        if datasets_2d.get('u10m_ds') is not None and datasets_2d.get('v10m_ds') is not None:
            ds_u10m = datasets_2d['u10m_ds']
            ds_v10m = datasets_2d['v10m_ds']

            u10m_monthly_slice = ds_u10m['u10m'].sel(time=(ds_u10m['time'].dt.month == month_num) & (ds_u10m['time'].dt.year == year_num))
            v10m_monthly_slice = ds_v10m['v10m'].sel(time=(ds_v10m['time'].dt.month == month_num) & (ds_v10m['time'].dt.year == year_num))

            if u10m_monthly_slice.sizes['time'] < 20 or v10m_monthly_slice.sizes['time'] < 20:
                print(f"Aviso: Dados incompletos de vento em {remover_acentos(month_label_pt)}. Pulando.")
            else:
                u10m_monthly_mean = u10m_monthly_slice.mean(dim='time')
                v10m_monthly_mean = v10m_monthly_slice.mean(dim='time')

                if u10m_monthly_mean.isnull().all() or v10m_monthly_mean.isnull().all():
                    print(f"Aviso: Media mensal de vento em {remover_acentos(month_label_pt)} e todos NaNs. Pulando plot.")
                else:
                    plot_monthly_wind10m(
                        u10m_mean=u10m_monthly_mean,
                        v10m_mean=v10m_monthly_mean,
                        month_label=remover_acentos(month_label_pt),
                        cmap=cmap_vento10m,
                        levels=levels_vento10m,
                        output_dir=run_output_dir,
                        label_rodada=label_rodada,
                        month_label_pt_upper=remover_acentos(month_label_pt_upper),
                        year_num=year_num
                    )
                    print(f"Media mensal de vento em {remover_acentos(month_label_pt)} plotada.")
        else:
            print(f"Aviso: Datasets de vento nao disponiveis para {run_date_str}. Pulando plots de vento para {remover_acentos(month_label_pt)}.")

    for ds_obj in list(datasets_2d.values()) + list(datasets_3d.values()):
        if ds_obj is not None:
            if hasattr(ds_obj, 'close'):
                ds_obj.close()
    print(f"Processamento da rodada finalizado: {run_date_str}")

if __name__ == "__main__":
    process_run_for_monthly_averages(run_directory_2d, run_directory_3d)
    print(f"\nTodas as figuras mensais foram geradas e salvas nas subpastas em:\n{path_figuras}")
