Skip to content

Commit

Permalink
feat: Add _operation variable
Browse files Browse the repository at this point in the history
  • Loading branch information
lkubb committed Nov 12, 2024
1 parent 2e7629e commit 2a80d01
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 4 deletions.
56 changes: 52 additions & 4 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import subprocess
import sys
from contextlib import suppress
from contextvars import ContextVar
from dataclasses import asdict, field, replace
from filecmp import dircmp
from functools import cached_property, partial
from functools import cached_property, partial, wraps
from itertools import chain
from pathlib import Path
from shutil import rmtree
Expand Down Expand Up @@ -60,13 +61,38 @@
MISSING,
AnyByStrDict,
JSONSerializable,
Operation,
ParamSpec,
RelativePath,
StrOrPath,
)
from .user_data import DEFAULT_DATA, AnswersMap, Question
from .vcs import get_git

_T = TypeVar("_T")
_P = ParamSpec("_P")

_operation: ContextVar[Operation] = ContextVar("_operation")


def as_operation(value: Operation) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""Decorator to set the current operation context, if not defined already.
This value is used to template specific configuration options.
"""

def _decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
@wraps(func)
def _wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
token = _operation.set(_operation.get(value))
try:
return func(*args, **kwargs)
finally:
_operation.reset(token)

return _wrapper

return _decorator


@dataclass(config=ConfigDict(extra="forbid"))
Expand Down Expand Up @@ -243,7 +269,7 @@ def _cleanup(self) -> None:
for method in self._cleanup_hooks:
method()

def _check_unsafe(self, mode: Literal["copy", "update"]) -> None:
def _check_unsafe(self, mode: Operation) -> None:
"""Check whether a template uses unsafe features."""
if self.unsafe:
return
Expand Down Expand Up @@ -296,8 +322,10 @@ def _execute_tasks(self, tasks: Sequence[Task]) -> None:
Arguments:
tasks: The list of tasks to run.
"""
operation = _operation.get()
for i, task in enumerate(tasks):
extra_context = {f"_{k}": v for k, v in task.extra_vars.items()}
extra_context["_operation"] = operation

if not cast_to_bool(self._render_value(task.condition, extra_context)):
continue
Expand Down Expand Up @@ -327,7 +355,7 @@ def _execute_tasks(self, tasks: Sequence[Task]) -> None:
/ Path(self._render_string(str(task.working_directory), extra_context))
).absolute()

extra_env = {k.upper(): str(v) for k, v in task.extra_vars.items()}
extra_env = {k[1:].upper(): str(v) for k, v in extra_context.items()}
with local.cwd(working_directory), local.env(**extra_env):
subprocess.run(task_cmd, shell=use_shell, check=True, env=local.env)

Expand Down Expand Up @@ -588,7 +616,14 @@ def _pathjoin(
@cached_property
def match_exclude(self) -> Callable[[Path], bool]:
"""Get a callable to match paths against all exclusions."""
return self._path_matcher(self.all_exclusions)
# Include the current operation in the rendering context.
# Note: This method is a cached property, it needs to be regenerated
# when reusing an instance in different contexts.
extra_context = {"_operation": _operation.get()}
return self._path_matcher(
self._render_string(exclusion, extra_context=extra_context)
for exclusion in self.all_exclusions
)

@cached_property
def match_skip(self) -> Callable[[Path], bool]:
Expand Down Expand Up @@ -818,6 +853,7 @@ def template_copy_root(self) -> Path:
return self.template.local_abspath / subdir

# Main operations
@as_operation("copy")
def run_copy(self) -> None:
"""Generate a subproject from zero, ignoring what was in the folder.
Expand All @@ -828,6 +864,11 @@ def run_copy(self) -> None:
See [generating a project][generating-a-project].
"""
with suppress(AttributeError):
# We might have switched operation context, ensure the cached property
# is regenerated to re-render templates.
del self.match_exclude

self._check_unsafe("copy")
self._print_message(self.template.message_before_copy)
self._ask()
Expand All @@ -854,6 +895,7 @@ def run_copy(self) -> None:
# TODO Unify printing tools
print("") # padding space

@as_operation("copy")
def run_recopy(self) -> None:
"""Update a subproject, keeping answers but discarding evolution."""
if self.subproject.template is None:
Expand All @@ -864,6 +906,7 @@ def run_recopy(self) -> None:
with replace(self, src_path=self.subproject.template.url) as new_worker:
new_worker.run_copy()

@as_operation("update")
def run_update(self) -> None:
"""Update a subproject that was already generated.
Expand Down Expand Up @@ -911,6 +954,11 @@ def run_update(self) -> None:
print(
f"Updating to template version {self.template.version}", file=sys.stderr
)
with suppress(AttributeError):
# We might have switched operation context, ensure the cached property
# is regenerated to re-render templates.
del self.match_exclude

self._apply_update()
self._print_message(self.template.message_after_update)

Expand Down
7 changes: 7 additions & 0 deletions copier/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Complex types, annotations, validators."""

import sys
from pathlib import Path
from typing import (
Annotated,
Expand All @@ -16,6 +17,11 @@

from pydantic import AfterValidator

if sys.version_info >= (3, 10):
from typing import ParamSpec as ParamSpec
else:
from typing_extensions import ParamSpec as ParamSpec

Check warning on line 23 in copier/types.py

View check run for this annotation

Codecov / codecov/patch

copier/types.py#L23

Added line #L23 was not covered by tests

# simple types
StrOrPath = Union[str, Path]
AnyByStrDict = Dict[str, Any]
Expand All @@ -35,6 +41,7 @@
Env = Mapping[str, str]
MissingType = NewType("MissingType", object)
MISSING = MissingType(object())
Operation = Literal["copy", "update"]


# Validators
Expand Down
17 changes: 17 additions & 0 deletions docs/configuring.md
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,18 @@ to know available options.

The CLI option can be passed several times to add several patterns.

Each pattern can be templated using Jinja.

!!! example

Templating `exclude` patterns using `_operation` allows to have files
that are rendered once during `copy`, but are never updated:

```yaml
_exclude:
- "{% if _operation == 'update' -%}src/*_example.py{% endif %}"
```

!!! info

When you define this parameter in `copier.yml`, it will **replace** the default
Expand Down Expand Up @@ -1351,6 +1363,8 @@ configuring `secret: true` in the [advanced prompt format][advanced-prompt-forma
exist, but always be present. If they do not exist in a project during an `update`
operation, they will be recreated.

Each pattern can be templated using Jinja.

!!! example

For example, it can be used if your project generates a password the 1st time and
Expand Down Expand Up @@ -1501,6 +1515,9 @@ other items not present.
- [invoke, end-process, "--full-conf={{ _copier_conf|to_json }}"]
# Your script can be run by the same Python environment used to run Copier
- ["{{ _copier_python }}", task.py]
# Run a command during the initial copy operation only, excluding updates
- command: ["{{ _copier_python }}", task.py]
when: "{{ _operation == 'copy' }}"
# OS-specific task (supported values are "linux", "macos", "windows" and `None`)
- command: rm {{ name_of_the_project }}/README.md
when: "{{ _copier_conf.os in ['linux', 'macos'] }}"
Expand Down
10 changes: 10 additions & 0 deletions docs/creating.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ The absolute path of the Python interpreter running Copier.

The name of the project root directory.

## Variables (context-dependent)

Some variables are only available in select contexts:

### `_operation`

The current operation, either `"copy"` or `"update"`.

Availability: [`exclude`](configuring.md#exclude), [`tasks`](configuring.md#tasks)

## Variables (context-specific)

Some rendering contexts provide variables unique to them:
Expand Down
90 changes: 90 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import json
from pathlib import Path

import pytest
from plumbum import local

import copier

from .helpers import build_file_tree, git_save


def test_exclude_templating_with_operation(
tmp_path_factory: pytest.TempPathFactory,
) -> None:
"""
Ensure it's possible to create one-off boilerplate files that are not
managed during updates via `_exclude` using the `_operation` context variable.
"""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))

template = "{% if _operation == 'update' %}copy-only{% endif %}"
with local.cwd(src):
build_file_tree(
{
"copier.yml": f'_exclude:\n - "{template}"',
"{{ _copier_conf.answers_file }}.jinja": "{{ _copier_answers|to_yaml }}",
"copy-only": "foo",
"copy-and-update": "foo",
}
)
git_save(tag="1.0.0")
build_file_tree(
{
"copy-only": "bar",
"copy-and-update": "bar",
}
)
git_save(tag="2.0.0")
copy_only = dst / "copy-only"
copy_and_update = dst / "copy-and-update"

copier.run_copy(str(src), dst, defaults=True, overwrite=True, vcs_ref="1.0.0")
for file in (copy_only, copy_and_update):
assert file.exists()
assert file.read_text() == "foo"

with local.cwd(dst):
git_save()

copier.run_update(str(dst), overwrite=True)
assert copy_only.read_text() == "foo"
assert copy_and_update.read_text() == "bar"


def test_task_templating_with_operation(
tmp_path_factory: pytest.TempPathFactory, tmp_path: Path
) -> None:
"""
Ensure that it is possible to define tasks that are only executed when copying.
"""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
# Use a file outside the Copier working directories to ensure accurate tracking
task_counter = tmp_path / "task_calls.txt"
with local.cwd(src):
build_file_tree(
{
"copier.yml": (
f"""\
_tasks:
- command: echo {{{{ _operation }}}} >> {json.dumps(str(task_counter))}
when: "{{{{ _operation == 'copy' }}}}"
"""
),
"{{ _copier_conf.answers_file }}.jinja": "{{ _copier_answers|to_yaml }}",
}
)
git_save(tag="1.0.0")

copier.run_copy(str(src), dst, defaults=True, overwrite=True, unsafe=True)
assert task_counter.exists()
assert len(task_counter.read_text().splitlines()) == 1

with local.cwd(dst):
git_save()

copier.run_recopy(dst, defaults=True, overwrite=True, unsafe=True)
assert len(task_counter.read_text().splitlines()) == 2

copier.run_update(dst, defaults=True, overwrite=True, unsafe=True)
assert len(task_counter.read_text().splitlines()) == 2

0 comments on commit 2a80d01

Please sign in to comment.