import xarray as xr
import numpy as np
import rasterio
from rasterio.transform import from_bounds
from rasterio.crs import CRS
import os
import sys
import pandas as pd # Usado para criar datas no NetCDF de exemplo
import datetime # Usado para formatar o timestamp do nome do arquivo
import argparse # Import argparse

def convert_netcdf_to_cog_sequence(
    netcdf_path,
    variable_name,
    arq_prefix,
    output_base_dir, # Agora é um diretório base, não um caminho de arquivo
    level_index=0,
    level_value=None,
    start_time_index=0, # NOVO: Índice de início para o loop de tempo
    end_time_index=None, # NOVO: Índice de fim para o loop de tempo (None para ir até o final)
    compression="DEFLATE",
    tile_size=256,
    flip_lat=True
):
    """
    Converte uma variável de um arquivo NetCDF (para múltiplos passos de tempo)
    para uma sequência de GeoTIFFs otimizados para nuvem (COG).

    Args:
        netcdf_path (str): Caminho para o arquivo NetCDF de entrada.
        variable_name (str): Nome da variável a ser convertida (ex: 'temperature').
        output_base_dir (str): Caminho para o diretório onde os COGs de saída serão salvos.
        level_index (int): O índice do nível a ser extraído e convertido (se a dimensão 'level' existir).
        start_time_index (int): O índice do primeiro passo de tempo a ser convertido.
        end_time_index (int): O índice do último passo de tempo a ser convertido (inclusive).
                               Se None, converte até o final.
        compression (str): Método de compressão para o COG.
        tile_size (int): Tamanho dos blocos (tiles) internos do COG (ex: 256x256 pixels).
        flip_lat (bool): Se True, inverte a dimensão da latitude para garantir que a origem
                         seja superior esquerda (necessário para alguns NetCDFs).
    """
    print(f"\nIniciando conversão de '{netcdf_path}' para sequência de COGs...")
    print(f"Variável: '{variable_name}', Nível (índice): {level_index}")

    try:
        ds = xr.open_dataset(netcdf_path)
    except FileNotFoundError:
        print(f"Erro: Arquivo NetCDF não encontrado em '{netcdf_path}'")
        return
    except Exception as e:
        print(f"Erro ao abrir o arquivo NetCDF: {e}")
        return

    if variable_name not in ds.data_vars:
        print(f"Erro: Variável '{variable_name}' não encontrada no dataset.")
        print(f"Variáveis disponíveis: {list(ds.data_vars.keys())}")
        ds.close()
        return

    data_array = ds[variable_name]

    # Cria o diretório de saída base se não existir
    if not os.path.exists(output_base_dir):
        os.makedirs(output_base_dir)
        print(f"Diretório de saída criado: '{output_base_dir}'")

    # Determina o intervalo de tempo para o loop
    time_dim_exists = 'time' in data_array.dims
    if time_dim_exists:
        num_times = len(data_array['time'])
        if end_time_index is None:
            end_time_index = num_times - 1
        
        if not (0 <= start_time_index < num_times and 0 <= end_time_index < num_times and start_time_index <= end_time_index):
            print(f"Erro: Índices de tempo inválidos. Início: {start_time_index}, Fim: {end_time_index}, Total: {num_times}.")
            ds.close()
            return
        
        time_indices_to_process = range(start_time_index, end_time_index + 1)
        print(f"Processando passos de tempo de {start_time_index} a {end_time_index}.")
    else:
        time_indices_to_process = [0] # Processa uma única "fatia" se não houver dimensão de tempo
        print("Dataset não possui dimensão de tempo, convertendo diretamente (assumindo um único passo).")


    # LOOP PRINCIPAL SOBRE OS PASSOS DE TEMPO
    for i, time_idx in enumerate(time_indices_to_process):
        print(f"\n--- Processando passo de tempo (índice global): {time_idx} ---")
        
        current_data_slice = data_array.isel(time=time_idx) if time_dim_exists else data_array

        # Seleciona o nível
        if 'level' in current_data_slice.dims:
            if level_index >= len(current_data_slice['level']):
                print(f"Erro: Índice de nível {level_index} fora dos limites (0 a {len(current_data_slice['level']) - 1}). Pulando este passo de tempo.")
                continue # Pula para o próximo passo de tempo
            current_data_slice = current_data_slice.isel(level=level_index)
            print(f"Convertendo dados para o nível: {current_data_slice['level'].item()}")
        else:
            print("Dataset não possui dimensão de nível.")

        # Garante que as dimensões estão na ordem (lat, lon) ou (y, x)
        # rasterio espera (height, width) para 2D
        if 'lat' in current_data_slice.dims and 'lon' in current_data_slice.dims:
            if list(current_data_slice.dims) != ['lat', 'lon']:
                current_data_slice = current_data_slice.transpose('lat', 'lon')
                print("Dimensões reordenadas para (lat, lon).")
        else:
            print("Aviso: Dimensões espaciais não são 'lat' e 'lon'. A ordem pode estar incorreta.")

        # Realiza o flip vertical na dimensão 'lat' se solicitado
        if flip_lat and 'lat' in current_data_slice.dims:
            # Inverte a ordem dos valores na dimensão 'lat'
            current_data_slice = current_data_slice.isel(lat=slice(None, None, -1))
            print("Flip vertical aplicado na dimensão 'lat'.")

        # Extrai as coordenadas e metadados
        height, width = current_data_slice.shape
        min_lon = float(current_data_slice.lon.min())
        max_lon = float(current_data_slice.lon.max())
        min_lat = float(current_data_slice.lat.min())
        max_lat = float(current_data_slice.lat.max())

        # Calcula a transformação (affine transform)
        # from_bounds(left, bottom, right, top, width, height)
        transform = from_bounds(min_lon, min_lat, max_lon, max_lat, width, height)

        # Define o CRS (Coordinate Reference System)
        crs = CRS.from_epsg(4326) # WGS84 é um CRS muito comum para lat/lon

        # Prepara o perfil para o GeoTIFF de saída
        profile = {
            "driver": "GTiff",
            "height": height,
            "width": width,
            "count": 1,  # Número de bandas (uma para a variável de temperatura)
            "dtype": current_data_slice.dtype,
            "crs": crs,
            "transform": transform,
            "compress": compression,
            "tiled": True,
            "blockxsize": tile_size,
            "blockysize": tile_size,
            "BIGTIFF": "IF_NEEDED", # Para arquivos maiores que 4GB
            "nodata": current_data_slice.attrs.get("_FillValue", None) # Usar _FillValue do NetCDF como nodata
        }

        # Adiciona otimizações para COG (overviews)
        profile.update(
            {
                "driver": "GTiff",
                "tiled": True,
                "blockxsize": tile_size,
                "blockysize": tile_size,
                "compress": compression,
                "predictor": 2 if compression in ["DEFLATE", "LZW"] else 1,
            }
        )
        # Constrói o nome do arquivo de saída com o timestamp
        current_timestamp_dt = None
        if time_dim_exists:
            time_value_from_netcdf = data_array['time'][time_idx].values
            
            # Scenario A: xarray loaded it as numpy.datetime64 (this is the ideal case)
            if isinstance(time_value_from_netcdf, np.datetime64):
                current_timestamp_dt = pd.to_datetime(time_value_from_netcdf)
                print(f"DEBUG: Time value is numpy.datetime64: {current_timestamp_dt}")
            # Scenario B: xarray loaded it as an integer/float (your current problem)
            elif isinstance(time_value_from_netcdf, (int, float, np.integer, np.floating)):
                # --- YOU MUST ADAPT THIS PART BASED ON YOUR NETCDF'S 'units' ATTRIBUTE ---
                # Example 1: if 'units' is 'hours since 2015-02-02 00:00:00'
                # base_time = datetime.datetime(2015, 2, 2, 0, 0)
                # current_timestamp_dt = base_time + datetime.timedelta(hours=int(time_value_from_netcdf))
                
                # Example 2: if 'units' is 'days since 1900-01-01 00:00:00'
                # base_time = datetime.datetime(1900, 1, 1, 0, 0)
                # current_timestamp_dt = base_time + datetime.timedelta(days=int(time_value_from_netcdf))

                # Example 3: If the integer is just an *index* (0, 1, 2...)
                # and your NetCDF is, for instance, hourly data starting from a fixed date
                # You'd need to know the actual start_date of your data series
                
                # For the given traceback, your file might be 'Eta10_C00_2015020200_TP2M.nc'
                # and if its internal 'time' coordinate are just integers like 0, 3, 6, 9...
                # representing hours from 2015-02-02 00:00:00
                
                # Assuming 'hours since 2015-02-02 00:00:00' for demonstration
                # You MUST replace this with the correct base time and unit from your NetCDF's 'units' attribute!
                # If your 'units' attribute is different, this line below is the one to change.
                base_time_epoch = datetime.datetime(2015, 2, 2, 0, 0) 
                time_unit = datetime.timedelta(hours=1) # Or days=1, minutes=1, etc.
                current_timestamp_dt = base_time_epoch + time_unit * int(time_value_from_netcdf)
                print(f"DEBUG: Time value is int/float. Converted to: {current_timestamp_dt}")
            else:
                # Fallback for unexpected types
                print(f"ERROR: Unexpected type for time coordinate: {type(time_value_from_netcdf)}. Attempting to use as is.")
                current_timestamp_dt = time_value_from_netcdf # This will likely fail if not datetime-like
            
            timestamp_str = current_timestamp_dt.strftime('%Y%m%d%H')
            
            output_cog_filename = f"{arq_prefix}_{timestamp_str}.tif"
        else:
            output_cog_filename = f"{arq_prefix}_single_slice.tif"

        full_output_cog_path = os.path.join(output_base_dir, output_cog_filename)

        print(f"Escrevendo COG em: '{full_output_cog_path}'")

        try:
            with rasterio.open(full_output_cog_path, "w", **profile) as dst:
                dst.write(current_data_slice.values, 1) # Escreve a banda 1
                # Gera overviews para otimização COG
                dst.build_overviews([2, 4, 8, 16, 32], resampling=rasterio.enums.Resampling.average)
                dst.update_tags(ns='rio_overview', resampling='average')
                # Adiciona metadados do xarray como tags no GeoTIFF
                dst.update_tags(**current_data_slice.attrs)
                if time_dim_exists and current_timestamp_dt is not None:
                    dst.update_tags(NETCDF_TIME=current_timestamp_dt.isoformat()) # Adiciona o tempo como tag ISO
            print(f"Conversão para COG concluída com sucesso: '{full_output_cog_path}'")
        except Exception as e:
            print(f"Erro ao escrever o arquivo COG '{full_output_cog_path}': {e}")
            
    ds.close() # Garante que o dataset NetCDF é fechado após o loop
    print("\nProcessamento de sequência de COGs concluído.")

# --- Exemplo de Uso ---
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Converte uma variável de um arquivo NetCDF para uma sequência temporal de COGs.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    # Required positional arguments
    parser.add_argument(
        "exp_prefix",
        type=str,
        help="Prefixo da experiência (e.g., 'Eta10_C00')."
    )
    parser.add_argument(
        "exp_init_condition",
        type=str,
        help="Condição inicial da experiência (e.g., '2015020200' para o diretório de data/hora)."
    )
    parser.add_argument(
        "variable",
        type=str,
        help="Nome da variável a ser convertida (e.g., 'TP2M', 'TEMP')."
    )
    parser.add_argument(
        "input_netcdf_base_dir",
        type=str,
        help="Diretório base onde o arquivo NetCDF de entrada está localizado."
    )
    parser.add_argument(
        "output_cogs_base_dir",
        type=str,
        help="Diretório base onde os COGs de saída serão salvos."
    )

    # Optional arguments
    parser.add_argument(
        "--level", # Changed from --level_index to just --level for clarity with string values
        type=str, # Keep as string to match '000750'
        default=None,
        help="Valor do nível a ser extraído do NetCDF (e.g., '000750'). Opcional para variáveis 2D."
    )
    parser.add_argument(
        "--start_date",
        type=str,
        default="2015020200",
        help="Data e hora de início para o processamento (formato YYYYMMDDHH)."
    )
    parser.add_argument(
        "--end_date",
        type=str,
        default="2015020300",
        help="Data e hora de fim para o processamento (formato YYYYMMDDHH)."
    )
    parser.add_argument(
        "--time_step_hours",
        type=int,
        default=1,
        help="Intervalo de tempo em horas entre os COGs gerados."
    )
    parser.add_argument(
        "--compression",
        type=str,
        default="DEFLATE",
        choices=["DEFLATE", "LZW", "JPEG", "ZSTD", "LZMA"],
        help="Método de compressão para os COGs de saída."
    )
    parser.add_argument(
        "--tile_size",
        type=int,
        default=512,
        help="Tamanho dos blocos (tiles) para os COGs de saída."
    )
    parser.add_argument(
        "--no_flip_lat",
        action="store_true",
        help="Não aplicar flip vertical na dimensão da latitude."
    )
    parser.add_argument(
        "--create_sample",
        action="store_true",
        help="Cria um arquivo NetCDF de exemplo para teste."
    )

    args = parser.parse_args()

    # Construct the full path for the input NetCDF file
    # This logic depends on your exact NetCDF file naming convention.
    # Adjust as needed.
    print(args.level)
    if args.level is None:
        # For 2D variables (TSFC)
        netcdf_input_filename = f"{args.exp_prefix}_{args.exp_init_condition}_{args.variable}.nc"
        cog_var_geotif_dir = f"{args.exp_prefix}/{args.exp_init_condition}/{args.variable}"
        cog_arq_prefix = f"{args.exp_prefix}_{args.variable}"
    else:
        # For 3D variables (TEMP)
        netcdf_input_filename = f"{args.exp_prefix}_{args.exp_init_condition}_{args.variable}_{args.level}.nc"
        cog_var_geotif_dir = f"{args.exp_prefix}/{args.exp_init_condition}/{args.variable}_{args.level}"
        cog_arq_prefix = f"{args.exp_prefix}_{args.variable}_{args.level}"
    
    netcdf_input_file = os.path.join(args.input_netcdf_base_dir, netcdf_input_filename)
    
    output_cogs_base_dir = os.path.join(args.output_cogs_base_dir, cog_var_geotif_dir)

    # Call the main conversion function with parsed arguments
    convert_netcdf_to_cog_sequence(
        netcdf_path=netcdf_input_file,
        variable_name=args.variable,
        arq_prefix=cog_arq_prefix,
        level_value=args.level, # Pass the optional level value
        output_base_dir=output_cogs_base_dir,
        compression=args.compression,
        tile_size=args.tile_size,
        flip_lat=not args.no_flip_lat
    )

    print("\nScript finished. Check the output directory for generated GeoTIFFs.")    
