import argparse
import datetime
import os
import sys
import cdsapi
import time
import multiprocessing # NEW: Import for parallel processing

def parse_yyyymmddhh_to_datetime(date_string):
    """
    Converts a date string in 'YYYYMMDDHH' format to a datetime.datetime object.
    """
    if not isinstance(date_string, str) or len(date_string) != 10:
        raise ValueError("Input date_string must be a string of 10 characters in YYYYMMDDHH format.")
    try:
        dt_object = datetime.datetime.strptime(date_string, '%Y%m%d%H')
        return dt_object
    except ValueError as e:
        raise ValueError(f"Could not parse date string '{date_string}'. Ensure it's in YYYYMMDDHH format. Error: {e}")

def get_era5_data_for_year(
    year: int, # CHANGED: Now takes a single year
    variables: list[str],
    output_base_dir: str,
    frequency_hours: int = 6,
    dataset: str = "reanalysis-era5-single-levels",
    product_type: str = "reanalysis",
    area: list[float] = [50, -150, -75, 50] # [North, West, South, East]
):
    """
    Downloads ERA5 data from the Copernicus Climate Data Store (CDS) for a specified
    year and list of variables, saving them as NetCDF files. This function is designed
    to be run in parallel for multiple years.

    Args:
        year (int): The year for which to download data.
        variables (list[str]): A list of ERA5 variable names to download.
        output_base_dir (str): The base directory where downloaded data will be saved.
                                Subdirectories will be created for each variable.
        frequency_hours (int): The frequency of data points (e.g., 6 for 6-hourly data).
        dataset (str): The CDS dataset ID.
        product_type (str): The product type.
        area (list[float]): Bounding box for the geographic area [North, West, South, East].
    """
    # Define start and end datetimes for the entire year
    initial_datetime = datetime.datetime(year, 1, 1, 0)
    # The last possible time for 6-hourly data in a year is Dec 31st 18:00
    end_datetime = datetime.datetime(year, 12, 31, 18)

    print(f"--- Starting ERA5 Data Download for Year: {year} ---")
    print(f"  Time Range: {initial_datetime.strftime('%Y-%m-%d %H:%M:%S')} to {end_datetime.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"  Variables: {', '.join(variables)}")
    print(f"  Output Base Directory: {output_base_dir}")
    print(f"  Frequency: {frequency_hours} hours")
    print(f"  Area: {area}")

    # Initialize CDS API client
    c = cdsapi.Client()

    # Collect unique years, months, days, and times within the range
    # For a full year, these will be all months, all days (up to 31), and all 6-hourly times.
    # We still iterate to be robust for partial year downloads if needed later,
    # but for full years, these sets will simply contain all valid options.
    years_in_range = set()
    months_in_range = set()
    days_in_range = set()
    times_in_range = set()

    current_dt = initial_datetime
    while current_dt <= end_datetime:
        years_in_range.add(current_dt.strftime('%Y'))
        months_in_range.add(current_dt.strftime('%m'))
        days_in_range.add(current_dt.strftime('%d'))
        times_in_range.add(current_dt.strftime('%H:00'))
        current_dt += datetime.timedelta(hours=frequency_hours)

    # Sort them for consistent request structure
    sorted_years = sorted(list(years_in_range))
    sorted_months = sorted(list(months_in_range))
    sorted_days = sorted(list(days_in_range))
    sorted_times = sorted(list(times_in_range))

    print(f"\n  Calculated request parameters for year {year}:")
    print(f"    Years: {sorted_years}")
    print(f"    Months: {sorted_months}")
    print(f"    Days: {sorted_days}")
    print(f"    Times: {sorted_times}")

    for variable in variables:
        # Create output directory for the current variable
        variable_output_dir = os.path.join(output_base_dir, variable)
        os.makedirs(variable_output_dir, exist_ok=True)
        print(f"\n  Processing variable: '{variable}' for year {year}")
        print(f"    Output directory: {variable_output_dir}")

        # Construct the target filename to include the specific year
        target_filename = f"era5_data_{variable}_{year}.nc"
        target_filepath = os.path.join(variable_output_dir, target_filename)

        request_params = {
            "product_type": product_type,
            "variable": [variable],
            "year": sorted_years,
            "month": sorted_months,
            "day": sorted_days,
            "time": sorted_times,
            "data_format": "netcdf",
            "area": area
        }

        print(f"    Downloading to: {target_filepath}")
        try:
            c.retrieve(
                dataset,
                request_params,
                target_filepath
            )
            print(f"    Successfully downloaded '{variable}' data for year {year} to {target_filepath}")
        except Exception as e:
            print(f"    Error downloading '{variable}' data for year {year}: {e}")
            print(f"    Request parameters were: {request_params}")
            print("    Please ensure your CDS API key is correctly configured in ~/.cdsapirc")
            print("    and that you have accepted the terms of use for the dataset on the CDS website.")

        time.sleep(1) # Small delay between requests to be polite to the API

    print(f"--- ERA5 Data Download for Year {year} Complete ---")


def main():
    parser = argparse.ArgumentParser(
        description="Download and process ERA5 data from CDS for multiple years in parallel.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "--years", type=str, required=True, # NEW: years argument
        help="Comma-separated list of years to download data for (e.g., '2015,2016,2017')."
    )
    parser.add_argument(
        "--variables", type=str, default="sea_surface_temperature",
        help="Comma-separated list of ERA5 variable names (e.g., '2m_temperature,total_precipitation')."
    )
    parser.add_argument(
        "--output_dir", type=str, default="./era5_downloads",
        help="Base directory to save the downloaded NetCDF files."
    )
    parser.add_argument(
        "--frequency_hours", type=int, default=6,
        help="Frequency of data points in hours (e.g., 6 for 6-hourly data)."
    )
    parser.add_argument(
        "--north", type=float, default=50.0, help="Northernmost latitude for data area."
    )
    parser.add_argument(
        "--west", type=float, default=-150.0, help="Westernmost longitude for data area."
    )
    parser.add_argument(
        "--south", type=float, default=-75.0, help="Southernmost latitude for data area."
    )
    parser.add_argument(
        "--east", type=float, default=50.0, help="Easternmost longitude for data area."
    )
    parser.add_argument(
        "--num_processes", type=int, default=multiprocessing.cpu_count(),
        help="Number of parallel processes to use for downloading. Defaults to CPU count."
    )

    args = parser.parse_args()

    # Parse years
    try:
        years_list = [int(y.strip()) for y in args.years.split(',') if y.strip()]
        if not years_list:
            raise ValueError("No valid years provided.")
    except ValueError as e:
        print(f"Error parsing years: {e}")
        sys.exit(1)

    # Convert comma-separated variables string to a list
    variables_list = [v.strip() for v in args.variables.split(',') if v.strip()]
    if not variables_list:
        print("Error: No variables specified for download.")
        sys.exit(1)

    # Define the geographical area
    area_bbox = [args.north, args.west, args.south, args.east]

    print(f"Starting parallel download for years: {years_list}")
    print(f"Using {args.num_processes} processes.")

    # Prepare arguments for multiprocessing pool
    pool_args = []
    for year in years_list:
        pool_args.append((
            year,
            variables_list,
            args.output_dir,
            args.frequency_hours,
            "reanalysis-era5-single-levels", # dataset
            "reanalysis", # product_type
            area_bbox
        ))

    # Use a multiprocessing Pool to run downloads in parallel
    with multiprocessing.Pool(processes=args.num_processes) as pool:
        pool.starmap(get_era5_data_for_year, pool_args)

    print("\nAll ERA5 data downloads completed across all specified years.")

if __name__ == "__main__":
    # This block ensures that multiprocessing works correctly on Windows
    # by protecting the main execution point.
    multiprocessing.freeze_support()
    main()

