import argparse
import datetime
import os
import sys
import cdsapi
import time

def parse_yyyymmddhh_to_datetime(date_string):
    """
    Converts a date string in 'YYYYMMDDHH' format to a datetime.datetime object.
    """
    if not isinstance(date_string, str) or len(date_string) != 10:
        raise ValueError("Input date_string must be a string of 10 characters in YYYYMMDDHH format.")
    try:
        dt_object = datetime.datetime.strptime(date_string, '%Y%m%d%H')
        return dt_object
    except ValueError as e:
        raise ValueError(f"Could not parse date string '{date_string}'. Ensure it's in YYYYMMDDHH format. Error: {e}")

def get_era5_data(
    initial_datetime: datetime.datetime,
    end_datetime: datetime.datetime, # CHANGED: num_hours to end_datetime
    variables: list[str],
    output_base_dir: str,
    frequency_hours: int = 6,
    dataset: str = "reanalysis-era5-single-levels",
    product_type: str = "reanalysis",
    area: list[float] = [50, -150, -75, 50] # [North, West, South, East]
):
    """
    Downloads ERA5 data from the Copernicus Climate Data Store (CDS) for a specified
    time range and list of variables, saving them as NetCDF files.

    Args:
        initial_datetime (datetime.datetime): The starting date and hour for data download.
        end_datetime (datetime.datetime): The ending date and hour for data download.
        variables (list[str]): A list of ERA5 variable names to download
                                (e.g., ["sea_surface_temperature", "2m_temperature"]).
        output_base_dir (str): The base directory where downloaded data will be saved.
                                Subdirectories will be created for each variable.
        frequency_hours (int): The frequency of data points (e.g., 6 for 6-hourly data).
        dataset (str): The CDS dataset ID (default: "reanalysis-era5-single-levels").
        product_type (str): The product type (default: "reanalysis").
        area (list[float]): Bounding box for the geographic area [North, West, South, East].
    """
    print(f"--- Starting ERA5 Data Download ---")
    print(f"Initial Date: {initial_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"End Date: {end_datetime.strftime('%Y-%m-%d %H:%M:%S')}") # CHANGED: Print end date
    print(f"Variables: {', '.join(variables)}")
    print(f"Output Base Directory: {output_base_dir}")
    print(f"Frequency: {frequency_hours} hours")
    print(f"Area: {area}")

    # Initialize CDS API client
    c = cdsapi.Client()

    # Collect unique years, months, days, and times within the range
    years = set()
    months = set()
    days = set()
    times = set()

    current_dt = initial_datetime
    while current_dt <= end_datetime: # CHANGED: Loop until end_datetime
        years.add(current_dt.strftime('%Y'))
        months.add(current_dt.strftime('%m'))
        days.add(current_dt.strftime('%d'))
        times.add(current_dt.strftime('%H:00'))
        current_dt += datetime.timedelta(hours=frequency_hours)

    # Sort them for consistent request structure
    sorted_years = sorted(list(years))
    sorted_months = sorted(list(months))
    sorted_days = sorted(list(days))
    sorted_times = sorted(list(times))

    print(f"\nCalculated request parameters:")
    print(f"  Years: {sorted_years}")
    print(f"  Months: {sorted_months}")
    print(f"  Days: {sorted_days}")
    print(f"  Times: {sorted_times}")

    for variable in variables:
        # Create output directory for the current variable
        variable_output_dir = os.path.join(output_base_dir, variable)
        os.makedirs(variable_output_dir, exist_ok=True)
        print(f"\nProcessing variable: '{variable}'")
        print(f"  Output directory: {variable_output_dir}")

        # Construct the target filename
        # Using the initial year for the filename as per the original script's implied logic
        target_filename = f"era5_data_{variable}_{initial_datetime.year}.nc"
        target_filepath = os.path.join(variable_output_dir, target_filename)

        request_params = {
            "product_type": product_type,
            "variable": [variable],
            "year": sorted_years,
            "month": sorted_months,
            "day": sorted_days,
            "time": sorted_times,
            "data_format": "netcdf",
            "area": area
        }

        print(f"  Downloading to: {target_filepath}")
        try:
            c.retrieve(
                dataset,
                request_params,
                target_filepath
            )
            print(f"  Successfully downloaded '{variable}' data to {target_filepath}")
        except Exception as e:
            print(f"  Error downloading '{variable}' data: {e}")
            print(f"  Request parameters were: {request_params}")
            print("  Please ensure your CDS API key is correctly configured in ~/.cdsapirc")
            print("  and that you have accepted the terms of use for the dataset on the CDS website.")

        time.sleep(1) # Small delay between requests to be polite to the API

    print("\n--- ERA5 Data Download Complete ---")


def main():
    parser = argparse.ArgumentParser(
        description="Download and process ERA5 data from CDS.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "yyyymmddhh_start", type=str, # CHANGED: argument name
        help="Initial date and hour to start data download (format: YYYYMMDDHH, e.g., 2015020200)."
    )
    parser.add_argument(
        "yyyymmddhh_end", type=str, # NEW: end date argument
        help="End date and hour for data download (format: YYYYMMDDHH, e.g., 2015020300)."
    )
    parser.add_argument(
        "--variables", type=str, default="sea_surface_temperature",
        help="Comma-separated list of ERA5 variable names (e.g., '2m_temperature,total_precipitation')."
    )
    parser.add_argument(
        "--output_dir", type=str, default="./era5_downloads",
        help="Base directory to save the downloaded NetCDF files."
    )
    parser.add_argument(
        "--frequency_hours", type=int, default=6,
        help="Frequency of data points in hours (e.g., 6 for 6-hourly data)."
    )
    parser.add_argument(
        "--north", type=float, default=50.0, help="Northernmost latitude for data area."
    )
    parser.add_argument(
        "--west", type=float, default=-150.0, help="Westernmost longitude for data area."
    )
    parser.add_argument(
        "--south", type=float, default=-75.0, help="Southernmost latitude for data area."
    )
    parser.add_argument(
        "--east", type=float, default=50.0, help="Easternmost longitude for data area."
    )

    args = parser.parse_args()

    # Parse initial and end dates
    try:
        initial_dt = parse_yyyymmddhh_to_datetime(args.yyyymmddhh_start)
        end_dt = parse_yyyymmddhh_to_datetime(args.yyyymmddhh_end)
        if initial_dt > end_dt:
            raise ValueError("Initial date cannot be after end date.")
    except ValueError as e:
        print(f"Error: {e}")
        sys.exit(1)

    # Convert comma-separated variables string to a list
    variables_list = [v.strip() for v in args.variables.split(',') if v.strip()]
    if not variables_list:
        print("Error: No variables specified for download.")
        sys.exit(1)

    # Define the geographical area
    area_bbox = [args.north, args.west, args.south, args.east]

    get_era5_data(
        initial_datetime=initial_dt,
        end_datetime=end_dt, # CHANGED: pass end_datetime
        variables=variables_list,
        output_base_dir=args.output_dir,
        frequency_hours=args.frequency_hours,
        area=area_bbox
    )

if __name__ == "__main__":
    main()

