Source code for unidep._conflicts

"""unidep - Unified Conda and Pip requirements management.

Verion conflict detections and resolution.
"""

from __future__ import annotations

import sys
from collections import defaultdict
from typing import TYPE_CHECKING

from packaging import version

from unidep.platform_definitions import Platform, Spec
from unidep.utils import defaultdict_to_dict, warn

if sys.version_info >= (3, 8):
    from typing import get_args
else:  # pragma: no cover
    from typing_extensions import get_args


if TYPE_CHECKING:
    from unidep.platform_definitions import CondaPip

VALID_OPERATORS = ["<=", ">=", "<", ">", "=", "!="]
_REPO_URL = "https://github.com/basnijholt/unidep"


def _prepare_specs_for_conflict_resolution(
    requirements: dict[str, list[Spec]],
) -> dict[str, dict[Platform | None, dict[CondaPip, list[Spec]]]]:
    """Prepare and group metadata for conflict resolution.

    This function groups metadata by platform and source for each package.

    :param requirements: Dictionary mapping package names to a list of Spec objects.
    :return: Dictionary mapping package names to grouped metadata.
    """
    prepared_data = {}
    for package, spec_list in requirements.items():
        grouped_specs: dict[Platform | None, dict[CondaPip, list[Spec]]] = defaultdict(
            lambda: defaultdict(list),
        )
        for spec in spec_list:
            _platforms = spec.platforms()
            if _platforms is None:
                _platforms = [None]  # type: ignore[list-item]
            for _platform in _platforms:
                grouped_specs[_platform][spec.which].append(spec)

        prepared_data[package] = grouped_specs
    return defaultdict_to_dict(prepared_data)


def _pop_unused_platforms_and_maybe_expand_none(
    platform_data: dict[Platform | None, dict[CondaPip, list[Spec]]],
    platforms: list[Platform] | None,
) -> None:
    """Expand `None` to all platforms if there is a platform besides None."""
    allowed_platforms = get_args(Platform)
    if platforms:
        allowed_platforms = platforms  # type: ignore[assignment]

    # If there is a platform besides None, expand None to all platforms
    if len(platform_data) > 1 and None in platform_data:
        sources = platform_data.pop(None)
        for _platform in allowed_platforms:
            for which, specs in sources.items():
                platform_data.setdefault(_platform, {}).setdefault(which, []).extend(
                    specs,
                )

    # Remove platforms that are not allowed
    to_pop = platform_data.keys() - allowed_platforms
    to_pop.discard(None)
    for _platform in to_pop:
        platform_data.pop(_platform)


def _maybe_new_spec_with_combined_pinnings_and_origins(
    specs: list[Spec],
) -> Spec:
    pinned_specs = [m for m in specs if m.pin is not None]
    combined_origin = tuple(sorted({p for s in specs for p in s.origin}))
    if len(pinned_specs) == 1:
        if len(combined_origin) == 1:
            return pinned_specs[0]
        # If there is only one pinned spec, but the origins are different,
        # we need to create a new spec with the combined origin.
        return pinned_specs[0]._replace(origin=combined_origin)

    if len(pinned_specs) > 1:
        first = pinned_specs[0]
        pins = [m.pin for m in pinned_specs]
        pin = combine_version_pinnings(pins, name=first.name)  # type: ignore[arg-type]
        return Spec(
            name=first.name,
            which=first.which,
            pin=pin,
            identifier=first.identifier,  # should I create a new one?
            origin=combined_origin,
        )

    # Flatten the list
    assert len(pinned_specs) == 0
    if len(combined_origin) > 1:
        # If there are no pinned specs, but the origins are different,
        # we need to create a new spec with the combined origin.
        return specs[0]._replace(origin=combined_origin)
    return specs[0]


def _combine_pinning_within_platform(
    data: dict[Platform | None, dict[CondaPip, list[Spec]]],
) -> dict[Platform | None, dict[CondaPip, Spec]]:
    reduced_data: dict[Platform | None, dict[CondaPip, Spec]] = {}
    for _platform, packages in data.items():
        reduced_data[_platform] = {}
        for which, specs in packages.items():
            spec = _maybe_new_spec_with_combined_pinnings_and_origins(specs)
            reduced_data[_platform][which] = spec
    return reduced_data


def _resolve_conda_pip_conflicts(sources: dict[CondaPip, Spec]) -> dict[CondaPip, Spec]:
    conda_spec = sources.get("conda")
    pip_spec = sources.get("pip")
    if not conda_spec or not pip_spec:  # If either is missing, there is no conflict
        return sources

    # Compare version pins to resolve conflicts
    if conda_spec.pin and not pip_spec.pin:
        return {"conda": conda_spec}  # Prefer conda if it has a pin
    if pip_spec.pin and not conda_spec.pin:
        return {"pip": pip_spec}  # Prefer pip if it has a pin
    if conda_spec.pin == pip_spec.pin:
        return {"conda": conda_spec, "pip": pip_spec}  # Keep both if pins are identical

    # Handle conflict where both conda and pip have different pins
    warn(
        "Version Pinning Conflict:\n"
        f"Different version specifications for Conda ('{conda_spec.pin}') and Pip"
        f" ('{pip_spec.pin}'). Both versions are retained.",
        stacklevel=2,
    )
    return {"conda": conda_spec, "pip": pip_spec}


class VersionConflictError(ValueError):
    """Raised when a version conflict is detected."""


def _add_optional_dependencies(
    requirements: dict[str, list[Spec]],
    optional_dependencies: dict[str, dict[str, list[Spec]]] | None,
) -> None:
    """Add optional dependencies to the requirements dictionary."""
    if optional_dependencies is None:
        return
    for dependencies in optional_dependencies.values():
        for pkg, specs in dependencies.items():
            requirements.setdefault(pkg, []).extend(specs)


[docs] def resolve_conflicts( requirements: dict[str, list[Spec]], platforms: list[Platform] | None = None, optional_dependencies: dict[str, dict[str, list[Spec]]] | None = None, ) -> dict[str, dict[Platform | None, dict[CondaPip, Spec]]]: """Resolve conflicts in a dictionary of requirements. Parameters ---------- requirements Dictionary mapping package names to a list of Spec objects. Typically ``ParsedRequirements.requirements`` is passed here, which is returned by `parse_requirements`. platforms List of platforms to resolve conflicts for. Typically ``ParsedRequirements.platforms`` is passed here, which is returned by `parse_requirements`. optional_dependencies Dictionary mapping package names to a dictionary of optional dependencies. Typically ``ParsedRequirements.optional_dependencies`` is passed here, which is returned by `parse_requirements`. If passing this argument, all optional dependencies will be added to the requirements dictionary. Pass `None` to ignore optional dependencies. Returns ------- Dictionary mapping package names to a dictionary of resolved metadata. The resolved metadata is a dictionary mapping platforms to a dictionary mapping sources to a single `Spec` object. """ if platforms and not set(platforms).issubset(get_args(Platform)): msg = f"Invalid platform: {platforms}, must contain only {get_args(Platform)}" raise VersionConflictError(msg) _add_optional_dependencies(requirements, optional_dependencies) prepared = _prepare_specs_for_conflict_resolution(requirements) for data in prepared.values(): _pop_unused_platforms_and_maybe_expand_none(data, platforms) resolved = { pkg: _combine_pinning_within_platform(data) for pkg, data in prepared.items() } for _platforms in resolved.values(): for _platform, sources in _platforms.items(): _platforms[_platform] = _resolve_conda_pip_conflicts(sources) return resolved
def _parse_pinning(pinning: str) -> tuple[str, version.Version]: """Separates the operator and the version number.""" pinning = pinning.strip() for operator in VALID_OPERATORS: if pinning.startswith(operator): version_part = pinning[len(operator) :].strip() if version_part: try: return operator, version.parse(version_part) except version.InvalidVersion: break else: break # Empty version string msg = f"Invalid version pinning: '{pinning}', must start with one of {VALID_OPERATORS}" # noqa: E501 raise VersionConflictError(msg) def _is_redundant(pinning: str, other_pinnings: list[str]) -> bool: """Determines if a version pinning is redundant given a list of other pinnings.""" op, version = _parse_pinning(pinning) for other in other_pinnings: other_op, other_version = _parse_pinning(other) if other == pinning: continue if op == "<" and ( (other_op == "<" and version >= other_version) or (other_op == "<=" and version > other_version) ): return True if op == "<=" and other_op in ["<", "<="] and version >= other_version: return True if op == ">" and ( (other_op == ">" and version <= other_version) or (other_op == ">=" and version < other_version) ): return True if op == ">=" and other_op in [">", ">="] and version <= other_version: return True return False def _is_valid_pinning(pinning: str) -> bool: """Checks if a version pinning string is valid.""" if any(op in pinning for op in VALID_OPERATORS): try: # Attempt to parse the version part of the pinning _parse_pinning(pinning) return True # noqa: TRY300 except VersionConflictError: # If parsing fails, the pinning is not valid return False # If the pinning doesn't contain any recognized operator, it's not valid return False def _deduplicate(pinnings: list[str]) -> list[str]: """Removes duplicate strings.""" return list(dict.fromkeys(pinnings)) # preserve order def _split_pinnings(pinnings: list[str]) -> list[str]: """Extracts version pinnings from a list of Spec objects.""" return [_pin.lstrip().rstrip() for pin in pinnings for _pin in pin.split(",")] def combine_version_pinnings(pinnings: list[str], *, name: str | None = None) -> str: """Combines a list of version pinnings into a single string.""" pinnings = [p for p in pinnings if p != ""] pinnings = _split_pinnings(pinnings) pinnings = _deduplicate(pinnings) if len(pinnings) == 1: return pinnings[0] for pin in pinnings: if not _is_valid_pinning(pin): ops = ", ".join(VALID_OPERATORS) url = f"{_REPO_URL}/blob/main/README.md#supported-version-pinnings" msg = ( f"Invalid version pinning '{pin}' for '{name}'. " "UniDep supports only the following operators for combining pinnings: " f"{ops}. For complex pinnings (like VCS URLs, local paths, or build" " strings), ensure all pinnings are identical. Divergent complex" f" pinnings cannot be combined. See {url} for more information." ) raise VersionConflictError(msg) valid_pinnings = [p.replace(" ", "") for p in pinnings] exact_pinnings = [p for p in valid_pinnings if p.startswith("=")] if len(exact_pinnings) > 1: pinnings_str = ", ".join(exact_pinnings) msg = f"Multiple exact version pinnings found: {pinnings_str} for `{name}`" raise VersionConflictError(msg) err_msg = f"Contradictory version pinnings found for `{name}`" if exact_pinnings: exact_pin = exact_pinnings[0] exact_version = version.parse(exact_pin[1:]) for other_pin in valid_pinnings: if other_pin != exact_pin: op, ver = _parse_pinning(other_pin) if not ( (op == "<" and exact_version < ver) or (op == "<=" and exact_version <= ver) or (op == ">" and exact_version > ver) or (op == ">=" and exact_version >= ver) ): msg = f"{err_msg}: {exact_pin} and {other_pin}" raise VersionConflictError(msg) return exact_pin non_redundant_pinnings = [ pin for pin in valid_pinnings if not _is_redundant(pin, valid_pinnings) ] for i, pin in enumerate(non_redundant_pinnings): for other_pin in non_redundant_pinnings[i + 1 :]: op1, ver1 = _parse_pinning(pin) op2, ver2 = _parse_pinning(other_pin) msg = f"{err_msg}: {pin} and {other_pin}" # Check for direct contradictions like >2 and <1 if (op1 == ">" and op2 == "<" and ver1 >= ver2) or ( op1 == "<" and op2 == ">" and ver1 <= ver2 ): raise VersionConflictError(msg) # Check for contradictions involving inclusive bounds like >=2 and <1 if ( (op1 == ">=" and op2 == "<" and ver1 >= ver2) or (op1 == ">" and op2 == "<=" and ver1 >= ver2) or (op1 == "<=" and op2 == ">" and ver1 <= ver2) or (op1 == ">" and op2 == "<=" and ver1 >= ver2) ): raise VersionConflictError(msg) return ",".join(non_redundant_pinnings)