"""unidep - Unified Conda and Pip requirements management.
Verion conflict detections and resolution.
"""
from __future__ import annotations
import sys
from collections import defaultdict
from typing import TYPE_CHECKING
from packaging import version
from unidep.platform_definitions import Platform, Spec
from unidep.utils import defaultdict_to_dict, warn
if sys.version_info >= (3, 8):
from typing import get_args
else: # pragma: no cover
from typing_extensions import get_args
if TYPE_CHECKING:
from unidep.platform_definitions import CondaPip
VALID_OPERATORS = ["<=", ">=", "<", ">", "=", "!="]
_REPO_URL = "https://github.com/basnijholt/unidep"
def _prepare_specs_for_conflict_resolution(
requirements: dict[str, list[Spec]],
) -> dict[str, dict[Platform | None, dict[CondaPip, list[Spec]]]]:
"""Prepare and group metadata for conflict resolution.
This function groups metadata by platform and source for each package.
:param requirements: Dictionary mapping package names to a list of Spec objects.
:return: Dictionary mapping package names to grouped metadata.
"""
prepared_data = {}
for package, spec_list in requirements.items():
grouped_specs: dict[Platform | None, dict[CondaPip, list[Spec]]] = defaultdict(
lambda: defaultdict(list),
)
for spec in spec_list:
_platforms = spec.platforms()
if _platforms is None:
_platforms = [None] # type: ignore[list-item]
for _platform in _platforms:
grouped_specs[_platform][spec.which].append(spec)
prepared_data[package] = grouped_specs
return defaultdict_to_dict(prepared_data)
def _pop_unused_platforms_and_maybe_expand_none(
platform_data: dict[Platform | None, dict[CondaPip, list[Spec]]],
platforms: list[Platform] | None,
) -> None:
"""Expand `None` to all platforms if there is a platform besides None."""
allowed_platforms = get_args(Platform)
if platforms:
allowed_platforms = platforms # type: ignore[assignment]
# If there is a platform besides None, expand None to all platforms
if len(platform_data) > 1 and None in platform_data:
sources = platform_data.pop(None)
for _platform in allowed_platforms:
for which, specs in sources.items():
platform_data.setdefault(_platform, {}).setdefault(which, []).extend(
specs,
)
# Remove platforms that are not allowed
to_pop = platform_data.keys() - allowed_platforms
to_pop.discard(None)
for _platform in to_pop:
platform_data.pop(_platform)
def _maybe_new_spec_with_combined_pinnings_and_origins(
specs: list[Spec],
) -> Spec:
pinned_specs = [m for m in specs if m.pin is not None]
combined_origin = tuple(sorted({p for s in specs for p in s.origin}))
if len(pinned_specs) == 1:
if len(combined_origin) == 1:
return pinned_specs[0]
# If there is only one pinned spec, but the origins are different,
# we need to create a new spec with the combined origin.
return pinned_specs[0]._replace(origin=combined_origin)
if len(pinned_specs) > 1:
first = pinned_specs[0]
pins = [m.pin for m in pinned_specs]
pin = combine_version_pinnings(pins, name=first.name) # type: ignore[arg-type]
return Spec(
name=first.name,
which=first.which,
pin=pin,
identifier=first.identifier, # should I create a new one?
origin=combined_origin,
)
# Flatten the list
assert len(pinned_specs) == 0
if len(combined_origin) > 1:
# If there are no pinned specs, but the origins are different,
# we need to create a new spec with the combined origin.
return specs[0]._replace(origin=combined_origin)
return specs[0]
def _combine_pinning_within_platform(
data: dict[Platform | None, dict[CondaPip, list[Spec]]],
) -> dict[Platform | None, dict[CondaPip, Spec]]:
reduced_data: dict[Platform | None, dict[CondaPip, Spec]] = {}
for _platform, packages in data.items():
reduced_data[_platform] = {}
for which, specs in packages.items():
spec = _maybe_new_spec_with_combined_pinnings_and_origins(specs)
reduced_data[_platform][which] = spec
return reduced_data
def _resolve_conda_pip_conflicts(sources: dict[CondaPip, Spec]) -> dict[CondaPip, Spec]:
conda_spec = sources.get("conda")
pip_spec = sources.get("pip")
if not conda_spec or not pip_spec: # If either is missing, there is no conflict
return sources
# Compare version pins to resolve conflicts
if conda_spec.pin and not pip_spec.pin:
return {"conda": conda_spec} # Prefer conda if it has a pin
if pip_spec.pin and not conda_spec.pin:
return {"pip": pip_spec} # Prefer pip if it has a pin
if conda_spec.pin == pip_spec.pin:
return {"conda": conda_spec, "pip": pip_spec} # Keep both if pins are identical
# Handle conflict where both conda and pip have different pins
warn(
"Version Pinning Conflict:\n"
f"Different version specifications for Conda ('{conda_spec.pin}') and Pip"
f" ('{pip_spec.pin}'). Both versions are retained.",
stacklevel=2,
)
return {"conda": conda_spec, "pip": pip_spec}
class VersionConflictError(ValueError):
"""Raised when a version conflict is detected."""
def _add_optional_dependencies(
requirements: dict[str, list[Spec]],
optional_dependencies: dict[str, dict[str, list[Spec]]] | None,
) -> None:
"""Add optional dependencies to the requirements dictionary."""
if optional_dependencies is None:
return
for dependencies in optional_dependencies.values():
for pkg, specs in dependencies.items():
requirements.setdefault(pkg, []).extend(specs)
[docs]
def resolve_conflicts(
requirements: dict[str, list[Spec]],
platforms: list[Platform] | None = None,
optional_dependencies: dict[str, dict[str, list[Spec]]] | None = None,
) -> dict[str, dict[Platform | None, dict[CondaPip, Spec]]]:
"""Resolve conflicts in a dictionary of requirements.
Parameters
----------
requirements
Dictionary mapping package names to a list of Spec objects.
Typically ``ParsedRequirements.requirements`` is passed here, which is
returned by `parse_requirements`.
platforms
List of platforms to resolve conflicts for.
Typically ``ParsedRequirements.platforms`` is passed here, which is
returned by `parse_requirements`.
optional_dependencies
Dictionary mapping package names to a dictionary of optional dependencies.
Typically ``ParsedRequirements.optional_dependencies`` is passed here, which is
returned by `parse_requirements`. If passing this argument, all optional
dependencies will be added to the requirements dictionary. Pass `None` to
ignore optional dependencies.
Returns
-------
Dictionary mapping package names to a dictionary of resolved metadata.
The resolved metadata is a dictionary mapping platforms to a dictionary
mapping sources to a single `Spec` object.
"""
if platforms and not set(platforms).issubset(get_args(Platform)):
msg = f"Invalid platform: {platforms}, must contain only {get_args(Platform)}"
raise VersionConflictError(msg)
_add_optional_dependencies(requirements, optional_dependencies)
prepared = _prepare_specs_for_conflict_resolution(requirements)
for data in prepared.values():
_pop_unused_platforms_and_maybe_expand_none(data, platforms)
resolved = {
pkg: _combine_pinning_within_platform(data) for pkg, data in prepared.items()
}
for _platforms in resolved.values():
for _platform, sources in _platforms.items():
_platforms[_platform] = _resolve_conda_pip_conflicts(sources)
return resolved
def _parse_pinning(pinning: str) -> tuple[str, version.Version]:
"""Separates the operator and the version number."""
pinning = pinning.strip()
for operator in VALID_OPERATORS:
if pinning.startswith(operator):
version_part = pinning[len(operator) :].strip()
if version_part:
try:
return operator, version.parse(version_part)
except version.InvalidVersion:
break
else:
break # Empty version string
msg = f"Invalid version pinning: '{pinning}', must start with one of {VALID_OPERATORS}" # noqa: E501
raise VersionConflictError(msg)
def _is_redundant(pinning: str, other_pinnings: list[str]) -> bool:
"""Determines if a version pinning is redundant given a list of other pinnings."""
op, version = _parse_pinning(pinning)
for other in other_pinnings:
other_op, other_version = _parse_pinning(other)
if other == pinning:
continue
if op == "<" and (
(other_op == "<" and version >= other_version)
or (other_op == "<=" and version > other_version)
):
return True
if op == "<=" and other_op in ["<", "<="] and version >= other_version:
return True
if op == ">" and (
(other_op == ">" and version <= other_version)
or (other_op == ">=" and version < other_version)
):
return True
if op == ">=" and other_op in [">", ">="] and version <= other_version:
return True
return False
def _is_valid_pinning(pinning: str) -> bool:
"""Checks if a version pinning string is valid."""
if any(op in pinning for op in VALID_OPERATORS):
try:
# Attempt to parse the version part of the pinning
_parse_pinning(pinning)
return True # noqa: TRY300
except VersionConflictError:
# If parsing fails, the pinning is not valid
return False
# If the pinning doesn't contain any recognized operator, it's not valid
return False
def _deduplicate(pinnings: list[str]) -> list[str]:
"""Removes duplicate strings."""
return list(dict.fromkeys(pinnings)) # preserve order
def _split_pinnings(pinnings: list[str]) -> list[str]:
"""Extracts version pinnings from a list of Spec objects."""
return [_pin.lstrip().rstrip() for pin in pinnings for _pin in pin.split(",")]
def combine_version_pinnings(pinnings: list[str], *, name: str | None = None) -> str:
"""Combines a list of version pinnings into a single string."""
pinnings = [p for p in pinnings if p != ""]
pinnings = _split_pinnings(pinnings)
pinnings = _deduplicate(pinnings)
if len(pinnings) == 1:
return pinnings[0]
for pin in pinnings:
if not _is_valid_pinning(pin):
ops = ", ".join(VALID_OPERATORS)
url = f"{_REPO_URL}/blob/main/README.md#supported-version-pinnings"
msg = (
f"Invalid version pinning '{pin}' for '{name}'. "
"UniDep supports only the following operators for combining pinnings: "
f"{ops}. For complex pinnings (like VCS URLs, local paths, or build"
" strings), ensure all pinnings are identical. Divergent complex"
f" pinnings cannot be combined. See {url} for more information."
)
raise VersionConflictError(msg)
valid_pinnings = [p.replace(" ", "") for p in pinnings]
exact_pinnings = [p for p in valid_pinnings if p.startswith("=")]
if len(exact_pinnings) > 1:
pinnings_str = ", ".join(exact_pinnings)
msg = f"Multiple exact version pinnings found: {pinnings_str} for `{name}`"
raise VersionConflictError(msg)
err_msg = f"Contradictory version pinnings found for `{name}`"
if exact_pinnings:
exact_pin = exact_pinnings[0]
exact_version = version.parse(exact_pin[1:])
for other_pin in valid_pinnings:
if other_pin != exact_pin:
op, ver = _parse_pinning(other_pin)
if not (
(op == "<" and exact_version < ver)
or (op == "<=" and exact_version <= ver)
or (op == ">" and exact_version > ver)
or (op == ">=" and exact_version >= ver)
):
msg = f"{err_msg}: {exact_pin} and {other_pin}"
raise VersionConflictError(msg)
return exact_pin
non_redundant_pinnings = [
pin for pin in valid_pinnings if not _is_redundant(pin, valid_pinnings)
]
for i, pin in enumerate(non_redundant_pinnings):
for other_pin in non_redundant_pinnings[i + 1 :]:
op1, ver1 = _parse_pinning(pin)
op2, ver2 = _parse_pinning(other_pin)
msg = f"{err_msg}: {pin} and {other_pin}"
# Check for direct contradictions like >2 and <1
if (op1 == ">" and op2 == "<" and ver1 >= ver2) or (
op1 == "<" and op2 == ">" and ver1 <= ver2
):
raise VersionConflictError(msg)
# Check for contradictions involving inclusive bounds like >=2 and <1
if (
(op1 == ">=" and op2 == "<" and ver1 >= ver2)
or (op1 == ">" and op2 == "<=" and ver1 >= ver2)
or (op1 == "<=" and op2 == ">" and ver1 <= ver2)
or (op1 == ">" and op2 == "<=" and ver1 >= ver2)
):
raise VersionConflictError(msg)
return ",".join(non_redundant_pinnings)