Skip to content

Commit

Permalink
Merge pull request #278 from python/feature/entry-points-by-group-and…
Browse files Browse the repository at this point in the history
…-name

Feature/entry points by group and name
  • Loading branch information
jaraco authored Feb 24, 2021
2 parents 61a265c + bdce7ef commit 73cf0a9
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 30 deletions.
36 changes: 36 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,39 @@
v3.6.0
======

* #284: Introduces new ``EntryPoints`` object, a tuple of
``EntryPoint`` objects but with convenience properties for
selecting and inspecting the results:

- ``.select()`` accepts ``group`` or ``name`` keyword
parameters and returns a new ``EntryPoints`` tuple
with only those that match the selection.
- ``.groups`` property presents all of the group names.
- ``.names`` property presents the names of the entry points.
- Item access (e.g. ``eps[name]``) retrieves a single
entry point by name.

``entry_points`` now accepts "selection parameters",
same as ``EntryPoint.select()``.

``entry_points()`` now provides a future-compatible
``SelectableGroups`` object that supplies the above interface
but remains a dict for compatibility.

In the future, ``entry_points()`` will return an
``EntryPoints`` object, but provide for backward
compatibility with a deprecated ``__getitem__``
accessor by group and a ``get()`` method.

If passing selection parameters to ``entry_points``, the
future behavior is invoked and an ``EntryPoints`` is the
result.

Construction of entry points using
``dict([EntryPoint, ...])`` is now deprecated and raises
an appropriate DeprecationWarning and will be removed in
a future version.

v3.5.0
======

Expand Down
12 changes: 7 additions & 5 deletions docs/using.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,20 @@ This package provides the following functionality via its public API.
Entry points
------------

The ``entry_points()`` function returns a dictionary of all entry points,
keyed by group. Entry points are represented by ``EntryPoint`` instances;
The ``entry_points()`` function returns a collection of entry points.
Entry points are represented by ``EntryPoint`` instances;
each ``EntryPoint`` has a ``.name``, ``.group``, and ``.value`` attributes and
a ``.load()`` method to resolve the value. There are also ``.module``,
``.attr``, and ``.extras`` attributes for getting the components of the
``.value`` attribute::

>>> eps = entry_points()
>>> list(eps)
>>> sorted(eps.groups)
['console_scripts', 'distutils.commands', 'distutils.setup_keywords', 'egg_info.writers', 'setuptools.installation']
>>> scripts = eps['console_scripts']
>>> wheel = [ep for ep in scripts if ep.name == 'wheel'][0]
>>> scripts = eps.select(group='console_scripts')
>>> 'wheel' in scripts.names
True
>>> wheel = scripts['wheel']
>>> wheel
EntryPoint(name='wheel', value='wheel.cli:main', group='console_scripts')
>>> wheel.module
Expand Down
151 changes: 136 additions & 15 deletions importlib_metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import sys
import zipp
import email
import inspect
import pathlib
import operator
import warnings
import functools
import itertools
import posixpath
Expand Down Expand Up @@ -130,22 +132,19 @@ def _from_text(cls, text):
config.read_string(text)
return cls._from_config(config)

@classmethod
def _from_text_for(cls, text, dist):
return (ep._for(dist) for ep in cls._from_text(text))

def _for(self, dist):
self.dist = dist
return self

def __iter__(self):
"""
Supply iter so one may construct dicts of EntryPoints by name.
>>> eps = [EntryPoint('a', 'b', 'c'), EntryPoint('d', 'e', 'f')]
>>> dict(eps)['a']
EntryPoint(name='a', value='b', group='c')
"""
msg = (
"Construction of dict of EntryPoints is deprecated in "
"favor of EntryPoints."
)
warnings.warn(msg, DeprecationWarning)
return iter((self.name, self))

def __reduce__(self):
Expand All @@ -154,6 +153,118 @@ def __reduce__(self):
(self.name, self.value, self.group),
)

def matches(self, **params):
attrs = (getattr(self, param) for param in params)
return all(map(operator.eq, params.values(), attrs))


class EntryPoints(tuple):
"""
An immutable collection of selectable EntryPoint objects.
"""

__slots__ = ()

def __getitem__(self, name): # -> EntryPoint:
try:
return next(iter(self.select(name=name)))
except StopIteration:
raise KeyError(name)

def select(self, **params):
return EntryPoints(ep for ep in self if ep.matches(**params))

@property
def names(self):
return set(ep.name for ep in self)

@property
def groups(self):
"""
For coverage while SelectableGroups is present.
>>> EntryPoints().groups
set()
"""
return set(ep.group for ep in self)

@classmethod
def _from_text_for(cls, text, dist):
return cls(ep._for(dist) for ep in EntryPoint._from_text(text))


class SelectableGroups(dict):
"""
A backward- and forward-compatible result from
entry_points that fully implements the dict interface.
"""

@classmethod
def load(cls, eps):
by_group = operator.attrgetter('group')
ordered = sorted(eps, key=by_group)
grouped = itertools.groupby(ordered, by_group)
return cls((group, EntryPoints(eps)) for group, eps in grouped)

@property
def _all(self):
return EntryPoints(itertools.chain.from_iterable(self.values()))

@property
def groups(self):
return self._all.groups

@property
def names(self):
"""
for coverage:
>>> SelectableGroups().names
set()
"""
return self._all.names

def select(self, **params):
if not params:
return self
return self._all.select(**params)


class LegacyGroupedEntryPoints(EntryPoints): # pragma: nocover
"""
Compatibility wrapper around EntryPoints to provide
much of the 'dict' interface previously returned by
entry_points.
"""

def __getitem__(self, name) -> Union[EntryPoint, 'EntryPoints']:
"""
When accessed by name that matches a group, return the group.
"""
group = self.select(group=name)
if group:
msg = "GroupedEntryPoints.__getitem__ is deprecated for groups. Use select."
warnings.warn(msg, DeprecationWarning, stacklevel=2)
return group

return super().__getitem__(name)

def get(self, group, default=None):
"""
For backward compatibility, supply .get.
"""
is_flake8 = any('flake8' in str(frame) for frame in inspect.stack())
msg = "GroupedEntryPoints.get is deprecated. Use select."
is_flake8 or warnings.warn(msg, DeprecationWarning, stacklevel=2)
return self.select(group=group) or default

def select(self, **params):
"""
Prevent transform to EntryPoints during call to entry_points if
no selection parameters were passed.
"""
if not params:
return self
return super().select(**params)


class PackagePath(pathlib.PurePosixPath):
"""A reference to a path in a package"""
Expand Down Expand Up @@ -310,7 +421,7 @@ def version(self):

@property
def entry_points(self):
return list(EntryPoint._from_text_for(self.read_text('entry_points.txt'), self))
return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self)

@property
def files(self):
Expand Down Expand Up @@ -643,19 +754,29 @@ def version(distribution_name):
return distribution(distribution_name).version


def entry_points():
def entry_points(**params) -> Union[EntryPoints, SelectableGroups]:
"""Return EntryPoint objects for all installed packages.
:return: EntryPoint objects for all installed packages.
Pass selection parameters (group or name) to filter the
result to entry points matching those properties (see
EntryPoints.select()).
For compatibility, returns ``SelectableGroups`` object unless
selection parameters are supplied. In the future, this function
will return ``LegacyGroupedEntryPoints`` instead of
``SelectableGroups`` and eventually will only return
``EntryPoints``.
For maximum future compatibility, pass selection parameters
or invoke ``.select`` with parameters on the result.
:return: EntryPoints or SelectableGroups for all installed packages.
"""
unique = functools.partial(unique_everseen, key=operator.attrgetter('name'))
eps = itertools.chain.from_iterable(
dist.entry_points for dist in unique(distributions())
)
by_group = operator.attrgetter('group')
ordered = sorted(eps, key=by_group)
grouped = itertools.groupby(ordered, by_group)
return {group: tuple(eps) for group, eps in grouped}
return SelectableGroups.load(eps).select(**params)


def files(distribution_name):
Expand Down
59 changes: 55 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import textwrap
import unittest
import warnings

from . import fixtures
from importlib_metadata import (
Expand Down Expand Up @@ -64,13 +65,16 @@ def test_read_text(self):
self.assertEqual(top_level.read_text(), 'mod\n')

def test_entry_points(self):
entries = dict(entry_points()['entries'])
eps = entry_points()
assert 'entries' in eps.groups
entries = eps.select(group='entries')
assert 'main' in entries.names
ep = entries['main']
self.assertEqual(ep.value, 'mod:main')
self.assertEqual(ep.extras, [])

def test_entry_points_distribution(self):
entries = dict(entry_points()['entries'])
entries = entry_points(group='entries')
for entry in ("main", "ns:sub"):
ep = entries[entry]
self.assertIn(ep.dist.name, ('distinfo-pkg', 'egginfo-pkg'))
Expand All @@ -96,14 +100,61 @@ def test_entry_points_unique_packages(self):
},
}
fixtures.build_files(alt_pkg, alt_site_dir)
entries = dict(entry_points()['entries'])
entries = entry_points(group='entries')
assert not any(
ep.dist.name == 'distinfo-pkg' and ep.dist.version == '1.0.0'
for ep in entries.values()
for ep in entries
)
# ns:sub doesn't exist in alt_pkg
assert 'ns:sub' not in entries

def test_entry_points_missing_name(self):
with self.assertRaises(KeyError):
entry_points(group='entries')['missing']

def test_entry_points_missing_group(self):
assert entry_points(group='missing') == ()

def test_entry_points_dict_construction(self):
"""
Prior versions of entry_points() returned simple lists and
allowed casting those lists into maps by name using ``dict()``.
Capture this now deprecated use-case.
"""
with warnings.catch_warnings(record=True) as caught:
eps = dict(entry_points(group='entries'))

assert 'main' in eps
assert eps['main'] == entry_points(group='entries')['main']

# check warning
expected = next(iter(caught))
assert expected.category is DeprecationWarning
assert "Construction of dict of EntryPoints is deprecated" in str(expected)

def test_entry_points_groups_getitem(self):
"""
Prior versions of entry_points() returned a dict. Ensure
that callers using '.__getitem__()' are supported but warned to
migrate.
"""
with warnings.catch_warnings(record=True):
entry_points()['entries'] == entry_points(group='entries')

with self.assertRaises(KeyError):
entry_points()['missing']

def test_entry_points_groups_get(self):
"""
Prior versions of entry_points() returned a dict. Ensure
that callers using '.get()' are supported but warned to
migrate.
"""
with warnings.catch_warnings(record=True):
entry_points().get('missing', 'default') == 'default'
entry_points().get('entries', 'default') == entry_points()['entries']
entry_points().get('missing', ()) == ()

def test_metadata_for_this_package(self):
md = metadata('egginfo-pkg')
assert md['author'] == 'Steven Ma'
Expand Down
10 changes: 5 additions & 5 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
import textwrap
import unittest
import warnings
import importlib
import importlib_metadata
import pyfakefs.fake_filesystem_unittest as ffs
Expand Down Expand Up @@ -57,13 +58,11 @@ def test_import_nonexistent_module(self):
importlib.import_module('does_not_exist')

def test_resolve(self):
entries = dict(entry_points()['entries'])
ep = entries['main']
ep = entry_points(group='entries')['main']
self.assertEqual(ep.load().__name__, "main")

def test_entrypoint_with_colon_in_name(self):
entries = dict(entry_points()['entries'])
ep = entries['ns:sub']
ep = entry_points(group='entries')['ns:sub']
self.assertEqual(ep.value, 'mod:main')

def test_resolve_without_attr(self):
Expand Down Expand Up @@ -249,7 +248,8 @@ def test_json_dump(self):
json should not expect to be able to dump an EntryPoint
"""
with self.assertRaises(Exception):
json.dumps(self.ep)
with warnings.catch_warnings(record=True):
json.dumps(self.ep)

def test_module(self):
assert self.ep.module == 'value'
Expand Down
Loading

0 comments on commit 73cf0a9

Please sign in to comment.