Skip to content

Commit

Permalink
Make sure that check_struct is called where needed (#161)
Browse files Browse the repository at this point in the history
* Make sure that check_struct is called where needed

* consistency

* black is the new black

* Check for errors in code gen report

* implement check_struct

* nicer name, without typo

* Test that check_struct actually does something

* Implement check_struct and apply a few fixes
  • Loading branch information
almarklein authored Apr 28, 2021
1 parent 61c1334 commit 2a00311
Show file tree
Hide file tree
Showing 16 changed files with 170 additions and 66 deletions.
6 changes: 3 additions & 3 deletions codegen/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def prepare():


def update_api():
""" Update the public API and patch the public-facing API of the backends. """
"""Update the public API and patch the public-facing API of the backends."""

print("## Updating API")

Expand Down Expand Up @@ -50,7 +50,7 @@ def update_api():


def update_rs():
""" Update and check the rs backend. """
"""Update and check the rs backend."""

print("## Validating rs.py")

Expand All @@ -68,7 +68,7 @@ def update_rs():


def main():
""" Codegen entry point. """
"""Codegen entry point."""

with PrintToFile(os.path.join(lib_dir, "resources", "codegen_report.md")):
print("# Code generatation report")
Expand Down
77 changes: 71 additions & 6 deletions codegen/apipatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import os

from .utils import print, lib_dir, blacken, to_snake_case, to_camel_case, Patcher
from .idlparser import get_idl_parser
from codegen.utils import print, lib_dir, blacken, to_snake_case, to_camel_case, Patcher
from codegen.idlparser import get_idl_parser


def patch_base_api(code):
Expand Down Expand Up @@ -42,7 +42,11 @@ def patch_backend_api(code):
base_api_code = f.read().decode()

# Patch!
for patcher in [CommentRemover(), BackendApiPatcher(base_api_code)]:
for patcher in [
CommentRemover(),
BackendApiPatcher(base_api_code),
StructValidationChecker(),
]:
patcher.apply(code)
code = patcher.dumps()
return code
Expand All @@ -53,7 +57,7 @@ class CommentRemover(Patcher):
to prevent accumulating comments.
"""

triggers = "# IDL:", "# FIXME: unknown api"
triggers = "# IDL:", "# FIXME: unknown api", "# FIXME: missing check_struct"

def apply(self, code):
self._init(code)
Expand Down Expand Up @@ -174,7 +178,7 @@ def patch_properties(self, classname, i1, i2):
self._apidiffs_from_lines(pre_lines, propname)
if self.prop_is_known(classname, propname):
if "@apidiff.add" in pre_lines:
print(f"Error: apidiff.add for known {classname}.{propname}")
print(f"ERROR: apidiff.add for known {classname}.{propname}")
elif "@apidiff.hide" in pre_lines:
pass # continue as normal
old_line = self.lines[j1]
Expand Down Expand Up @@ -207,7 +211,7 @@ def patch_methods(self, classname, i1, i2):
self._apidiffs_from_lines(pre_lines, methodname)
if self.method_is_known(classname, methodname):
if "@apidiff.add" in pre_lines:
print(f"Error: apidiff.add for known {classname}.{methodname}")
print(f"ERROR: apidiff.add for known {classname}.{methodname}")
elif "@apidiff.hide" in pre_lines:
pass # continue as normal
elif "@apidiff.change" in pre_lines:
Expand Down Expand Up @@ -443,3 +447,64 @@ def get_required_prop_names(self, classname):
def get_required_method_names(self, classname):
_, methods = self.classes[classname]
return list(name for name in methods.keys() if methods[name][1])


class StructValidationChecker(Patcher):
"""Checks that all structs are vaildated in the methods that have incoming structs."""

def apply(self, code):
self._init(code)

idl = get_idl_parser()
all_structs = set()
ignore_structs = {"Extent3D"}

for classname, i1, i2 in self.iter_classes():
if classname not in idl.classes:
continue

# For each method ...
for methodname, j1, j2 in self.iter_methods(i1 + 1):
code = "\n".join(self.lines[j1 : j2 + 1])
# Get signature and cut it up in words
sig_words = code.partition("(")[2].split("):")[0]
for c in "][(),\"'":
sig_words = sig_words.replace(c, " ")
# Collect incoming structs from signature
method_structs = set()
for word in sig_words.split():
if word.startswith("structs."):
structname = word.partition(".")[2]
method_structs.update(self._get_sub_structs(idl, structname))
all_structs.update(method_structs)
# Collect structs being checked
checked = set()
for line in code.splitlines():
line = line.lstrip()
if line.startswith("check_struct("):
name = line.split("(")[1].split(",")[0].strip('"')
checked.add(name)
# Test that a matching check is done
unchecked = method_structs.difference(checked)
unchecked = list(sorted(unchecked.difference(ignore_structs)))
if (
methodname.endswith("_async")
and f"return self.{methodname[:-7]}" in code
):
pass
elif unchecked:
msg = f"missing check_struct in {methodname}: {unchecked}"
self.insert_line(j1, f"# FIXME: {msg}")
print(f"ERROR: {msg}")

# Test that we did find structs. In case our detection fails for
# some reason, this would probably catch that.
assert len(all_structs) > 10

def _get_sub_structs(self, idl, structname):
structnames = {structname}
for structfield in idl.structs[structname].values():
structname2 = structfield.typename[3:] # remove "GPU"
if structname2 in idl.structs:
structnames.update(self._get_sub_structs(idl, structname2))
return structnames
2 changes: 1 addition & 1 deletion codegen/hparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def get_h_parser(*, allow_cache=True):
""" Get the global HParser object. """
"""Get the global HParser object."""

# Singleton pattern
global _parser
Expand Down
2 changes: 1 addition & 1 deletion codegen/idlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def get_idl_parser(*, allow_cache=True):
""" Get the global IdlParser object. """
"""Get the global IdlParser object."""

# Singleton pattern
global _parser
Expand Down
6 changes: 3 additions & 3 deletions codegen/rspatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def apply(self, code):
if name not in hp.functions:
msg = f"unknown C function {name}"
self.insert_line(i, f"{indent}# FIXME: {msg}")
print(f"Error: {msg}")
print(f"ERROR: {msg}")
else:
detected.add(name)
anno = hp.functions[name].replace(name, "f").strip(";")
Expand Down Expand Up @@ -302,7 +302,7 @@ def _validate_struct(self, hp, i1, i2):
if struct_name not in hp.structs:
msg = f"unknown C struct {struct_name}"
self.insert_line(i1, f"{indent}# FIXME: {msg}")
print(f"Error: {msg}")
print(f"ERROR: {msg}")
return
else:
struct = hp.structs[struct_name]
Expand All @@ -322,7 +322,7 @@ def _validate_struct(self, hp, i1, i2):
if key not in struct:
msg = f"unknown C struct field {struct_name}.{key}"
self.insert_line(i1 + j, f"{indent}# FIXME: {msg}")
print(f"Error: {msg}")
print(f"ERROR: {msg}")

# Insert comments for unused keys
more_lines = []
Expand Down
10 changes: 10 additions & 0 deletions codegen/tests/test_codegen_z.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,15 @@ def test_that_code_is_up_to_date():
print("Codegen check ok!")


def test_that_codegen_report_has_no_errors():
filename = os.path.join(lib_dir, "resources", "codegen_report.md")
with open(filename, "rb") as f:
text = f.read().decode()

# The codegen uses a prefix "ERROR:" for unacceptable things.
# All caps, some function names may contain the name "error".
assert "ERROR" not in text


if __name__ == "__main__":
test_that_code_is_up_to_date()
2 changes: 1 addition & 1 deletion codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def print(*args, **kwargs):


class PrintToFile:
""" Context manager to print to file. """
"""Context manager to print to file."""

def __init__(self, f):
if isinstance(f, str):
Expand Down
2 changes: 1 addition & 1 deletion examples/cube_glfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def draw_frame():


def simple_event_loop():
""" A real simple event loop, but it keeps the CPU busy. """
"""A real simple event loop, but it keeps the CPU busy."""
while update_glfw_canvasses():
glfw.poll_events()

Expand Down
4 changes: 2 additions & 2 deletions examples/triangle_glfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@


def simple_event_loop():
""" A real simple event loop, but it keeps the CPU busy. """
"""A real simple event loop, but it keeps the CPU busy."""
while update_glfw_canvasses():
glfw.poll_events()


def better_event_loop(max_fps=100):
""" A simple event loop that schedules draws. """
"""A simple event loop that schedules draws."""
td = 1 / max_fps
while update_glfw_canvasses():
# Determine next time to draw
Expand Down
9 changes: 8 additions & 1 deletion tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def compute_shader(
)
bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)

# Create and run the pipeline, fail - test check_struct
with raises(ValueError):
compute_pipeline = device.create_compute_pipeline(
layout=pipeline_layout,
compute={"module": cshader, "entry_point": "main", "foo": 42},
)

# Create and run the pipeline
compute_pipeline = device.create_compute_pipeline(
layout=pipeline_layout,
Expand Down Expand Up @@ -259,7 +266,7 @@ def compute_shader(
compute_with_buffers({0: in1}, {0: c_int32 * 100}, compute_shader, n=-1)

with raises(TypeError): # invalid shader
compute_with_buffers({0: in1}, {0: c_int32 * 100}, "not a shader")
compute_with_buffers({0: in1}, {0: c_int32 * 100}, {"not", "a", "shader"})


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rs_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_shader_module_creation():
with raises(TypeError):
device.create_shader_module(code=code4)
with raises(TypeError):
device.create_shader_module(code="not a shader")
device.create_shader_module(code={"not", "a", "shader"})
with raises(ValueError):
device.create_shader_module(code=b"bytes but no SpirV magic number")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_rs_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,13 @@ def cb(renderpass):
format=wgpu.TextureFormat.depth24plus_stencil8,
depth_write_enabled=True,
depth_compare=wgpu.CompareFunction.less_equal,
front={
stencil_front={
"compare": wgpu.CompareFunction.equal,
"fail_op": wgpu.StencilOperation.keep,
"depth_fail_op": wgpu.StencilOperation.keep,
"pass_op": wgpu.StencilOperation.keep,
},
back={
stencil_back={
"compare": wgpu.CompareFunction.equal,
"fail_op": wgpu.StencilOperation.keep,
"depth_fail_op": wgpu.StencilOperation.keep,
Expand Down
Loading

0 comments on commit 2a00311

Please sign in to comment.