# -*- coding: utf-8 -*-

import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import os
from datetime import datetime, timedelta
import matplotlib.dates as mdates

# --- CONFIGURAÇÕES INICIAIS ---
# Definir a lista de pontos para os meteogramas, incluindo o país
locais = [
    {"nome": "Sao Paulo", "lat": -23.55, "lon": -46.63, "sigla": "XSPL", "pais": "BR"},
    {"nome": "Rio de Janeiro", "lat": -22.90, "lon": -43.20, "sigla": "XRJ", "pais": "BR"}
    # Exemplo de uma cidade fora do Brasil, para ilustrar a flexibilidade:
    # {"nome": "Buenos Aires", "lat": -34.60, "lon": -58.38, "sigla": "XBA", "pais": "AR"}
]

label_rodada = "Eta05_M01"
RUNDATE = "2025072900" # Ex: 2023021900

path_dados = r"/dados/grpeta/Team/DiegoChagas/Eta_Ensemble_5km/CENAPAD/netcdf/2025072900/M01"
path_figuras = r"/dados/grpeta/Team/DiegoChagas/WORKETAVIII"

# Criar um subdiretório para meteogramas
path_meteogramas = os.path.join(path_figuras, "meteogramas")
os.makedirs(path_meteogramas, exist_ok=True)

# Nomes dos arquivos NetCDF (assumindo que a estrutura é consistente)
nome_base_prec = f"{label_rodada}_PREC_{RUNDATE}"
nome_base_tp2m = f"{label_rodada}_TP2M_{RUNDATE}"
nome_base_ur2m = f"{label_rodada}_UR2M_{RUNDATE}"
nome_base_ocis = f"{label_rodada}_OCIS_{RUNDATE}"
nome_base_u10m = f"{label_rodada}_U10M_{RUNDATE}"
nome_base_v10m = f"{label_rodada}_V10M_{RUNDATE}"
nome_base_lwnv = f"{label_rodada}_LWNV_{RUNDATE}" # Nuvem baixa
nome_base_mdnv = f"{label_rodada}_MDNV_{RUNDATE}" # Nuvem média
nome_base_hinv = f"{label_rodada}_HINV_{RUNDATE}" # Nuvem alta


# Construção dos caminhos completos para os arquivos
arquivo_nc_prec = os.path.join(path_dados, f"{nome_base_prec}.nc")
arquivo_nc_tp2m = os.path.join(path_dados, f"{nome_base_tp2m}.nc")
arquivo_nc_ur2m = os.path.join(path_dados, f"{nome_base_ur2m}.nc")
arquivo_nc_ocis = os.path.join(path_dados, f"{nome_base_ocis}.nc")
arquivo_nc_u10m = os.path.join(path_dados, f"{nome_base_u10m}.nc")
arquivo_nc_v10m = os.path.join(path_dados, f"{nome_base_v10m}.nc")
arquivo_nc_lwnv = os.path.join(path_dados, f"{nome_base_lwnv}.nc")
arquivo_nc_mdnv = os.path.join(path_dados, f"{nome_base_mdnv}.nc")
arquivo_nc_hinv = os.path.join(path_dados, f"{nome_base_hinv}.nc")

# Parse da data e hora da rodada
rodada_data_dt = datetime.strptime(RUNDATE, '%Y%m%d%H')

# --- LEITURA DOS DADOS (feito uma vez para todos os locais, pois os arquivos são os mesmos) ---
try:
    ds_prec = xr.open_dataset(arquivo_nc_prec)
    ds_tp2m = xr.open_dataset(arquivo_nc_tp2m)
    ds_ur2m = xr.open_dataset(arquivo_nc_ur2m)
    ds_ocis = xr.open_dataset(arquivo_nc_ocis)
    ds_u10m = xr.open_dataset(arquivo_nc_u10m)
    ds_v10m = xr.open_dataset(arquivo_nc_v10m)
    ds_lwnv = xr.open_dataset(arquivo_nc_lwnv)
    ds_mdnv = xr.open_dataset(arquivo_nc_mdnv)
    ds_hinv = xr.open_dataset(arquivo_nc_hinv)
except FileNotFoundError as e:
    print(f"Erro: Arquivo nao encontrado. Certifique-se de que todos os arquivos NetCDF necessarios (incluindo UR2M, LWNV, MDNV, HINV) existem no caminho especificado.")
    print(f"Arquivo ausente: {e.filename}")
    exit()

# --- LOOP PARA GERAR METEOGRAMAS PARA CADA LOCAL ---
for local in locais:
    LAT_POINT = local["lat"]
    LON_POINT = local["lon"]
    SIGLA_LOCAL = local["sigla"]
    NOME_CIDADE = local["nome"]
    NOME_PAIS = local["pais"] # Novo campo para o país

    print(f"\nGerando meteograma para {NOME_CIDADE}, {NOME_PAIS} ({SIGLA_LOCAL})...")

    # --- EXTRAÇÃO DOS DADOS PARA O PONTO ESPECÍFICO ---
    prec_series = ds_prec['prec'].sel(lat=LAT_POINT, lon=LON_POINT, method='nearest').squeeze() * 1000
    tp2m_series = ds_tp2m['tp2m'].sel(lat=LAT_POINT, lon=LON_POINT, method='nearest').squeeze() - 273.15 # °C
    ur2m_series = ds_ur2m['ur2m'].sel(lat=LAT_POINT, lon=LON_POINT, method='nearest').squeeze() # %
    
    u10m_series_vals = ds_u10m['u10m'].sel(lat=LAT_POINT, lon=LON_POINT, method='nearest').squeeze().values
    v10m_series_vals = ds_v10m['v10m'].sel(lat=LAT_POINT, lon=LON_POINT, method='nearest').squeeze().values
    wind_speed = np.sqrt(u10m_series_vals**2 + v10m_series_vals**2) # m/s

    ocis_series = ds_ocis['ocis'].sel(lat=LAT_POINT, lon=LON_POINT, method='nearest').squeeze()

    lwnv_series = ds_lwnv['lwnv'].sel(lat=LAT_POINT, lon=LON_POINT, method='nearest').squeeze() * 100
    mdnv_series = ds_mdnv['mdnv'].sel(lat=LAT_POINT, lon=LON_POINT, method='nearest').squeeze() * 100
    hinv_series = ds_hinv['hinv'].sel(lat=LAT_POINT, lon=LON_POINT, method='nearest').squeeze() * 100

    if np.issubdtype(prec_series['time'].dtype, np.datetime64):
        times = prec_series['time'].values.astype('M8[ms]').astype('O')
    else:
        times = [rodada_data_dt + timedelta(hours=int(h)) for h in prec_series['time'].values]

    # --- PLOTAGEM DO METEOGRAMA ---
    fig, axes = plt.subplots(nrows=6, ncols=1, figsize=(14, 18), sharex=True)
    # Título do gráfico agora inclui o nome do país
    fig.suptitle(f'Eta Model - {rodada_data_dt.strftime("%d/%m/%Y %HUTC")}\n'
                 f'{NOME_CIDADE}, {NOME_PAIS}     lat:{LAT_POINT:.4f} lon:{LON_POINT:.4f}  ', fontsize=18, weight='bold') 

    if len(times) > 1:
        bar_width = (mdates.date2num(times[1]) - mdates.date2num(times[0])) * 0.8
    else:
        bar_width = 0.03

    # 1. Precipitação (barras)
    ax0 = axes[0]
    ax0.bar(times, prec_series.values, color='steelblue', width=bar_width)
    prec_max = np.ceil(prec_series.max() * 1.2) if prec_series.max() > 0 else 5
    ax0.set_ylim(0, max(prec_max, 8))
    ax0.set_yticks(np.arange(0, ax0.get_ylim()[1] + 1, 1))
    ax0.tick_params(axis='y', labelcolor='black', labelsize=12)
    ax0.set_title('Hourly Precipitation (mm/h)', fontsize=14, loc='left', weight='bold', color='red')
    ax0.grid(axis='y', linestyle='--', alpha=0.7)


    # 2. Temperatura 2m (linha)
    ax1 = axes[1]
    ax1.plot(times, tp2m_series.values, color='blue', linewidth=2)
    min_temp_val = np.floor(tp2m_series.min() - 2)
    max_temp_val = np.ceil(tp2m_series.max() + 2)
    ax1.set_ylim(min_temp_val, max_temp_val)
    ax1.set_yticks(np.arange(np.floor(min_temp_val), np.ceil(max_temp_val) + 1, 2))
    ax1.tick_params(axis='y', labelcolor='black', labelsize=12)
    ax1.set_title('2-m Temperature (C)', fontsize=14, loc='left', weight='bold', color='red')
    ax1.grid(axis='y', linestyle='--', alpha=0.7)


    # 3. Umidade Relativa 2m (linha)
    ax2 = axes[2]
    ax2.plot(times, ur2m_series.values, color='blue', linewidth=2)
    min_ur_val = np.floor(ur2m_series.min()) - 5 if ur2m_series.min() < 0 else 0
    max_ur_val = np.ceil(ur2m_series.max()) + 5 if ur2m_series.max() > 100 else 100
    ax2.set_ylim(max(0, min_ur_val), min(100, max_ur_val))
    ax2.set_yticks(np.arange(0, 101, 20))
    ax2.tick_params(axis='y', labelcolor='black', labelsize=12)
    ax2.set_title('2-m Relative Humidity (%)', fontsize=14, loc='left', weight='bold', color='red')
    ax2.grid(axis='y', linestyle='--', alpha=0.7)


    # 4. Vento 10m (linha para intensidade, setas para direção)
    ax3 = axes[3]
    ax3.plot(times, wind_speed, color='blue', linewidth=2) # wind_speed já é um array NumPy

    min_wind_val = 0
    max_wind_val = np.ceil(wind_speed.max() * 1.2) if wind_speed.max() > 0 else 5 
    ax3.set_ylim(min_wind_val, max(max_wind_val, 3)) 

    ax3.set_yticks(np.arange(0, ax3.get_ylim()[1] + 1, 1))
    ax3.tick_params(axis='y', labelcolor='black', labelsize=12)

    skip_interval = 1
    if len(times[::skip_interval]) > 0:
        y_quiver_pos = ax3.get_ylim()[1] * 0.5 

        ax3.quiver(times[::skip_interval], y_quiver_pos,
                   u10m_series_vals[::skip_interval], v10m_series_vals[::skip_interval], 
                   color='black',
                   scale=ax3.get_ylim()[1] * 10, 
                   width=0.0025, 
                   headwidth=5,  
                   headlength=3, 
                   headaxislength=2, 
                   alpha=0.8) 

    ax3.set_title('10-m Wind (m/s)', fontsize=14, loc='left', weight='bold', color='red')
    ax3.grid(axis='y', linestyle='--', alpha=0.7)


    # 5. Downward Shortwave Radiation Flux at Surface (W/m²)
    ax4 = axes[4]
    ax4.fill_between(times, 0, ocis_series.values, color='skyblue', alpha=0.7) 
    ax4.plot(times, ocis_series.values, color='blue', linewidth=1, alpha=0.9) 

    min_ocis_val = 0 
    max_ocis_val = np.ceil(ocis_series.max() * 1.1) 
    if max_ocis_val == 0:
        max_ocis_val = 100 

    ax4.set_ylim(min_ocis_val, max_ocis_val)
    ax4.set_yticks(np.arange(0, max_ocis_val + 1, max(100, int(max_ocis_val/5)))) 
    ax4.tick_params(axis='y', labelcolor='black', labelsize=12)
    ax4.set_title('Downward Shortwave Radiation Flux at Surface (W/m\u00B2)', fontsize=14, loc='left', weight='bold', color='red')
    ax4.grid(axis='y', linestyle='--', alpha=0.7)


    # 6. Cobertura de Nuvens (barras empilhadas)
    ax5 = axes[5]
    p1 = ax5.bar(times, lwnv_series.values, color='#ADD8E6', width=bar_width, label='low clouds')
    p2 = ax5.bar(times, mdnv_series.values, bottom=lwnv_series.values, color='#90EE90', width=bar_width, label='middle clouds')
    p3 = ax5.bar(times, hinv_series.values, bottom=lwnv_series.values + mdnv_series.values, color='#FFDAB9', width=bar_width, label='high clouds')

    ax5.set_ylim(0, 100)
    ax5.set_yticks(np.arange(0, 101, 20))
    ax5.tick_params(axis='y', labelcolor='black', labelsize=12)
    ax5.set_title('Cloud Cover (%)', fontsize=14, loc='left', weight='bold', color='red')
    ax5.grid(axis='y', linestyle='--', alpha=0.7)
    ax5.legend(loc='lower left', ncol=3, fontsize=12)


    # --- Formatação do Eixo X (tempo) ---
    if len(times) > 0:
        start_num = mdates.date2num(times[0])
        end_num = mdates.date2num(times[-1])
        for ax in axes:
            ax.set_xlim(start_num - bar_width/2, end_num + bar_width/2)

    for i, ax in enumerate(axes):
        ax.xaxis.set_major_locator(mdates.DayLocator(interval=1))
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%d%b\n%Y'))

        ax.xaxis.set_minor_locator(mdates.HourLocator(interval=6))
        ax.tick_params(axis='x', rotation=0, labelsize=12, labelbottom=True)
        
        ax.xaxis.grid(True, which='major', linestyle='--', alpha=0.7)

    # Ajustes finais para o layout da figura
    plt.tight_layout(rect=[0, 0.03, 1, 0.96])

    # Salvar a figura com a nova sigla do local
    fig_name = f"{SIGLA_LOCAL}_{RUNDATE}.png"
    fig_path = os.path.join(path_meteogramas, fig_name)
    plt.savefig(fig_path, dpi=200)
    plt.close(fig) 

    print(f"? Meteograma para {NOME_CIDADE}, {NOME_PAIS} gerado e salvo em:\n{fig_path}")

print("\nProcessamento de todos os locais concluido.")
