Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 49 additions & 49 deletions black.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
Priority = int
Index = int
LN = Union[Leaf, Node]
SplitFunc = Callable[["Line", bool], Iterator["Line"]]
SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
Timestamp = float
FileSize = int
CacheInfo = Tuple[Timestamp, FileSize]
Expand Down Expand Up @@ -133,31 +133,35 @@ class Feature(Enum):
UNICODE_LITERALS = 1
F_STRINGS = 2
NUMERIC_UNDERSCORES = 3
TRAILING_COMMA = 4
TRAILING_COMMA_IN_CALL = 4
TRAILING_COMMA_IN_DEF = 5


VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
TargetVersion.PY27: set(),
TargetVersion.PY33: {Feature.UNICODE_LITERALS},
TargetVersion.PY34: {Feature.UNICODE_LITERALS},
TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
TargetVersion.PY36: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
},
TargetVersion.PY37: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
},
TargetVersion.PY38: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
Feature.TRAILING_COMMA,
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
},
}

Expand Down Expand Up @@ -683,16 +687,19 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
elt = EmptyLineTracker(is_pyi=mode.is_pyi)
empty_line = Line()
after = 0
split_line_features = {
feature
for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
if supports_feature(versions, feature)
}
for current_line in lines.visit(src_node):
for _ in range(after):
dst_contents += str(empty_line)
before, after = elt.maybe_empty_lines(current_line)
for _ in range(before):
dst_contents += str(empty_line)
for line in split_line(
current_line,
line_length=mode.line_length,
supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
current_line, line_length=mode.line_length, features=split_line_features
):
dst_contents += str(line)
return dst_contents
Expand Down Expand Up @@ -2162,7 +2169,7 @@ def split_line(
line: Line,
line_length: int,
inner: bool = False,
supports_trailing_commas: bool = False,
features: Collection[Feature] = (),
) -> Iterator[Line]:
"""Split a `line` into potentially many lines.

Expand All @@ -2171,7 +2178,7 @@ def split_line(
current `line`, possibly transitively. This means we can fallback to splitting
by delimiters if the LHS/RHS don't yield any results.

If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
`features` are syntactical features that may be used in the output.
"""
if line.is_comment:
yield line
Expand All @@ -2192,21 +2199,17 @@ def split_line(
split_funcs = [left_hand_split]
else:

def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
for omit in generate_trailers_to_omit(line, line_length):
lines = list(
right_hand_split(
line, line_length, supports_trailing_commas, omit=omit
)
)
lines = list(right_hand_split(line, line_length, features, omit=omit))
if is_line_short_enough(lines[0], line_length=line_length):
yield from lines
return

# All splits failed, best effort split with no omits.
# This mostly happens to multiline strings that are by definition
# reported as not fitting a single line.
yield from right_hand_split(line, line_length, supports_trailing_commas)
yield from right_hand_split(line, line_length, features=features)

if line.inside_brackets:
split_funcs = [delimiter_split, standalone_comment_split, rhs]
Expand All @@ -2218,16 +2221,13 @@ def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
# split altogether.
result: List[Line] = []
try:
for l in split_func(line, supports_trailing_commas):
for l in split_func(line, features):
if str(l).strip("\n") == line_str:
raise CannotSplit("Split function returned an unchanged result")

result.extend(
split_line(
l,
line_length=line_length,
inner=True,
supports_trailing_commas=supports_trailing_commas,
l, line_length=line_length, inner=True, features=features
)
)
except CannotSplit:
Expand All @@ -2241,9 +2241,7 @@ def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
yield line


def left_hand_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair.

Note: this usually looks weird, only use this for function definitions.
Expand Down Expand Up @@ -2282,7 +2280,7 @@ def left_hand_split(
def right_hand_split(
line: Line,
line_length: int,
supports_trailing_commas: bool = False,
features: Collection[Feature] = (),
omit: Collection[LeafID] = (),
) -> Iterator[Line]:
"""Split line into many lines, starting with the last matching bracket pair.
Expand Down Expand Up @@ -2341,12 +2339,7 @@ def right_hand_split(
):
omit = {id(closing_bracket), *omit}
try:
yield from right_hand_split(
line,
line_length,
supports_trailing_commas=supports_trailing_commas,
omit=omit,
)
yield from right_hand_split(line, line_length, features=features, omit=omit)
return

except CannotSplit:
Expand Down Expand Up @@ -2435,24 +2428,20 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
"""

@wraps(split_func)
def split_wrapper(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
for l in split_func(line, supports_trailing_commas):
def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
for l in split_func(line, features):
normalize_prefix(l.leaves[0], inside_brackets=True)
yield l

return split_wrapper


@dont_increase_indentation
def delimiter_split(
line: Line, supports_trailing_commas: bool = False
) -> Iterator[Line]:
def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
"""Split according to delimiters of the highest priority.

If `supports_trailing_commas` is True, the split will add trailing commas
also in function signatures that contain `*` and `**`.
If the appropriate Features are given, the split will add trailing commas
also in function signatures and calls that contain `*` and `**`.
"""
try:
last_leaf = line.leaves[-1]
Expand Down Expand Up @@ -2491,10 +2480,16 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:
yield from append_to_line(comment_after)

lowest_depth = min(lowest_depth, leaf.bracket_depth)
if leaf.bracket_depth == lowest_depth and is_vararg(
leaf, within=VARARGS_PARENTS
):
trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
if leaf.bracket_depth == lowest_depth:
if is_vararg(leaf, within={syms.typedargslist}):
trailing_comma_safe = (
trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
)
elif is_vararg(leaf, within={syms.arglist, syms.argument}):
trailing_comma_safe = (
trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
)

leaf_priority = bt.delimiters.get(id(leaf))
if leaf_priority == delimiter_priority:
yield current_line
Expand All @@ -2513,7 +2508,7 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]:

@dont_increase_indentation
def standalone_comment_split(
line: Line, supports_trailing_commas: bool = False
line: Line, features: Collection[Feature] = ()
) -> Iterator[Line]:
"""Split standalone comments from the rest of the line."""
if not line.contains_standalone_comments(0):
Expand Down Expand Up @@ -3063,14 +3058,19 @@ def get_features_used(node: Node) -> Set[Feature]:
and n.children
and n.children[-1].type == token.COMMA
):
if n.type == syms.typedargslist:
feature = Feature.TRAILING_COMMA_IN_DEF
else:
feature = Feature.TRAILING_COMMA_IN_CALL

for ch in n.children:
if ch.type in STARS:
features.add(Feature.TRAILING_COMMA)
features.add(feature)

if ch.type == syms.argument:
for argch in ch.children:
if argch.type in STARS:
features.add(Feature.TRAILING_COMMA)
features.add(feature)

return features

Expand Down
23 changes: 14 additions & 9 deletions tests/test_black.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,11 @@ def test_get_features_used(self) -> None:
node = black.lib2to3_parse("def f(*, arg): ...\n")
self.assertEqual(black.get_features_used(node), set())
node = black.lib2to3_parse("def f(*, arg,): ...\n")
self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
node = black.lib2to3_parse("f(*arg,)\n")
self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
)
node = black.lib2to3_parse("def f(*, arg): f'string'\n")
self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
node = black.lib2to3_parse("123_456\n")
Expand All @@ -841,13 +845,14 @@ def test_get_features_used(self) -> None:
self.assertEqual(black.get_features_used(node), set())
source, expected = read_data("function")
node = black.lib2to3_parse(source)
self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
)
expected_features = {
Feature.TRAILING_COMMA_IN_CALL,
Feature.TRAILING_COMMA_IN_DEF,
Feature.F_STRINGS,
}
self.assertEqual(black.get_features_used(node), expected_features)
node = black.lib2to3_parse(expected)
self.assertEqual(
black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
)
self.assertEqual(black.get_features_used(node), expected_features)
source, expected = read_data("expression")
node = black.lib2to3_parse(source)
self.assertEqual(black.get_features_used(node), set())
Expand Down Expand Up @@ -1499,8 +1504,8 @@ async def check(header_value: str, expected_status: int) -> None:

await check("3.6", 200)
await check("py3.6", 200)
await check("3.5,3.7", 200)
await check("3.5,py3.7", 200)
await check("3.6,3.7", 200)
await check("3.6,py3.7", 200)

await check("2", 204)
await check("2.7", 204)
Expand Down