"""unidep - Unified Conda and Pip requirements management.
Conda environment file generation functions.
"""
from __future__ import annotations
import sys
from collections import defaultdict
from copy import deepcopy
from typing import TYPE_CHECKING, NamedTuple, cast
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap, CommentedSeq
from unidep._conflicts import (
VersionConflictError,
_maybe_new_spec_with_combined_pinnings_and_origins,
)
from unidep.platform_definitions import (
PLATFORM_SELECTOR_MAP,
CondaPip,
CondaPlatform,
Platform,
Spec,
)
from unidep.utils import (
add_comment_to_file,
build_pep508_environment_marker,
warn,
)
if TYPE_CHECKING:
from pathlib import Path
if sys.version_info >= (3, 8):
from typing import Literal, get_args
else: # pragma: no cover
from typing_extensions import Literal, get_args
class CondaEnvironmentSpec(NamedTuple):
"""A conda environment."""
channels: list[str]
platforms: list[Platform]
conda: list[str | dict[str, str]] # actually a CommentedSeq[str | dict[str, str]]
pip: list[str]
def _conda_sel(sel: str) -> CondaPlatform:
"""Return the allowed `sel(platform)` string."""
_platform = sel.split("-", 1)[0]
assert _platform in get_args(CondaPlatform), f"Invalid platform: {_platform}"
return cast(CondaPlatform, _platform)
def _extract_conda_pip_dependencies(
resolved: dict[str, dict[Platform | None, dict[CondaPip, Spec]]],
) -> tuple[
dict[str, dict[Platform | None, Spec]],
dict[str, dict[Platform | None, Spec]],
]:
"""Extract and separate conda and pip dependencies."""
conda: dict[str, dict[Platform | None, Spec]] = {}
pip: dict[str, dict[Platform | None, Spec]] = {}
for pkg, platform_data in resolved.items():
for _platform, sources in platform_data.items():
if "conda" in sources:
conda.setdefault(pkg, {})[_platform] = sources["conda"]
else:
pip.setdefault(pkg, {})[_platform] = sources["pip"]
return conda, pip
def _resolve_multiple_platform_conflicts(
platform_to_spec: dict[Platform | None, Spec],
) -> None:
"""Fix conflicts for deps with platforms that map to a single Conda platform.
In a Conda environment with dependencies across various platforms (like
'linux-aarch64', 'linux64'), this function ensures consistency in metadata
for each Conda platform (e.g., 'sel(linux): ...'). It maps each platform to
a Conda platform and resolves conflicts by retaining the first `Spec` object
per Conda platform, discarding others. This approach guarantees uniform
metadata across different but equivalent platforms.
"""
valid: dict[
CondaPlatform,
dict[Spec, list[Platform | None]],
] = defaultdict(lambda: defaultdict(list))
for _platform, spec in platform_to_spec.items():
assert _platform is not None
conda_platform = _conda_sel(_platform)
valid[conda_platform][spec].append(_platform)
for conda_platform, spec_to_platforms in valid.items():
# We cannot distinguish between e.g., linux-64 and linux-aarch64
# (which becomes linux). So of the list[Platform] we only need to keep
# one Platform. We can pop the rest from `platform_to_spec`. This is
# not a problem because they share the same `Spec` object.
for platforms in spec_to_platforms.values():
for j, _platform in enumerate(platforms):
if j >= 1:
platform_to_spec.pop(_platform)
# Now make sure that valid[conda_platform] has only one key.
# That means that all `Spec`s for the different Platforms that map to a
# CondaPlatform are identical. If len > 1, we have a conflict.
if len(spec_to_platforms) > 1:
specs, (first_platform, *_) = zip(*spec_to_platforms.items())
first, *others = specs
try:
spec = _maybe_new_spec_with_combined_pinnings_and_origins(specs) # type: ignore[arg-type]
except VersionConflictError:
# We have a conflict, select the first one.
msg = (
f"Dependency Conflict on '{conda_platform}':\n"
f"Multiple versions detected. Retaining '{first.pprint()}' and"
f" discarding conflicts: {', '.join(o.pprint() for o in others)}."
)
warn(msg, stacklevel=2)
else:
# Means that we could combine the pinnings
spec_to_platforms.pop(first)
spec_to_platforms[spec] = [first_platform]
for other in others:
platforms = spec_to_platforms[other]
for _platform in platforms:
if _platform in platform_to_spec: # might have been popped already
platform_to_spec.pop(_platform)
# Now we have only one `Spec` left, so we can select it.
def _add_comment(commment_seq: CommentedSeq, platform: Platform) -> None:
comment = f"# [{PLATFORM_SELECTOR_MAP[platform][0]}]"
commment_seq.yaml_add_eol_comment(comment, len(commment_seq) - 1)
[docs]
def create_conda_env_specification( # noqa: PLR0912
resolved: dict[str, dict[Platform | None, dict[CondaPip, Spec]]],
channels: list[str],
platforms: list[Platform],
selector: Literal["sel", "comment"] = "sel",
) -> CondaEnvironmentSpec:
"""Create a conda environment specification from resolved requirements."""
if selector not in ("sel", "comment"): # pragma: no cover
msg = f"Invalid selector: {selector}, must be one of ['sel', 'comment']"
raise ValueError(msg)
# Split in conda and pip dependencies and prefer conda over pip
conda, pip = _extract_conda_pip_dependencies(resolved)
conda_deps: list[str | dict[str, str]] = CommentedSeq()
pip_deps: list[str] = CommentedSeq()
seen_identifiers: set[str] = set()
for platform_to_spec in conda.values():
if len(platform_to_spec) > 1 and selector == "sel":
# None has been expanded already if len>1
_resolve_multiple_platform_conflicts(platform_to_spec)
for _platform, spec in sorted(platform_to_spec.items()):
dep_str = spec.name_with_pin()
if len(platforms) != 1 and _platform is not None:
if selector == "sel":
sel = _conda_sel(_platform)
dep_str = {f"sel({sel})": dep_str} # type: ignore[assignment]
conda_deps.append(dep_str)
if selector == "comment":
_add_comment(conda_deps, _platform)
else:
conda_deps.append(dep_str)
assert isinstance(spec.identifier, str)
seen_identifiers.add(spec.identifier)
for platform_to_spec in pip.values():
spec_to_platforms: dict[Spec, list[Platform | None]] = {}
for _platform, spec in platform_to_spec.items():
spec_to_platforms.setdefault(spec, []).append(_platform)
for spec, _platforms in spec_to_platforms.items():
if spec.identifier in seen_identifiers:
continue
dep_str = spec.name_with_pin(is_pip=True)
if _platforms != [None] and len(platforms) != 1:
if selector == "sel":
marker = build_pep508_environment_marker(_platforms) # type: ignore[arg-type]
dep_str = f"{dep_str}; {marker}"
pip_deps.append(dep_str)
else:
assert selector == "comment"
# We can only add comments with a single platform because
# `conda-lock` doesn't implement logic, e.g., [linux or win]
# should be spread into two lines, one with [linux] and the
# other with [win].
for _platform in _platforms:
pip_deps.append(dep_str)
_add_comment(pip_deps, cast(Platform, _platform))
else:
pip_deps.append(dep_str)
return CondaEnvironmentSpec(channels, platforms, conda_deps, pip_deps)
[docs]
def write_conda_environment_file(
env_spec: CondaEnvironmentSpec,
output_file: str | Path | None = "environment.yaml",
name: str = "myenv",
*,
verbose: bool = False,
) -> None:
"""Generate a conda environment.yaml file or print to stdout."""
resolved_dependencies = deepcopy(env_spec.conda)
if env_spec.pip:
resolved_dependencies.append({"pip": env_spec.pip}) # type: ignore[arg-type, dict-item]
env_data = CommentedMap({"name": name})
if env_spec.channels:
env_data["channels"] = env_spec.channels
if resolved_dependencies:
env_data["dependencies"] = resolved_dependencies
if env_spec.platforms:
env_data["platforms"] = env_spec.platforms
yaml = YAML(typ="rt")
yaml.default_flow_style = False
yaml.width = 4096
yaml.indent(mapping=2, sequence=2, offset=2)
if output_file:
if verbose:
print(f"📝 Generating environment file at `{output_file}`")
with open(output_file, "w") as f: # noqa: PTH123
yaml.dump(env_data, f)
if verbose:
print("📝 Environment file generated successfully.")
add_comment_to_file(output_file)
else:
yaml.dump(env_data, sys.stdout)