import argparse
import datetime
import os
import sys
import cdsapi
import time
import multiprocessing
from typing import List

def parse_yyyymmddhh_to_datetime(date_string):
    """
    Converts a date string in 'YYYYMMDDHH' format to a datetime.datetime object.
    """
    if not isinstance(date_string, str) or len(date_string) != 10:
        raise ValueError("Input date_string must be a string of 10 characters in YYYYMMDDHH format.")
    try:
        dt_object = datetime.datetime.strptime(date_string, '%Y%m%d%H')
        return dt_object
    except ValueError as e:
        raise ValueError(f"Could not parse date string '{date_string}'. Ensure it's in YYYYMMDDHH format. Error: {e}")

def get_era5_data_for_year(
    year: int,
    variables: List[str],
    output_base_dir: str,
    # frequency_hours: int = 6, # Removed as it's not a direct parameter for derived daily data
    dataset: str = "derived-era5-single-levels-daily-statistics", # CHANGED: New dataset
    product_type: str = "reanalysis",
    daily_statistic: str = "daily_mean", # NEW: Parameter for derived data
    time_zone: str = "utc-12:00",       # NEW: Parameter for derived data
    area: List[float] = [50, -150, -75, 50] # [North, West, South, East]
):
    """
    Downloads ERA5 derived daily statistics from the Copernicus Climate Data Store (CDS)
    for a specified year and list of variables, saving them as NetCDF files.
    This function is designed to be run in parallel for multiple years.

    Args:
        year (int): The year for which to download data.
        variables (List[str]): A list of ERA5 derived variable names to download
                                (e.g., ["2m_temperature_daily_mean", "total_precipitation_daily_sum"]).
        output_base_dir (str): The base directory where downloaded data will be saved.
                                Subdirectories will be created for each variable.
        dataset (str): The CDS dataset ID (default: "derived-era5-single-levels-daily-statistics").
        product_type (str): The product type (default: "reanalysis").
        daily_statistic (str): The daily statistic to retrieve (e.g., "daily_mean", "daily_maximum").
        time_zone (str): The time zone for the daily statistic calculation (e.g., "utc-12:00").
        area (List[float]): Bounding box for the geographic area [North, West, South, East].
    """
    # For derived daily data, we request all days of the year.
    # The 'time' parameter is not used for daily statistics.
    initial_datetime = datetime.datetime(year, 1, 1)
    end_datetime = datetime.datetime(year, 12, 31)

    print(f"--- Starting ERA5 Daily Derived Data Download for Year: {year} ---")
    print(f"  Time Range: {initial_datetime.strftime('%Y-%m-%d')} to {end_datetime.strftime('%Y-%m-%d')}")
    print(f"  Variables: {', '.join(variables)}")
    print(f"  Output Base Directory: {output_base_dir}")
    print(f"  Daily Statistic: {daily_statistic}")
    print(f"  Time Zone: {time_zone}")
    print(f"  Area: {area}")

    # Initialize CDS API client
    c = cdsapi.Client()

    # Collect unique years, months, and days within the range
    years_in_range = {str(year)} # Only the current year
    months_in_range = [f"{i:02d}" for i in range(1, 13)] # All 12 months
    days_in_range = [f"{i:02d}" for i in range(1, 32)] # All days up to 31

    print(f"\n  Calculated request parameters for year {year}:")
    print(f"    Years: {list(years_in_range)}")
    print(f"    Months: {months_in_range}")
    print(f"    Days: {days_in_range}")

    for variable in variables:
        # Create output directory for the current variable
        variable_output_dir = os.path.join(output_base_dir, variable)
        os.makedirs(variable_output_dir, exist_ok=True)
        print(f"\n  Processing variable: '{variable}' for year {year}")
        print(f"    Output directory: {variable_output_dir}")

        # Construct the target filename to include the specific year
        target_filename = f"era5_daily_{variable}_{year}.nc" # Changed filename prefix
        target_filepath = os.path.join(variable_output_dir, target_filename)

        request_params = {
            "product_type": product_type,
            "variable": [variable],
            "year": list(years_in_range), # Convert set to list for request
            "month": months_in_range,
            "day": days_in_range,
            "daily_statistic": daily_statistic, # NEW
            "frequency": "6_hourly",
            "time_zone": time_zone,             # NEW
            "format": "netcdf",                 # CHANGED: 'data_format' to 'format' for this dataset
            "area": area
        }

        print(f"    Downloading to: {target_filepath}")
        try:
            c.retrieve(
                dataset,
                request_params,
                target_filepath
            )
            print(f"    Successfully downloaded '{variable}' data for year {year} to {target_filepath}")
        except Exception as e:
            print(f"    Error downloading '{variable}' data for year {year}: {e}")
            print(f"    Request parameters were: {request_params}")
            print("    Please ensure your CDS API key is correctly configured in ~/.cdsapirc")
            print("    and that you have accepted the terms of use for the dataset on the CDS website.")

        time.sleep(1) # Small delay between requests to be polite to the API

    print(f"--- ERA5 Daily Derived Data Download for Year {year} Complete ---")


def main():
    parser = argparse.ArgumentParser(
        description="Download ERA5 derived daily statistics from CDS for multiple years in parallel.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "--years", type=str, required=True,
        help="Comma-separated list of years to download data for (e.g., '2015,2016,2017')."
    )
    parser.add_argument(
        "--variables", type=str, default="2m_temperature_daily_mean", # Adjusted default variable
        help="Comma-separated list of ERA5 derived variable names (e.g., '2m_temperature_daily_mean,total_precipitation_daily_sum')."
    )
    parser.add_argument(
        "--output_dir", type=str, default="./era5_daily_downloads", # Adjusted default output dir
        help="Base directory to save the downloaded NetCDF files."
    )
    parser.add_argument(
        "--daily_statistic", type=str, default="daily_mean",
        choices=["daily_mean", "daily_maximum", "daily_minimum", "daily_sum", "daily_standard_deviation"],
        help="The daily statistic to retrieve."
    )
    parser.add_argument(
        "--time_zone", type=str, default="utc-12:00",
        help="The time zone for the daily statistic calculation (e.g., 'utc-12:00', 'utc+00:00')."
    )
    parser.add_argument(
        "--north", type=float, default=50.0, help="Northernmost latitude for data area."
    )
    parser.add_argument(
        "--west", type=float, default=-150.0, help="Westernmost longitude for data area."
    )
    parser.add_argument(
        "--south", type=float, default=-75.0, help="Southernmost latitude for data area."
    )
    parser.add_argument(
        "--east", type=float, default=50.0, help="Easternmost longitude for data area."
    )
    parser.add_argument(
        "--num_processes", type=int, default=multiprocessing.cpu_count(),
        help="Number of parallel processes to use for downloading. Defaults to CPU count."
    )

    args = parser.parse_args()

    # Parse years
    try:
        years_list = [int(y.strip()) for y in args.years.split(',') if y.strip()]
        if not years_list:
            raise ValueError("No valid years provided.")
    except ValueError as e:
        print(f"Error parsing years: {e}")
        sys.exit(1)

    # Convert comma-separated variables string to a list
    variables_list = [v.strip() for v in args.variables.split(',') if v.strip()]
    if not variables_list:
        print("Error: No variables specified for download.")
        sys.exit(1)

    # Define the geographical area
    area_bbox = [args.north, args.west, args.south, args.east]

    print(f"Starting parallel download for years: {years_list}")
    print(f"Using {args.num_processes} processes.")

    # Prepare arguments for multiprocessing pool
    pool_args = []
    for year in years_list:
        pool_args.append((
            year,
            variables_list,
            args.output_dir,
            # No frequency_hours for this dataset
            "derived-era5-single-levels-daily-statistics", # dataset
            "reanalysis", # product_type
            args.daily_statistic, # Pass daily_statistic
            args.time_zone,       # Pass time_zone
            area_bbox
        ))

    # Use a multiprocessing Pool to run downloads in parallel
    with multiprocessing.Pool(processes=args.num_processes) as pool:
        pool.starmap(get_era5_data_for_year, pool_args)

    print("\nAll ERA5 daily derived data downloads completed across all specified years.")

if __name__ == "__main__":
    multiprocessing.freeze_support()
    main()

