import xarray as xr
import numpy as np
import rasterio
from rasterio.transform import from_bounds
from rasterio.crs import CRS
import os
import pandas as pd # Para criar datas no exemplo

def convert_netcdf_to_cog(
    netcdf_path,
    variable_name,
    output_cog_path,
    time_index=0, # Índice do passo de tempo a ser convertido
    compression="DEFLATE", # LZW, DEFLATE, JPEG, ZSTD, LZMA
    tile_size=256, # Tamanho dos blocos (tiles) para o COG 
    flip_lat=True # NOVO PARÂMETRO: Se True, inverte a dimensão da latitude
):
    """
    Converte uma variável de um arquivo NetCDF (para um passo de tempo específico)
    para um GeoTIFF otimizado para nuvem (COG).

    Args:
        netcdf_path (str): Caminho para o arquivo NetCDF de entrada.
        variable_name (str): Nome da variável a ser convertida (ex: 'temperature').
        output_cog_path (str): Caminho para o arquivo COG GeoTIFF de saída.
        time_index (int): O índice do passo de tempo a ser extraído e convertido.
                          (Um COG representa um único slice espacial).
        compression (str): Método de compressão para o COG.
        tile_size (int): Tamanho dos blocos (tiles) internos do COG (ex: 256x256 pixels).
    """
    print(f"\nIniciando conversão de '{netcdf_path}' para COG...")
    print(f"Variável: '{variable_name}', Passo de tempo (índice): {time_index}")

    try:
        # Abre o dataset NetCDF
        ds = xr.open_dataset(netcdf_path)
    except FileNotFoundError:
        print(f"Erro: Arquivo NetCDF não encontrado em '{netcdf_path}'")
        return
    except Exception as e:
        print(f"Erro ao abrir o arquivo NetCDF: {e}")
        return

    if variable_name not in ds.data_vars:
        print(f"Erro: Variável '{variable_name}' não encontrada no dataset.")
        print(f"Variáveis disponíveis: {list(ds.data_vars.keys())}")
        return

    data_array = ds[variable_name]

    # Verifica se a dimensão 'time' existe e seleciona o passo de tempo
    if 'time' in data_array.dims:
        if time_index >= len(data_array['time']):
            print(f"Erro: Índice de tempo {time_index} fora dos limites (0 a {len(data_array['time']) - 1}).")
            return
        data_slice = data_array.isel(time=time_index)
        print(f"Convertendo dados para o tempo: {data_slice['time'].item()}")
    else:
        # Se não houver dimensão de tempo, assume que já é 2D espacial
        data_slice = data_array
        print("Dataset não possui dimensão de tempo, convertendo diretamente.")

    # Garante que as dimensões estão na ordem (lat, lon) ou (y, x)
    # rasterio espera (height, width) para 2D
    if 'lat' in data_slice.dims and 'lon' in data_slice.dims:
        # Se a ordem não for (lat, lon), pode precisar de transposição
        if list(data_slice.dims) != ['lat', 'lon']:
            # Tenta transpor para a ordem esperada por rasterio para georreferenciamento
            # Assumindo que lat é Y e lon é X
            if 'lon' in data_slice.dims and 'lat' in data_slice.dims:
                # Reordena para (lat, lon) se necessário
                data_slice = data_slice.transpose('lat', 'lon')
            else:
                print("Aviso: Dimensões espaciais não são 'lat' e 'lon'. A ordem pode estar incorreta.")
    else:
        print("Aviso: Dimensões espaciais não são 'lat' e 'lon'. Verifique a ordem das dimensões.")

    # NOVO: Realiza o flip vertical na dimensão 'lat' se solicitado
    if flip_lat and 'lat' in data_slice.dims:
        data_slice = data_slice.isel(lat=slice(None, None, -1))
        print("Flip vertical aplicado na dimensão 'lat'.")
    # Extrai as coordenadas e metadados
    height, width = data_slice.shape
    min_lon = float(data_slice.lon.min())
    max_lon = float(data_slice.lon.max())
    min_lat = float(data_slice.lat.min())
    max_lat = float(data_slice.lat.max())

    # Calcula a transformação (affine transform)
    # from_bounds(left, bottom, right, top, width, height)
    transform = from_bounds(min_lon, min_lat, max_lon, max_lat, width, height)

    # Define o CRS (Coordinate Reference System)
    # Assumindo WGS84 (latitude/longitude) - EPSG:4326 é muito comum
    # Se seu NetCDF tiver um CRS diferente, você precisará extraí-lo
    # do ds.rio.crs ou de atributos específicos do NetCDF.
    # Ex: crs = CRS.from_wkt(ds.rio.crs.wkt) se você usou rioxarray para carregar
    # ou crs = CRS.from_epsg(4326)
    crs = CRS.from_epsg(4326) # WGS84

    # Prepara o perfil para o GeoTIFF de saída
    profile = {
        "driver": "GTiff",
        "height": height,
        "width": width,
        "count": 1,  # Número de bandas (uma para a variável de temperatura)
        "dtype": data_slice.dtype,
        "crs": crs,
        "transform": transform,
        "compress": compression,
        "tiled": True,
        "blockxsize": tile_size,
        "blockysize": tile_size,
        "BIGTIFF": "IF_NEEDED", # Para arquivos maiores que 4GB
        "nodata": data_slice.attrs.get("_FillValue", None) # Usar _FillValue do NetCDF como nodata
    }

    # Adiciona otimizações para COG
    # 'CLOUD_OPTIMIZED' é um perfil predefinido em rasterio que aplica muitas otimizações COG
    # Isso inclui overviews e outras configurações
    profile.update(
        {
            "driver": "GTiff",
            "tiled": True,
            "blockxsize": tile_size,
            "blockysize": tile_size,
            "compress": compression,
            "predictor": 2 if compression in ["DEFLATE", "LZW"] else 1, # Otimização para compressão
            "overview_level": 1, # Nível de overviews a serem gerados
            "num_threads": "ALL_CPUS" # Usar todos os CPUs para criação
        }
    )

    # Cria o diretório de saída se não existir
    output_dir = os.path.dirname(output_cog_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    print(f"Escrevendo COG em: '{output_cog_path}' com perfil: {profile}")

    # Escreve o arquivo GeoTIFF
    try:
        with rasterio.open(output_cog_path, "w", **profile) as dst:
            dst.write(data_slice.values, 1) # Escreve a banda 1
            # Gera overviews para otimização COG
            dst.build_overviews([2, 4, 8, 16, 32], resampling=rasterio.enums.Resampling.average)
            dst.update_tags(ns='rio_overview', resampling='average')
            # Adiciona metadados do xarray como tags no GeoTIFF
            dst.update_tags(**data_slice.attrs)
            dst.update_tags(NETCDF_TIME=str(data_slice['time'].item())) # Adiciona o tempo como tag
        print(f"Conversão para COG concluída com sucesso: '{output_cog_path}'")
    except Exception as e:
        print(f"Erro ao escrever o arquivo COG: {e}")
    finally:
        ds.close() # Garante que o dataset NetCDF é fechado

# --- Exemplo de Uso ---
if __name__ == "__main__":
    netcdf_input_file = "Eta20_2025061300_tp2m.nc"
    output_cog_file = "output_cogs/temperature_slice_0.tif"
    variable_to_convert = "tp2m"

    # 1. Crie um arquivo NetCDF de exemplo (se você não tiver um)
    #create_sample_netcdf(netcdf_input_file)

    # 2. Converta o arquivo NetCDF para COG
    # Converter o primeiro passo de tempo (índice 0)
    convert_netcdf_to_cog(
        netcdf_input_file,
        variable_to_convert,
        output_cog_file,
        time_index=0,
        compression="DEFLATE",
        tile_size=512 # Pode ajustar para 256, 512, etc.
    )

    # Exemplo: Converter um passo de tempo diferente (índice 2)
    output_cog_file_2 = "output_cogs/temperature_slice_2.tif"
    convert_netcdf_to_cog(
        netcdf_input_file,
        variable_to_convert,
        output_cog_file_2,
        time_index=2,
        compression="LZW", # Experimente diferentes compressões
        tile_size=256
    )

    # Opcional: Limpar o arquivo NetCDF de exemplo e a pasta de saída
    # os.remove(netcdf_input_file)
    # import shutil
    # if os.path.exists("output_cogs"):
    #     shutil.rmtree("output_cogs")
    # print("\nArquivos de exemplo e diretório de saída removidos.")
Shelter_Temperature
