import xarray as xr
import numpy as np
import rasterio
from rasterio.transform import from_bounds
from rasterio.crs import CRS
import os
import pandas as pd # For creating dates in the example

def convert_netcdf_to_geotiff(
    netcdf_path,
    variable_name,
    output_geotiff_path,
    time_index=0, # Index of the time step to be converted
    compression=None, # Optional: "DEFLATE", "LZW", "JPEG", etc.
    flip_lat=True # NOVO PARÂMETRO: Se True, inverte a dimensão da latitude
):
    """
    Converts a specific variable from a NetCDF file (for a single time step)
    to a standard GeoTIFF file.

    Args:
        netcdf_path (str): Path to the input NetCDF file.
        variable_name (str): Name of the variable to convert (e.g., 'temperature').
        output_geotiff_path (str): Path for the output GeoTIFF file.
        time_index (int): The index of the time step to extract and convert.
                          A GeoTIFF typically represents a 2D spatial slice.
        compression (str, optional): Compression method for the GeoTIFF.
                                     Common options: "DEFLATE", "LZW", "JPEG".
                                     Defaults to None (no compression).
    """
    print(f"\nStarting conversion of '{netcdf_path}' to GeoTIFF...")
    print(f"Variable: '{variable_name}', Time step (index): {time_index}")

    try:
        # Open the NetCDF dataset
        ds = xr.open_dataset(netcdf_path)
    except FileNotFoundError:
        print(f"Error: NetCDF file not found at '{netcdf_path}'")
        return
    except Exception as e:
        print(f"Error opening NetCDF file: {e}")
        return

    if variable_name not in ds.data_vars:
        print(f"Error: Variable '{variable_name}' not found in the dataset.")
        print(f"Available variables: {list(ds.data_vars.keys())}")
        return

    data_array = ds[variable_name]

    # Select the specific time step
    if 'time' in data_array.dims:
        if time_index >= len(data_array['time']):
            print(f"Error: Time index {time_index} is out of bounds (0 to {len(data_array['time']) - 1}).")
            return
        data_slice = data_array.isel(time=time_index)
        print(f"Converting data for time: {data_slice['time'].item()}")
    else:
        # If no time dimension, assume it's already 2D spatial
        data_slice = data_array
        print("Dataset does not have a time dimension, converting directly.")

    # Ensure spatial dimensions are in (lat, lon) or (y, x) order for rasterio
    # rasterio expects (height, width) for 2D data
    if 'lat' in data_slice.dims and 'lon' in data_slice.dims:
        if list(data_slice.dims) != ['lat', 'lon']:
            data_slice = data_slice.transpose('lat', 'lon')
    else:
        print("Warning: Spatial dimensions are not 'lat' and 'lon'. Please ensure correct dimension order.")
    # NOVO: Realiza o flip vertical na dimensão 'lat' se solicitado
    if flip_lat and 'lat' in data_slice.dims:
        data_slice = data_slice.isel(lat=slice(None, None, -1))
        print("Flip vertical aplicado na dimensão 'lat'.")

    # Extract coordinates and metadata
    height, width = data_slice.shape
    min_lon = float(data_slice.lon.min())
    max_lon = float(data_slice.lon.max())
    min_lat = float(data_slice.lat.min())
    max_lat = float(data_slice.lat.max())

    # Calculate the affine transform
    # from_bounds(left, bottom, right, top, width, height)
    transform = from_bounds(min_lon, min_lat, max_lon, max_lat, width, height)

    # Define the CRS (Coordinate Reference System)
    # Assuming WGS84 (latitude/longitude) - EPSG:4326 is common
    crs = CRS.from_epsg(4326) # WGS84

    # Prepare the profile for the output GeoTIFF
    profile = {
        "driver": "GTiff",
        "height": height,
        "width": width,
        "count": 1,  # Number of bands (one for the temperature variable)
        "dtype": data_slice.dtype,
        "crs": crs,
        "transform": transform,
        "nodata": data_slice.attrs.get("_FillValue", None) # Use _FillValue from NetCDF as nodata
    }

    # Add compression if specified
    if compression:
        profile["compress"] = compression
        # Predictor can sometimes improve compression for certain data types
        if compression in ["DEFLATE", "LZW"]:
            profile["predictor"] = 2 # Horizontal differencing

    # Create the output directory if it doesn't exist
    output_dir = os.path.dirname(output_geotiff_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    print(f"Writing GeoTIFF to: '{output_geotiff_path}' with profile: {profile}")

    # Write the GeoTIFF file
    try:
        with rasterio.open(output_geotiff_path, "w", **profile) as dst:
            dst.write(data_slice.values, 1) # Write the data to band 1
            # Add xarray attributes as tags in the GeoTIFF
            dst.update_tags(**data_slice.attrs)
            # Add the time value as a tag for reference
            if 'time' in data_slice.coords:
                dst.update_tags(NETCDF_TIME=str(data_slice['time'].item()))
        print(f"Conversion to GeoTIFF completed successfully: '{output_geotiff_path}'")
    except Exception as e:
        print(f"Error writing the GeoTIFF file: {e}")
    finally:
        ds.close() # Ensure the NetCDF dataset is closed

# --- How to use the function ---
if __name__ == "__main__":
    netcdf_input_file = "Eta20_2025061300_prec.nc"
    output_geotiff_dir = "output_geotiffs"
    output_geotiff_file_1 = os.path.join(output_geotiff_dir, "precipitation_slice_0.tif")
    output_geotiff_file_2 = os.path.join(output_geotiff_dir, "precipitation_slice_3.tif")
    variable_to_convert = "prec"

    # 2. Convert the NetCDF file to GeoTIFF (first time step, with compression)
    convert_netcdf_to_geotiff(
        netcdf_input_file,
        variable_to_convert,
        output_geotiff_file_1,
        time_index=0,
        compression="DEFLATE" # Optional compression
    )

    # 3. Convert another time step (e.g., index 3, without compression)
    convert_netcdf_to_geotiff(
        netcdf_input_file,
        variable_to_convert,
        output_geotiff_file_2,
        time_index=3,
        compression=None # No compression
    )

    print("\nScript finished.")
    print(f"Check the '{output_geotiff_dir}' directory for generated GeoTIFFs.")

    # Opcional: Limpar o arquivo NetCDF de exemplo e a pasta de saída
    # os.remove(netcdf_input_file)
    # import shutil
    # if os.path.exists(output_geotiff_dir):
    #     shutil.rmtree(output_geotiff_dir)
    # print("\nArquivos de exemplo e diretório de saída removidos.")
