import xarray as xr
import numpy as np
import rasterio
from rasterio.transform import from_bounds
from rasterio.crs import CRS
import os
import pandas as pd # Usado para criar datas no NetCDF de exemplo
import datetime # Usado para formatar o timestamp do nome do arquivo

def convert_netcdf_to_cog_sequence(
    netcdf_path,
    variable_name,
    output_base_dir, # Agora é um diretório base, não um caminho de arquivo
    level_index=0,
    start_time_index=0, # NOVO: Índice de início para o loop de tempo
    end_time_index=None, # NOVO: Índice de fim para o loop de tempo (None para ir até o final)
    compression="DEFLATE",
    tile_size=256,
    flip_lat=True
):
    """
    Converte uma variável de um arquivo NetCDF (para múltiplos passos de tempo)
    para uma sequência de GeoTIFFs otimizados para nuvem (COG).

    Args:
        netcdf_path (str): Caminho para o arquivo NetCDF de entrada.
        variable_name (str): Nome da variável a ser convertida (ex: 'temperature').
        output_base_dir (str): Caminho para o diretório onde os COGs de saída serão salvos.
        level_index (int): O índice do nível a ser extraído e convertido (se a dimensão 'level' existir).
        start_time_index (int): O índice do primeiro passo de tempo a ser convertido.
        end_time_index (int): O índice do último passo de tempo a ser convertido (inclusive).
                               Se None, converte até o final.
        compression (str): Método de compressão para o COG.
        tile_size (int): Tamanho dos blocos (tiles) internos do COG (ex: 256x256 pixels).
        flip_lat (bool): Se True, inverte a dimensão da latitude para garantir que a origem
                         seja superior esquerda (necessário para alguns NetCDFs).
    """
    print(f"\nIniciando conversão de '{netcdf_path}' para sequência de COGs...")
    print(f"Variável: '{variable_name}', Nível (índice): {level_index}")

    try:
        ds = xr.open_dataset(netcdf_path)
    except FileNotFoundError:
        print(f"Erro: Arquivo NetCDF não encontrado em '{netcdf_path}'")
        return
    except Exception as e:
        print(f"Erro ao abrir o arquivo NetCDF: {e}")
        return

    if variable_name not in ds.data_vars:
        print(f"Erro: Variável '{variable_name}' não encontrada no dataset.")
        print(f"Variáveis disponíveis: {list(ds.data_vars.keys())}")
        ds.close()
        return

    data_array = ds[variable_name]

    # Cria o diretório de saída base se não existir
    if not os.path.exists(output_base_dir):
        os.makedirs(output_base_dir)
        print(f"Diretório de saída criado: '{output_base_dir}'")

    # Determina o intervalo de tempo para o loop
    time_dim_exists = 'time' in data_array.dims
    if time_dim_exists:
        num_times = len(data_array['time'])
        if end_time_index is None:
            end_time_index = num_times - 1
        
        if not (0 <= start_time_index < num_times and 0 <= end_time_index < num_times and start_time_index <= end_time_index):
            print(f"Erro: Índices de tempo inválidos. Início: {start_time_index}, Fim: {end_time_index}, Total: {num_times}.")
            ds.close()
            return
        
        time_indices_to_process = range(start_time_index, end_time_index + 1)
        print(f"Processando passos de tempo de {start_time_index} a {end_time_index}.")
    else:
        time_indices_to_process = [0] # Processa uma única "fatia" se não houver dimensão de tempo
        print("Dataset não possui dimensão de tempo, convertendo diretamente (assumindo um único passo).")


    # LOOP PRINCIPAL SOBRE OS PASSOS DE TEMPO
    for i, time_idx in enumerate(time_indices_to_process):
        print(f"\n--- Processando passo de tempo (índice global): {time_idx} ---")
        
        current_data_slice = data_array.isel(time=time_idx) if time_dim_exists else data_array

        # Seleciona o nível
        if 'level' in current_data_slice.dims:
            if level_index >= len(current_data_slice['level']):
                print(f"Erro: Índice de nível {level_index} fora dos limites (0 a {len(current_data_slice['level']) - 1}). Pulando este passo de tempo.")
                continue # Pula para o próximo passo de tempo
            current_data_slice = current_data_slice.isel(level=level_index)
            print(f"Convertendo dados para o nível: {current_data_slice['level'].item()}")
        else:
            print("Dataset não possui dimensão de nível.")

        # Garante que as dimensões estão na ordem (lat, lon) ou (y, x)
        # rasterio espera (height, width) para 2D
        if 'lat' in current_data_slice.dims and 'lon' in current_data_slice.dims:
            if list(current_data_slice.dims) != ['lat', 'lon']:
                current_data_slice = current_data_slice.transpose('lat', 'lon')
                print("Dimensões reordenadas para (lat, lon).")
        else:
            print("Aviso: Dimensões espaciais não são 'lat' e 'lon'. A ordem pode estar incorreta.")

        # Realiza o flip vertical na dimensão 'lat' se solicitado
        if flip_lat and 'lat' in current_data_slice.dims:
            # Inverte a ordem dos valores na dimensão 'lat'
            current_data_slice = current_data_slice.isel(lat=slice(None, None, -1))
            print("Flip vertical aplicado na dimensão 'lat'.")

        # Extrai as coordenadas e metadados
        height, width = current_data_slice.shape
        min_lon = float(current_data_slice.lon.min())
        max_lon = float(current_data_slice.lon.max())
        min_lat = float(current_data_slice.lat.min())
        max_lat = float(current_data_slice.lat.max())

        # Calcula a transformação (affine transform)
        # from_bounds(left, bottom, right, top, width, height)
        transform = from_bounds(min_lon, min_lat, max_lon, max_lat, width, height)

        # Define o CRS (Coordinate Reference System)
        crs = CRS.from_epsg(4326) # WGS84 é um CRS muito comum para lat/lon

        # Prepara o perfil para o GeoTIFF de saída
        profile = {
            "driver": "GTiff",
            "height": height,
            "width": width,
            "count": 1,  # Número de bandas (uma para a variável de temperatura)
            "dtype": current_data_slice.dtype,
            "crs": crs,
            "transform": transform,
            "compress": compression,
            "tiled": True,
            "blockxsize": tile_size,
            "blockysize": tile_size,
            "BIGTIFF": "IF_NEEDED", # Para arquivos maiores que 4GB
            "nodata": current_data_slice.attrs.get("_FillValue", None) # Usar _FillValue do NetCDF como nodata
        }

        # Adiciona otimizações para COG (overviews)
        profile.update(
            {
                "driver": "GTiff",
                "tiled": True,
                "blockxsize": tile_size,
                "blockysize": tile_size,
                "compress": compression,
                "predictor": 2 if compression in ["DEFLATE", "LZW"] else 1,
            }
        )
        # Constrói o nome do arquivo de saída com o timestamp
        if time_dim_exists:
            # --- THE CRITICAL FIX FOR YOUR ERROR ---
            time_value_from_netcdf = data_array['time'][time_idx].values
            
            # Scenario A: xarray loaded it as numpy.datetime64 (this is the ideal case)
            if isinstance(time_value_from_netcdf, np.datetime64):
                current_timestamp_dt = pd.to_datetime(time_value_from_netcdf)
                print(f"DEBUG: Time value is numpy.datetime64: {current_timestamp_dt}")
            # Scenario B: xarray loaded it as an integer/float (your current problem)
            elif isinstance(time_value_from_netcdf, (int, float, np.integer, np.floating)):
                # --- YOU MUST ADAPT THIS PART BASED ON YOUR NETCDF'S 'units' ATTRIBUTE ---
                # Example 1: if 'units' is 'hours since 2015-02-02 00:00:00'
                # base_time = datetime.datetime(2015, 2, 2, 0, 0)
                # current_timestamp_dt = base_time + datetime.timedelta(hours=int(time_value_from_netcdf))
                
                # Example 2: if 'units' is 'days since 1900-01-01 00:00:00'
                # base_time = datetime.datetime(1900, 1, 1, 0, 0)
                # current_timestamp_dt = base_time + datetime.timedelta(days=int(time_value_from_netcdf))

                # Example 3: If the integer is just an *index* (0, 1, 2...)
                # and your NetCDF is, for instance, hourly data starting from a fixed date
                # You'd need to know the actual start_date of your data series
                
                # For the given traceback, your file might be 'Eta10_C00_2015020200_TP2M.nc'
                # and if its internal 'time' coordinate are just integers like 0, 3, 6, 9...
                # representing hours from 2015-02-02 00:00:00
                
                # Assuming 'hours since 2015-02-02 00:00:00' for demonstration
                # You MUST replace this with the correct base time and unit from your NetCDF's 'units' attribute!
                # If your 'units' attribute is different, this line below is the one to change.
                base_time_epoch = datetime.datetime(2015, 2, 2, 0, 0) 
                time_unit = datetime.timedelta(hours=1) # Or days=1, minutes=1, etc.
                current_timestamp_dt = base_time_epoch + time_unit * int(time_value_from_netcdf)
                print(f"DEBUG: Time value is int/float. Converted to: {current_timestamp_dt}")
            else:
                # Fallback for unexpected types
                print(f"ERROR: Unexpected type for time coordinate: {type(time_value_from_netcdf)}. Attempting to use as is.")
                current_timestamp_dt = time_value_from_netcdf # This will likely fail if not datetime-like
            
            timestamp_str = current_timestamp_dt.strftime('%Y%m%d%H')
            
            output_cog_filename = f"{variable_name}_{timestamp_str}.tif"
        else:
            output_cog_filename = f"{variable_name}_single_slice.tif"

        full_output_cog_path = os.path.join(output_base_dir, output_cog_filename)

        print(f"Escrevendo COG em: '{full_output_cog_path}'")

        try:
            with rasterio.open(full_output_cog_path, "w", **profile) as dst:
                dst.write(current_data_slice.values, 1) # Escreve a banda 1
                # Gera overviews para otimização COG
                dst.build_overviews([2, 4, 8, 16, 32], resampling=rasterio.enums.Resampling.average)
                dst.update_tags(ns='rio_overview', resampling='average')
                # Adiciona metadados do xarray como tags no GeoTIFF
                dst.update_tags(**current_data_slice.attrs)
                if time_dim_exists:
                    dst.update_tags(NETCDF_TIME=current_timestamp.isoformat()) # Adiciona o tempo como tag ISO
            print(f"Conversão para COG concluída com sucesso: '{full_output_cog_path}'")
        except Exception as e:
            print(f"Erro ao escrever o arquivo COG '{full_output_cog_path}': {e}")
            
    ds.close() # Garante que o dataset NetCDF é fechado após o loop
    print("\nProcessamento de sequência de COGs concluído.")

# --- Exemplo de Uso ---
if __name__ == "__main__":
    netcdf_input_file = "Eta10_C00_2015020200_TP2M.nc"
    output_cogs_directory = "output_cogs_sequence"
    variable_to_convert = "TP2M"

    # 2. Converta o arquivo NetCDF para uma sequência de COGs
    # Isso irá gerar um COG para cada passo de tempo no NetCDF de exemplo
    convert_netcdf_to_cog_sequence(
        netcdf_input_file,
        variable_to_convert,
        output_cogs_directory,
        level_index=0, # Converte apenas o primeiro nível (se houver)
        start_time_index=0, # Começa do primeiro passo de tempo
        end_time_index=None, # Vai até o último passo de tempo
        compression="DEFLATE",
        tile_size=512
    )

    # Exemplo: Converter apenas um subconjunto de passos de tempo (do índice 1 ao 3)
    # output_cogs_directory_subset = "output_cogs_subset"
    # convert_netcdf_to_cog_sequence(
    #     netcdf_input_file,
    #     variable_to_convert,
    #     output_cogs_directory_subset,
    #     level_index=0,
    #     start_time_index=1,
    #     end_time_index=3,
    #     compression="LZW",
    #     tile_size=256
    # )

    # Opcional: Limpar o arquivo NetCDF de exemplo e as pastas de saída
    # import shutil
    # if os.path.exists(netcdf_input_file):
    #     os.remove(netcdf_input_file)
    # if os.path.exists(output_cogs_directory):
    #     shutil.rmtree(output_cogs_directory)
    # if os.path.exists(output_cogs_directory_subset):
    #     shutil.rmtree(output_cogs_directory_subset)
    # print("\nArquivos de exemplo e diretórios de saída removidos.")
