import numpy as np
from netCDF4 import Dataset, date2num, num2date
from datetime import datetime, timedelta
import os
import sys
import re
def read_multi_field_binary_all_at_once(input_filename, field_definitions):
    """
    Reads a binary file containing multiple concatenated data fields by reading the
    entire file at once and then slicing it.

    Args:
        input_filename (str): The path to the binary file.
        field_definitions (list of dict): A list where each dictionary defines a field:
            Example:
            [
                {'name': 'temperature', 'shape': (3, 4, 5, 10), 'dtype': np.float32},
                {'name': 'pressure', 'shape': (3, 5, 10), 'dtype': np.float32}
            ]
            The order of dictionaries in this list MUST match the order of fields in the binary file.

    Returns:
        dict: A dictionary where keys are field names and values are NumPy arrays.
              Returns None if the file cannot be read or processed.
    """
    print(f"Reading data from binary file (all at once): {input_filename}...")
    if not os.path.exists(input_filename):
        print(f"Error: Binary file '{input_filename}' not found.")
        return None

    if not field_definitions:
        print("Error: No field definitions provided.")
        return None
    
    # Determine the overall data type for reading the entire file
    common_dtype_size = field_definitions[0]['dtype']().itemsize 
    total_expected_elements = 0
    for field_def in field_definitions:
        if not all(isinstance(dim, int) for dim in field_def['shape']):
            raise TypeError(f"Shape dimensions for field '{field_def['name']}' must be integers. Got: {field_def['shape']}")
        total_expected_elements += np.prod(field_def['shape'])
        if field_def['dtype']().itemsize != common_dtype_size:
             print(f"Warning: Field '{field_def['name']}' has a different itemsize ({field_def['dtype']().itemsize}) than the first field ({common_dtype_size}). This function assumes uniform itemsize for `np.fromfile` for simplicity. Adjust logic if this is not the case for your data.")

    expected_file_size = total_expected_elements * common_dtype_size
    actual_file_size = os.path.getsize(input_filename)

    if actual_file_size != expected_file_size:
        print(f"Warning: Binary file size mismatch for '{input_filename}'. Expected {expected_file_size} bytes, got {actual_file_size} bytes.")
        print("This might lead to incorrect data parsing.")

    try:
        raw_data = np.fromfile(input_filename, dtype=field_definitions[0]['dtype'])
        
        separated_data = {}
        current_element_offset = 0

        for field_def in field_definitions:
            field_name = field_def['name']
            field_shape = field_def['shape']
            field_dtype = field_def['dtype']
            
            num_elements_in_field = np.prod(field_shape)
            
            field_data_flat = raw_data[current_element_offset : current_element_offset + num_elements_in_field]
            field_data = field_data_flat.reshape(field_shape).astype(field_dtype) 
            
            separated_data[field_name] = field_data
            current_element_offset += num_elements_in_field
            
            print(f"  Read field '{field_name}' (all at once method): Shape {field_data.shape}, Dtype {field_data.dtype}")

        return separated_data

    except Exception as e:
        print(f"An error occurred while reading or processing the binary file (all at once): {e}")
        return None

def read_multi_field_binary_one_by_one(input_filename, field_definitions):
    """
    Reads a binary file containing multiple concatenated data fields by reading
    each field sequentially from the file.

    Args:
        input_filename (str): The path to the binary file.
        field_definitions (list of dict): A list where each dictionary defines a field:
            Example:
            [
                {'name': 'temperature', 'shape': (3, 4, 5, 10), 'dtype': np.float32},
                {'name': 'pressure', 'shape': (3, 5, 10), 'dtype': np.float32}
            ]
            The order of dictionaries in this list MUST match the order of fields in the binary file.

    Returns:
        dict: A dictionary where keys are field names and values are NumPy arrays.
              Returns None if the file cannot be read or processed.
    """
    print(f"Reading data from binary file (one by one): {input_filename}...")
    if not os.path.exists(input_filename):
        print(f"Error: Binary file '{input_filename}' not found.")
        return None

    if not field_definitions:
        print("Error: No field definitions provided.")
        return None

    separated_data = {}
    try:
        with open(input_filename, 'rb') as f:
            for field_def in field_definitions:
                field_name = field_def['name']
                field_shape = field_def['shape']
                field_dtype = field_def['dtype']

                if not all(isinstance(dim, int) for dim in field_shape):
                    raise TypeError(f"Shape dimensions for field '{field_name}' must be integers. Got: {field_shape}")

                num_elements_in_field = np.prod(field_shape)
                bytes_to_read = num_elements_in_field * field_dtype().itemsize

                # Read the exact number of bytes for the current field
                binary_chunk = f.read(bytes_to_read)

                if len(binary_chunk) != bytes_to_read:
                    print(f"Warning: Unexpected end of file while reading field '{field_name}'. Expected {bytes_to_read} bytes, read {len(binary_chunk)} bytes.")
                    # Handle incomplete read (e.g., fill with NaNs or raise error)
                    # For now, we'll proceed with what was read, which might cause reshape errors
                    
                # Convert the bytes to a NumPy array and reshape
                field_data = np.frombuffer(binary_chunk, dtype=field_dtype).reshape(field_shape)
                
                separated_data[field_name] = field_data
                print(f"  Read field '{field_name}' (one by one method): Shape {field_data.shape}, Dtype {field_data.dtype}")
        return separated_data

    except Exception as e:
        print(f"An error occurred while reading or processing the binary file (one by one): {e}")
        return None


def create_multi_variable_netcdf(output_filename="multi_variable_data.nc", input_binary_filename="input_multi_field.bin"
                                 ,Eta_variables="Eta_variables_data",Eta_levels="Eta_levels_data",fct_int=1):
    """
    Creates a NetCDF file with multiple variables (temperature, pressure)
    and their associated dimensions and attributes.

    Args:
        output_filename (str): The name of the NetCDF file to create.
        input_binary_filename (str): The name of the dummy binary file to read from.
    """
    from EtaLibrary import read_specific_grid_info
    import importlib

    try:
        # Dynamically import the module
        eta_variables_module = importlib.import_module(Eta_variables)
        variable_data = eta_variables_module.variable_data
        eta_levels_module = importlib.import_module(Eta_levels)
        num_levels = eta_levels_module.num_levels
        levels = eta_levels_module.levels

        print(f"Successfully imported variable_data from {module_name}.py")
        print(variable_data)

    except ImportError:
     print(f"Error: Module '{module_name}' not found. Make sure '{module_name}.py' is in your Python path.")
    except AttributeError:
     print(f"Error: 'variable_data' not found in module '{module_name}'.")
    except Exception as e:
     print(f"An unexpected error occurred: {e}")

    # --- 1. Define Dimensions and Coordinate Data ---
    # These dimensions must match how your binary data was originally structured
    match = re.match(r'(.+?)_(.+?)_(\d{10})\+(\d{10})_(.+?)\.(.+)',input_binary_filename)
    if match:
        model_part = match.group(1)
        run_code = match.group(2)
        start_time = match.group(3)
        end_time = match.group(4)
        data_type = match.group(5)
        extension = match.group(6)
    print(end_time)
    num_time_steps = 1
    start_date = datetime(int(end_time[0:4]),int(end_time[4:6]),int(end_time[6:8]),int(end_time[8:10]), 0, 0)
    times = [start_date + timedelta(hours=i * int(fct_int)) for i in range(num_time_steps)]
# Example dimensions (adjust these for your actual data)
    grid_info_file = "grid_info.txt"
    parsed_grid_data = read_specific_grid_info(grid_info_file)
    num_lats = parsed_grid_data['nlats']
    lat_South = parsed_grid_data['south']
    lat_North = parsed_grid_data['north']
    lats = np.linspace(lat_South, lat_North,  num_lats)

    num_lons = parsed_grid_data['nlons']
    lon_West = parsed_grid_data['west']
    lon_East = parsed_grid_data['east']
    lons = np.linspace(lon_West, lon_East, num_lons)

    print(levels) 
    print(num_levels) 

    # Define the data type for reading the binary file
    binary_data_type = np.float32 # Assuming your binary data is float32

    # --- Define the properties of each variable to be written to NetCDF ---
    # This list will be looped through to create variables dynamically.
    # The 'shape_dims' should correspond to the NetCDF dimensions defined below.
    # The 'binary_shape' should match how the data is structured in the binary file.
    # Data for each variable (excluding common 'shape_dims', 'binary_shape', 'dtype')

    variable_definitions_for_netcdf = []

    for var_info in variable_data:
        variable_definition = {
                'name': var_info['name'],
                'shape_dims': ('time','level', 'lat', 'lon'),
                'binary_shape': (num_time_steps, num_levels, num_lats, num_lons), # Shape as read from binary
                'dtype': np.float32,
                'long_name': var_info['long_name'],
                'units': var_info['units'],
                'standard_name': var_info['standard_name'],
                'fill_value': var_info['fill_value'],
                'chunksizes': (1, 1, num_lats, num_lons), # Chunk by time and level, full spatial slice
                'zlib':True, 
                'complevel':1
        }
        variable_definitions_for_netcdf.append(variable_definition)

    # You can now print to verify
    for var_def in variable_definitions_for_netcdf:
        print(var_def)

    # --- Create a dummy binary file for demonstration ---
    # This step is for demonstration only. In your actual use case,
    # you would already have your 'input_multi_field.bin' file.
    # We pass the binary_shape from our variable definitions to the dummy file creator
    #create_dummy_binary_file(input_binary_filename, 
    #                         [ {'name': v['name'], 'shape': v['binary_shape']} for v in variable_definitions_for_netcdf ],
    #                         binary_data_type)

    # --- 2. Read Data from the Binary File with Multiple Fields (using the one-by-one method) ---
    # You can switch between read_multi_field_binary_all_at_once or read_multi_field_binary_one_by_one
    # Ensure the field_definitions passed here match the structure of your binary file.
    # We use the 'binary_shape' from our variable definitions for reading.
    binary_field_definitions = [
        {'name': v['name'], 'shape': v['binary_shape'], 'dtype': v['dtype']}
        for v in variable_definitions_for_netcdf
    ]
    read_data = read_multi_field_binary_one_by_one(input_binary_filename, binary_field_definitions)
    
    if read_data is None:
        print("Failed to read data from binary file. Exiting.")
        return # Exit if reading failed

    # --- 3. Create the NetCDF file ---
    with Dataset(output_filename, 'w', format='NETCDF4') as nc_file:
        # --- 4. Create Dimensions in the NetCDF file ---
        nc_file.createDimension('time', None)
        nc_file.createDimension('lat', num_lats)
        nc_file.createDimension('lon', num_lons)
        nc_file.createDimension('level', num_levels)

        # --- 5. Create Coordinate Variables ---
        time_var = nc_file.createVariable('time', 'f8', ('time',))
        lat_var = nc_file.createVariable('lat', 'f4', ('lat',))
        lon_var = nc_file.createVariable('lon', 'f4', ('lon',))
        level_var = nc_file.createVariable('level', 'i4', ('level',))

        # Add attributes to coordinate variables
        time_var.units = 'hours since '+end_time[0:4]+'-'+end_time[4:6]+'-'+end_time[6:8]+' '+end_time[8:10]+':00:00'
        time_var.calendar = 'gregorian'
        time_var.long_name = 'Time'

        lat_var.units = 'degrees_north'
        lat_var.long_name = 'Latitude'
        lat_var.standard_name = 'latitude'
        lat_var.axis = 'Y'

        lon_var.units = 'degrees_east'
        lon_var.long_name = 'Longitude'
        lon_var.standard_name = 'longitude'
        lon_var.axis = 'X'

        level_var.units = 'hPa'
        level_var.long_name = 'Pressure Level'
        level_var.standard_name = 'pressure'
        level_var.axis = 'Z'
        level_var.positive = 'down'

        # --- 6. Create Data Variables using a Loop ---
        created_nc_vars = {} # To store references to the created NetCDF variables
        for var_def in variable_definitions_for_netcdf:
            var_name = var_def['name']
            var_shape_dims = var_def['shape_dims']
            var_dtype = var_def['dtype']
            var_fill_value = var_def['fill_value']
            var_chunksizes = var_def.get('chunksizes', None) # Get chunksizes, default to None

            nc_var = nc_file.createVariable(var_name, var_dtype, var_shape_dims, fill_value=var_fill_value
                                            ,chunksizes=var_chunksizes, # Pass the chunksizes here
                                            zlib=True,       # Enable zlib compression
                                            complevel=5      # Compression level (0-9, 9 is highest, 1 is fastest)
            )
            nc_var.long_name = var_def.get('long_name', var_name)
            nc_var.units = var_def.get('units', '1') # Default to dimensionless
            nc_var.standard_name = var_def.get('standard_name', var_name)
            nc_var.coordinates = ' '.join(var_shape_dims) # Auto-assign coordinates based on dimensions


            created_nc_vars[var_name] = nc_var

        # --- 7. Write Data to Variables ---
        time_var[:] = date2num(times, time_var.units, time_var.calendar)
        lat_var[:] = lats
        lon_var[:] = lons
        level_var[:] = levels

        # Write data from the 'read_data' dictionary to the NetCDF variables using the loop
        for var_name, nc_var_obj in created_nc_vars.items():
            if var_name in read_data:
                nc_var_obj[:] = read_data[var_name]
            else:
                print(f"Warning: Data for variable '{var_name}' not found in read binary data. Skipping.")


        # --- 8. Add Global Attributes ---
        nc_file.title = 'Eta Model - Version 1.4.4'
        nc_file.institution = 'CPTEC/INPE'
        nc_file.source = 'Data read from binary file and converted to NetCDF by Python script'
        nc_file.history = f'Created on {datetime.now().isoformat()}'
        nc_file.Conventions = 'CF-1.8'

    print(f"Successfully created NetCDF file: {output_filename}")
    print(f"File size: {os.path.getsize(output_filename) / (1024*1024):.2f} MB")

    # --- 9. Verify the file (Optional) ---
    print("\n--- Verifying file content ---")
    with Dataset(output_filename, 'r') as nc_file_read:
        print("Dimensions:", nc_file_read.dimensions.keys())
        print("Variables:", nc_file_read.variables.keys())
        print("\nGlobal Attributes:")
        for attr_name in nc_file_read.ncattrs():
            print(f"  {attr_name}: {getattr(nc_file_read, attr_name)}")
        
        # Loop through the created variables to print info
        for var_def in variable_definitions_for_netcdf:
            var_name = var_def['name']
            if var_name in nc_file_read.variables:
                read_var = nc_file_read.variables[var_name]
                print(f"\n{var_name.capitalize()} Variable Info:")
                print(f"  Shape: {read_var.shape}")
                print(f"  Units: {read_var.units}")
                # Print first value, handling different dimensions
                if read_var.ndim == 4:
                    print(f"  First value: {read_var[0, 0, 0, 0]:.2f}")
                elif read_var.ndim == 3:
                    print(f"  First value: {read_var[0, 0, 0]:.2f}")
                elif read_var.ndim == 2:
                    print(f"  First value: {read_var[0, 0]:.2f}")
                elif read_var.ndim == 1:
                    print(f"  First value: {read_var[0]:.2f}")
            else:
                print(f"\nVariable '{var_name}' not found in the created NetCDF file.")

        times_read_num = nc_file_read.variables['time'][:]
        times_read_dt = num2date(times_read_num, nc_file_read.variables['time'].units, nc_file_read.variables['time'].calendar)
        print(f"\nTime values (first and last): {times_read_dt[0]}, {times_read_dt[-1]}")


if __name__ == "__main__":
    print(f"Script name: {sys.argv[0]}")

    if len(sys.argv) > 1:
        input_file = sys.argv[1]
        print(f"First argument: ",input_file)
    if len(sys.argv) > 2:
        output_file = sys.argv[2]
        print(f"Second argument: ",output_file)
    if len(sys.argv) > 3:
        variable_data_file= sys.argv[3]
        print(f"Third argument: ",variable_data_file)
    if len(sys.argv) > 4:
        levels_file= sys.argv[4]
        print(f"Fourth argument: ",levels_file)
    if len(sys.argv) > 5:
        ftc_time_int= sys.argv[5]
        print(f"Fourth argument: ",levels_file)
    print("All arguments:", sys.argv)
    #output_file = "Eta40_C00_2025070400_FF.nc"
    #input_file  = "Eta40_C00_2025070400_FF.bin"
    create_multi_variable_netcdf(output_filename=output_file, input_binary_filename=input_file
                                 , Eta_variables=variable_data_file, Eta_levels=levels_file
                                 ,fct_int=ftc_time_int)
