diff --git a/README.md b/README.md
index 73551459..6bb89e2f 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,15 @@
-py-tree-sitter
-==================
+# py-tree-sitter
[![Build Status](https://github.com/tree-sitter/py-tree-sitter/actions/workflows/ci.yml/badge.svg)](https://github.com/tree-sitter/py-tree-sitter/actions/workflows/ci.yml)
[![Build status](https://ci.appveyor.com/api/projects/status/mde790v0v9gux85w/branch/master?svg=true)](https://ci.appveyor.com/project/maxbrunsfeld/py-tree-sitter/branch/master)
-This module provides Python bindings to the [tree-sitter](https://github.com/tree-sitter/tree-sitter) parsing library.
+This module provides Python bindings to the [tree-sitter](https://github.com/tree-sitter/tree-sitter)
+parsing library.
## Installation
-This package currently only works with Python 3. There are no library dependencies, but you do need to have a C compiler installed.
+This package currently only works with Python 3. There are no library dependencies,
+but you do need to have a C compiler installed.
```sh
pip3 install tree_sitter
@@ -16,9 +17,11 @@ pip3 install tree_sitter
## Usage
-#### Setup
+### Setup
-First you'll need a Tree-sitter language implementation for each language that you want to parse. You can clone some of the [existing language repos](https://github.com/tree-sitter) or [create your own](http://tree-sitter.github.io/tree-sitter/creating-parsers):
+First you'll need a Tree-sitter language implementation for each language that you
+want to parse. You can clone some of the [existing language repos](https://github.com/tree-sitter)
+or [create your own](http://tree-sitter.github.io/tree-sitter/creating-parsers):
```sh
git clone https://github.com/tree-sitter/tree-sitter-go
@@ -26,30 +29,27 @@ git clone https://github.com/tree-sitter/tree-sitter-javascript
git clone https://github.com/tree-sitter/tree-sitter-python
```
-Use the `Language.build_library` method to compile these into a library that's usable from Python. This function will return immediately if the library has already been compiled since the last time its source code was modified:
+Use the `Language.build_library` method to compile these into a library that's
+usable from Python. This function will return immediately if the library has
+already been compiled since the last time its source code was modified:
```python
-from tree_sitter import Language, Parser
+from tree_sitter import Language
Language.build_library(
- # Store the library in the `build` directory
- 'build/my-languages.so',
-
- # Include one or more languages
- [
- 'vendor/tree-sitter-go',
- 'vendor/tree-sitter-javascript',
- 'vendor/tree-sitter-python'
- ]
+ # Store the library in the `build` directory
+ "build/my-languages.so",
+ # Include one or more languages
+ ["vendor/tree-sitter-go", "vendor/tree-sitter-javascript", "vendor/tree-sitter-python"],
)
```
Load the languages into your app as `Language` objects:
```python
-GO_LANGUAGE = Language('build/my-languages.so', 'go')
-JS_LANGUAGE = Language('build/my-languages.so', 'javascript')
-PY_LANGUAGE = Language('build/my-languages.so', 'python')
+GO_LANGUAGE = Language("build/my-languages.so", "go")
+JS_LANGUAGE = Language("build/my-languages.so", "javascript")
+PY_LANGUAGE = Language("build/my-languages.so", "python")
```
#### Basic Parsing
@@ -64,11 +64,16 @@ parser.set_language(PY_LANGUAGE)
Parse some source code:
```python
-tree = parser.parse(bytes("""
+tree = parser.parse(
+ bytes(
+ """
def foo():
if bar:
baz()
-""", "utf8"))
+""",
+ "utf8",
+ )
+)
```
If you have your source code in some data structure other than a bytes object,
@@ -81,14 +86,19 @@ terminates parsing for that line. The bytes must encode the source as UTF-8.
For example, to use the byte offset:
```python
-src = bytes("""
+src = bytes(
+ """
def foo():
if bar:
baz()
-""", "utf8")
+""",
+ "utf8",
+)
+
def read_callable(byte_offset, point):
- return src[byte_offset:byte_offset+1]
+ return src[byte_offset : byte_offset + 1]
+
tree = parser.parse(read_callable)
```
@@ -98,11 +108,13 @@ And to use the point:
```python
src_lines = ["def foo():\n", " if bar:\n", " baz()"]
+
def read_callable(byte_offset, point):
row, column = point
if row >= len(src_lines) or column >= len(src_lines[row]):
return None
- return src_lines[row][column:].encode('utf8')
+ return src_lines[row][column:].encode("utf8")
+
tree = parser.parse(read_callable)
```
@@ -145,30 +157,31 @@ a `TreeCursor`:
```python
cursor = tree.walk()
-assert cursor.node.type == 'module'
+assert cursor.node.type == "module"
assert cursor.goto_first_child()
-assert cursor.node.type == 'function_definition'
+assert cursor.node.type == "function_definition"
assert cursor.goto_first_child()
-assert cursor.node.type == 'def'
+assert cursor.node.type == "def"
# Returns `False` because the `def` node has no children
assert not cursor.goto_first_child()
assert cursor.goto_next_sibling()
-assert cursor.node.type == 'identifier'
+assert cursor.node.type == "identifier"
assert cursor.goto_next_sibling()
-assert cursor.node.type == 'parameters'
+assert cursor.node.type == "parameters"
assert cursor.goto_parent()
-assert cursor.node.type == 'function_definition'
+assert cursor.node.type == "function_definition"
```
#### Editing
-When a source file is edited, you can edit the syntax tree to keep it in sync with the source:
+When a source file is edited, you can edit the syntax tree to keep it in sync with
+the source:
```python
tree.edit(
@@ -190,30 +203,32 @@ new_tree = parser.parse(new_source, tree)
This will run much faster than if you were parsing from scratch.
-The `Tree.get_changed_ranges` method can be called on the *old* tree to return
+The `Tree.get_changed_ranges` method can be called on the _old_ tree to return
the list of ranges whose syntactic structure has been changed:
```python
for changed_range in tree.get_changed_ranges(new_tree):
- print('Changed range:')
- print(f' Start point {changed_range.start_point}')
- print(f' Start byte {changed_range.start_byte}')
- print(f' End point {changed_range.end_point}')
- print(f' End byte {changed_range.end_byte}')
+ print("Changed range:")
+ print(f" Start point {changed_range.start_point}")
+ print(f" Start byte {changed_range.start_byte}")
+ print(f" End point {changed_range.end_point}")
+ print(f" End byte {changed_range.end_byte}")
```
#### Pattern-matching
-You can search for patterns in a syntax tree using a *tree query*:
+You can search for patterns in a syntax tree using a _tree query_:
```python
-query = PY_LANGUAGE.query("""
+query = PY_LANGUAGE.query(
+ """
(function_definition
name: (identifier) @function.def)
(call
function: (identifier) @function.call)
-""")
+"""
+)
captures = query.captures(tree.root_node)
assert len(captures) == 2
diff --git a/pyproject.toml b/pyproject.toml
index 42bb7bf1..b8d7afb0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,9 @@
[build-system]
-requires = [
- "setuptools>=43.0.0",
- "wheel>=0.36.2",
-]
+requires = ["setuptools>=43.0.0", "wheel>=0.36.2"]
build-backend = "setuptools.build_meta"
+
+[tool.ruff]
+line-length = 100
+
+[tool.ruff.pycodestyle]
+max-line-length = 102
diff --git a/script/lint b/script/lint
index cdca731b..83be6c3a 100755
--- a/script/lint
+++ b/script/lint
@@ -1,4 +1,4 @@
#!/bin/bash
sources=$(git ls-files | grep '.py$')
-flake8 --max-line-length 100 $sources
+flake8 --max-line-length 102 --ignore=E203,W503 $sources
diff --git a/setup.py b/setup.py
index 81e045cf..ad5222de 100644
--- a/setup.py
+++ b/setup.py
@@ -29,9 +29,7 @@
"Topic :: Software Development :: Compilers",
"Topic :: Text Processing :: Linguistic",
],
- install_requires=[
- "setuptools>=60.0.0; python_version>='3.12'"
- ],
+ install_requires=["setuptools>=60.0.0; python_version>='3.12'"],
packages=["tree_sitter"],
package_data={"tree_sitter": ["py.typed", "*.pyi"]},
ext_modules=[
diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py
index 451e0fde..d98fe3fc 100644
--- a/tests/test_tree_sitter.py
+++ b/tests/test_tree_sitter.py
@@ -116,7 +116,7 @@ def test_multibyte_characters(self):
self.assertEqual(binary_node.type, "binary_expression")
self.assertEqual(snake_node.type, "string")
self.assertEqual(
- source_code[snake_node.start_byte:snake_node.end_byte].decode("utf8"),
+ source_code[snake_node.start_byte : snake_node.end_byte].decode("utf8"),
"'🐍'",
)
@@ -133,7 +133,7 @@ def test_multibyte_characters_via_read_callback(self):
source_code = bytes("'😎' && '🐍'", "utf8")
def read(byte_position, _):
- return source_code[byte_position:byte_position + 1]
+ return source_code[byte_position : byte_position + 1]
tree = parser.parse(read)
root_node = tree.root_node
@@ -144,7 +144,7 @@ def read(byte_position, _):
self.assertEqual(binary_node.type, "binary_expression")
self.assertEqual(snake_node.type, "string")
self.assertEqual(
- source_code[snake_node.start_byte:snake_node.end_byte].decode("utf8"),
+ source_code[snake_node.start_byte : snake_node.end_byte].decode("utf8"),
"'🐍'",
)
@@ -164,9 +164,9 @@ def test_parsing_with_one_included_range(self):
self.assertEqual(
js_tree.root_node.sexp(),
- "(program (expression_statement (call_expression " +
- "function: (member_expression object: (identifier) property: (property_identifier)) " +
- "arguments: (arguments (string (string_fragment))))))"
+ "(program (expression_statement (call_expression "
+ + "function: (member_expression object: (identifier) property: (property_identifier)) "
+ + "arguments: (arguments (string (string_fragment))))))",
)
self.assertEqual(js_tree.root_node.start_point, (0, source_code.index(b"console")))
self.assertEqual(js_tree.included_ranges, [script_content_node.range])
@@ -177,9 +177,9 @@ def test_parsing_with_multiple_included_ranges(self):
parser = Parser()
parser.set_language(JAVASCRIPT)
js_tree = parser.parse(source_code)
- template_string_node = js_tree \
- .root_node \
- .descendant_for_byte_range(source_code.index(b"
"), source_code.index(b"Hello"))
+ template_string_node = js_tree.root_node.descendant_for_byte_range(
+ source_code.index(b"
"), source_code.index(b"Hello")
+ )
if template_string_node is None:
self.fail("template_string_node is None")
@@ -203,19 +203,19 @@ def test_parsing_with_multiple_included_ranges(self):
start_byte=open_quote_node.end_byte,
start_point=open_quote_node.end_point,
end_byte=interpolation_node1.start_byte,
- end_point=interpolation_node1.start_point
+ end_point=interpolation_node1.start_point,
),
Range(
start_byte=interpolation_node1.end_byte,
start_point=interpolation_node1.end_point,
end_byte=interpolation_node2.start_byte,
- end_point=interpolation_node2.start_point
+ end_point=interpolation_node2.start_point,
),
Range(
start_byte=interpolation_node2.end_byte,
start_point=interpolation_node2.end_point,
end_byte=close_quote_node.start_byte,
- end_point=close_quote_node.start_point
+ end_point=close_quote_node.start_point,
),
]
parser.set_included_ranges(html_ranges)
@@ -224,12 +224,12 @@ def test_parsing_with_multiple_included_ranges(self):
self.assertEqual(
html_tree.root_node.sexp(),
- "(fragment (element" +
- " (start_tag (tag_name))" +
- " (text)" +
- " (element (start_tag (tag_name)) (end_tag (tag_name)))" +
- " (text)" +
- " (end_tag (tag_name))))"
+ "(fragment (element"
+ + " (start_tag (tag_name))"
+ + " (text)"
+ + " (element (start_tag (tag_name)) (end_tag (tag_name)))"
+ + " (text)"
+ + " (end_tag (tag_name))))",
)
self.assertEqual(html_tree.included_ranges, html_ranges)
@@ -282,42 +282,46 @@ def test_parsing_with_included_range_containing_mismatched_positions(self):
self.assertEqual(
html_tree.root_node.sexp(),
- "(fragment (element (start_tag (tag_name)) (text) (end_tag (tag_name))))"
+ "(fragment (element (start_tag (tag_name)) (text) (end_tag (tag_name))))",
)
def test_parsing_error_in_invalid_included_ranges(self):
parser = Parser()
with self.assertRaises(Exception):
- parser.set_included_ranges([
- Range(
- start_byte=23,
- end_byte=29,
- start_point=(0, 23),
- end_point=(0, 29),
- ),
- Range(
- start_byte=0,
- end_byte=5,
- start_point=(0, 0),
- end_point=(0, 5),
- ),
- Range(
- start_byte=50,
- end_byte=60,
- start_point=(0, 50),
- end_point=(0, 60),
- ),
- ])
+ parser.set_included_ranges(
+ [
+ Range(
+ start_byte=23,
+ end_byte=29,
+ start_point=(0, 23),
+ end_point=(0, 29),
+ ),
+ Range(
+ start_byte=0,
+ end_byte=5,
+ start_point=(0, 0),
+ end_point=(0, 5),
+ ),
+ Range(
+ start_byte=50,
+ end_byte=60,
+ start_point=(0, 50),
+ end_point=(0, 60),
+ ),
+ ]
+ )
with self.assertRaises(Exception):
- parser.set_included_ranges([
- Range(
- start_byte=10,
- end_byte=5,
- start_point=(0, 10),
- end_point=(0, 5),
- )
- ])
+ parser.set_included_ranges(
+ [
+ Range(
+ start_byte=10,
+ end_byte=5,
+ start_point=(0, 10),
+ end_point=(0, 5),
+ )
+ ]
+ )
def test_parsing_with_external_scanner_that_uses_included_range_boundaries(self):
source_code = b"a <%= b() %> c <% d() %>"
@@ -328,20 +332,22 @@ def test_parsing_with_external_scanner_that_uses_included_range_boundaries(self)
parser = Parser()
parser.set_language(JAVASCRIPT)
- parser.set_included_ranges([
- Range(
- start_byte=range1_start_byte,
- end_byte=range1_end_byte,
- start_point=(0, range1_start_byte),
- end_point=(0, range1_end_byte),
- ),
- Range(
- start_byte=range2_start_byte,
- end_byte=range2_end_byte,
- start_point=(0, range2_start_byte),
- end_point=(0, range2_end_byte),
- ),
- ])
+ parser.set_included_ranges(
+ [
+ Range(
+ start_byte=range1_start_byte,
+ end_byte=range1_end_byte,
+ start_point=(0, range1_start_byte),
+ end_point=(0, range1_end_byte),
+ ),
+ Range(
+ start_byte=range2_start_byte,
+ end_byte=range2_end_byte,
+ start_point=(0, range2_start_byte),
+ end_point=(0, range2_end_byte),
+ ),
+ ]
+ )
tree = parser.parse(source_code)
root = tree.root_node
@@ -355,11 +361,11 @@ def test_parsing_with_external_scanner_that_uses_included_range_boundaries(self)
self.assertEqual(
root.sexp(),
"(program"
- + " " +
- "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + " " +
- "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + ")"
+ + " "
+ + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
+ + " "
+ + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
+ + ")",
)
self.assertEqual(statement1.start_byte, source_code.index(b"b()"))
@@ -391,29 +397,31 @@ def test_parsing_with_a_newly_excluded_range(self):
directive_start = source_code.index(b"<%=")
directive_end = source_code.index(b"")
source_code_end = len(source_code)
- parser.set_included_ranges([
- Range(
- start_byte=0,
- end_byte=directive_start,
- start_point=(0, 0),
- end_point=(0, directive_start),
- ),
- Range(
- start_byte=directive_end,
- end_byte=source_code_end,
- start_point=(0, directive_end),
- end_point=(0, source_code_end),
- ),
- ])
+ parser.set_included_ranges(
+ [
+ Range(
+ start_byte=0,
+ end_byte=directive_start,
+ start_point=(0, 0),
+ end_point=(0, directive_start),
+ ),
+ Range(
+ start_byte=directive_end,
+ end_byte=source_code_end,
+ start_point=(0, directive_end),
+ end_point=(0, source_code_end),
+ ),
+ ]
+ )
tree = parser.parse(source_code, first_tree)
self.assertEqual(
tree.root_node.sexp(),
- "(fragment (text) (element" +
- " (start_tag (tag_name))" +
- " (element (start_tag (tag_name)) (end_tag (tag_name)))" +
- " (end_tag (tag_name))))"
+ "(fragment (text) (element"
+ + " (start_tag (tag_name))"
+ + " (element (start_tag (tag_name)) (end_tag (tag_name)))"
+ + " (end_tag (tag_name))))",
)
self.assertEqual(
@@ -435,7 +443,7 @@ def test_parsing_with_a_newly_excluded_range(self):
start_point=(0, directive_start),
end_point=(0, directive_end),
),
- ]
+ ],
)
def test_parsing_with_a_newly_included_range(self):
@@ -463,51 +471,55 @@ def simple_range(start: int, end: int) -> Range:
self.assertEqual(
tree.root_node.sexp(),
"(program"
- + " " +
- "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + ")"
+ + " "
+ + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
+ + ")",
)
# Parse both the first and third code directives as JavaScript, using the old tree as a
# reference.
- parser.set_included_ranges([
- simple_range(range1_start, range1_end),
- simple_range(range3_start, range3_end),
- ])
+ parser.set_included_ranges(
+ [
+ simple_range(range1_start, range1_end),
+ simple_range(range3_start, range3_end),
+ ]
+ )
tree2 = parser.parse(source_code)
self.assertEqual(
tree2.root_node.sexp(),
"(program"
- + " " +
- "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + " " +
- "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + ")"
+ + " "
+ + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
+ + " "
+ + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
+ + ")",
)
self.assertEqual(tree2.changed_ranges(tree), [simple_range(range1_end, range3_end)])
# Parse all three code directives as JavaScript, using the old tree as a
# reference.
- parser.set_included_ranges([
- simple_range(range1_start, range1_end),
- simple_range(range2_start, range2_end),
- simple_range(range3_start, range3_end),
- ])
+ parser.set_included_ranges(
+ [
+ simple_range(range1_start, range1_end),
+ simple_range(range2_start, range2_end),
+ simple_range(range3_start, range3_end),
+ ]
+ )
tree3 = parser.parse(source_code)
self.assertEqual(
tree3.root_node.sexp(),
"(program"
- + " " +
- "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + " " +
- "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + " " +
- "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + ")"
+ + " "
+ + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
+ + " "
+ + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
+ + " "
+ + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
+ + ")",
)
self.assertEqual(
tree3.changed_ranges(tree2),
- [simple_range(range2_start + 1, range2_end - 1)]
+ [simple_range(range2_start + 1, range2_end - 1)],
)
@@ -552,9 +564,7 @@ def test_children_by_field_id(self):
self.fail("attribute_field is not an int")
attributes = jsx_node.children_by_field_id(attribute_field)
- self.assertEqual(
- [a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]
- )
+ self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"])
def test_children_by_field_name(self):
parser = Parser()
@@ -563,9 +573,7 @@ def test_children_by_field_name(self):
jsx_node = tree.root_node.children[0].children[0]
attributes = jsx_node.children_by_field_name("attribute")
- self.assertEqual(
- [a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]
- )
+ self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"])
def test_node_child_by_field_name_with_extra_hidden_children(self):
parser = Parser()
@@ -576,7 +584,7 @@ def test_node_child_by_field_name_with_extra_hidden_children(self):
if while_node is None:
self.fail("while_node is None")
self.assertEqual(while_node.type, "while_statement")
- self.assertEqual(while_node.child_by_field_name('body'), while_node.child(3))
+ self.assertEqual(while_node.child_by_field_name("body"), while_node.child(3))
def test_node_descendant_count(self):
parser = Parser()
@@ -632,9 +640,7 @@ def test_descendant_for_byte_range(self):
colon_index = JSON_EXAMPLE.index(b":")
# Leaf node exactly matches the given bounds - byte query
- colon_node = array_node.descendant_for_byte_range(
- colon_index, colon_index + 1
- )
+ colon_node = array_node.descendant_for_byte_range(colon_index, colon_index + 1)
if colon_node is None:
self.fail("colon_node is None")
self.assertEqual(colon_node.type, ":")
@@ -654,9 +660,7 @@ def test_descendant_for_byte_range(self):
self.assertEqual(colon_node.end_point, (6, 8))
# The given point is between two adjacent leaf nodes - byte query
- colon_node = array_node.descendant_for_byte_range(
- colon_index, colon_index
- )
+ colon_node = array_node.descendant_for_byte_range(colon_index, colon_index)
if colon_node is None:
self.fail("colon_node is None")
self.assertEqual(colon_node.type, ":")
@@ -677,9 +681,7 @@ def test_descendant_for_byte_range(self):
# Leaf node starts at the lower bound, ends after the upper bound - byte query
string_index = JSON_EXAMPLE.index(b'"x"')
- string_node = array_node.descendant_for_byte_range(
- string_index, string_index + 2
- )
+ string_node = array_node.descendant_for_byte_range(string_index, string_index + 2)
if string_node is None:
self.fail("string_node is None")
self.assertEqual(string_node.type, "string")
@@ -700,9 +702,7 @@ def test_descendant_for_byte_range(self):
# Leaf node starts before the lower bound, ends at the upper bound - byte query
null_index = JSON_EXAMPLE.index(b"null")
- null_node = array_node.descendant_for_byte_range(
- null_index + 1, null_index + 4
- )
+ null_node = array_node.descendant_for_byte_range(null_index + 1, null_index + 4)
if null_node is None:
self.fail("null_node is None")
self.assertEqual(null_node.type, "null")
@@ -722,9 +722,7 @@ def test_descendant_for_byte_range(self):
self.assertEqual(null_node.end_point, (6, 13))
# The bounds span multiple leaf nodes - return the smallest node that does span it.
- pair_node = array_node.descendant_for_byte_range(
- string_index + 2, string_index + 4
- )
+ pair_node = array_node.descendant_for_byte_range(string_index + 2, string_index + 4)
if pair_node is None:
self.fail("pair_node is None")
self.assertEqual(pair_node.type, "pair")
@@ -1021,17 +1019,21 @@ def test_node_numeric_symbols_respect_simple_aliases(self):
# `symbol`, aka `kind_id` should match that of a normal `parenthesized_expression`.
tree = parser.parse(b"(a((*b)))")
root_node = tree.root_node
- self.assertEqual(root_node.sexp(), "(module (expression_statement (parenthesized_expression (call function: (identifier) arguments: (argument_list (parenthesized_expression (list_splat (identifier))))))))") # noqa: E501
+ self.assertEqual(
+ root_node.sexp(),
+ "(module (expression_statement (parenthesized_expression (call "
+ + "function: (identifier) arguments: (argument_list (parenthesized_expression "
+ + "(list_splat (identifier))))))))",
+ )
outer_expr_node = root_node.child(0).child(0)
if outer_expr_node is None:
self.fail("outer_expr_node is None")
self.assertEqual(outer_expr_node.type, "parenthesized_expression")
- inner_expr_node = outer_expr_node \
- .named_child(0) \
- .child_by_field_name("arguments") \
- .named_child(0)
+ inner_expr_node = (
+ outer_expr_node.named_child(0).child_by_field_name("arguments").named_child(0)
+ )
if inner_expr_node is None:
self.fail("inner_expr_node is None")
@@ -1174,12 +1176,14 @@ def test_tree_cursor(self):
parser = Parser()
parser.set_language(RUST)
- tree = parser.parse(b"""
+ tree = parser.parse(
+ b"""
struct Stuff {
a: A,
b: Option,
}
- """)
+ """
+ )
cursor = tree.walk()
self.assertEqual(cursor.node.type, "source_file")
@@ -1665,8 +1669,7 @@ def test_lookahead_iterator(self):
self.assertNotEqual(next_state, 0)
self.assertEqual(
- next_state,
- RUST.next_state(cursor.node.parse_state, cursor.node.grammar_id)
+ next_state, RUST.next_state(cursor.node.parse_state, cursor.node.grammar_id)
)
self.assertLess(next_state, RUST.parse_state_count)
self.assertEqual(cursor.goto_next_sibling(), True) # type_identifier
diff --git a/tree_sitter/__init__.py b/tree_sitter/__init__.py
index a4b10e1f..287a1019 100644
--- a/tree_sitter/__init__.py
+++ b/tree_sitter/__init__.py
@@ -7,15 +7,40 @@
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional
-from tree_sitter.binding import (LookaheadIterator, Node, Parser, # noqa: F401
- Tree, TreeCursor, _language_field_count,
- _language_field_id_for_name,
- _language_field_name_for_id, _language_query,
- _language_state_count, _language_symbol_count,
- _language_symbol_for_name,
- _language_symbol_name, _language_symbol_type,
- _language_version, _lookahead_iterator,
- _next_state)
+from tree_sitter.binding import (
+ LookaheadIterator,
+ LookaheadNamesIterator,
+ Node,
+ Parser,
+ Query,
+ Range,
+ Tree,
+ TreeCursor,
+ _language_field_count,
+ _language_field_id_for_name,
+ _language_field_name_for_id,
+ _language_query,
+ _language_state_count,
+ _language_symbol_count,
+ _language_symbol_for_name,
+ _language_symbol_name,
+ _language_symbol_type,
+ _language_version,
+ _lookahead_iterator,
+ _next_state,
+)
+
+__all__ = [
+ "Language",
+ "Node",
+ "Parser",
+ "Query",
+ "Range",
+ "Tree",
+ "TreeCursor",
+ "LookaheadIterator",
+ "LookaheadNamesIterator",
+]
class SymbolType(enum.IntEnum):
@@ -59,9 +84,7 @@ def build_library(output_path: str, repo_paths: List[str]):
source_paths.append(path.join(src_path, "scanner.cc"))
elif path.exists(path.join(src_path, "scanner.c")):
source_paths.append(path.join(src_path, "scanner.c"))
- source_mtimes = [path.getmtime(__file__)] + [
- path.getmtime(path_) for path_ in source_paths
- ]
+ source_mtimes = [path.getmtime(__file__)] + [path.getmtime(path_) for path_ in source_paths]
if max(source_mtimes) <= output_mtime:
return False
diff --git a/tree_sitter/binding.c b/tree_sitter/binding.c
index 6e2e93c1..c8011beb 100644
--- a/tree_sitter/binding.c
+++ b/tree_sitter/binding.c
@@ -847,14 +847,10 @@ static PyGetSetDef node_accessors[] = {
};
static PyType_Slot node_type_slots[] = {
- {Py_tp_doc, "A syntax node"},
- {Py_tp_dealloc, node_dealloc},
- {Py_tp_repr, node_repr},
- {Py_tp_richcompare, node_compare},
- {Py_tp_hash, node_hash},
- {Py_tp_methods, node_methods},
- {Py_tp_getset, node_accessors},
- {0, NULL},
+ {Py_tp_doc, "A syntax node"}, {Py_tp_dealloc, node_dealloc},
+ {Py_tp_repr, node_repr}, {Py_tp_richcompare, node_compare},
+ {Py_tp_hash, node_hash}, {Py_tp_methods, node_methods},
+ {Py_tp_getset, node_accessors}, {0, NULL},
};
static PyType_Spec node_type_spec = {
diff --git a/tree_sitter/binding.pyi b/tree_sitter/binding.pyi
index 26c4edf2..70fa7e57 100644
--- a/tree_sitter/binding.pyi
+++ b/tree_sitter/binding.pyi
@@ -45,14 +45,10 @@ class Node:
def field_name_for_child(self, child_index: int) -> Optional[str]:
"""Get the field name of a child node by the index of child."""
...
- def descendant_for_byte_range(
- self, start_byte: int, end_byte: int
- ) -> Optional[Node]:
+ def descendant_for_byte_range(self, start_byte: int, end_byte: int) -> Optional[Node]:
"""Get the smallest node within the given byte range."""
...
- def named_descendant_for_byte_range(
- self, start_byte: int, end_byte: int
- ) -> Optional[Node]:
+ def named_descendant_for_byte_range(self, start_byte: int, end_byte: int) -> Optional[Node]:
"""Get the smallest named node within the given byte range."""
...
def descendant_for_point_range(
@@ -189,7 +185,9 @@ class Node:
class Tree:
"""A Syntax Tree"""
- def root_node_with_offset(self, offset_bytes: int, offset_extent: Tuple[int, int]) -> Optional[Node]:
+ def root_node_with_offset(
+ self, offset_bytes: int, offset_extent: Tuple[int, int]
+ ) -> Optional[Node]:
"""Get the root node of the syntax tree, but with its position shifted forward by the given offset."""
...
def walk(self) -> TreeCursor:
@@ -321,8 +319,8 @@ class Parser:
def parse(
self,
- source_code: bytes|Callable[[int, Tuple[int, int]], Optional[bytes]],
- old_tree: Optional[Tree]= None,
+ source_code: bytes | Callable[[int, Tuple[int, int]], Optional[bytes]],
+ old_tree: Optional[Tree] = None,
keep_text: Optional[bool] = True,
) -> Tree:
"""Parse source code, creating a syntax tree.