Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loopy.statistics for CInstructions #647

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def __init__(self, knl, callables_table, kernel_rec,
arithmetic_count_granularity = CountGranularity.SUBGROUP

def combine(self, values):
return sum(values)
return sum(values, self.new_zero_poly_map())

def map_constant(self, expr):
return self.new_zero_poly_map()
Expand Down Expand Up @@ -1654,45 +1654,59 @@ def _get_insn_count(knl, callables_table, insn_id, subgroup_size,
# {{{ get_op_map

def _get_op_map_for_single_kernel(knl, callables_table,
count_redundant_work,
count_within_subscripts, subgroup_size):
count_redundant_work,
count_within_subscripts,
ignore_c_instruction_ops,
subgroup_size):

subgroup_size = _process_subgroup_size(knl, subgroup_size)

kernel_rec = partial(_get_op_map_for_single_kernel,
callables_table=callables_table,
count_redundant_work=count_redundant_work,
count_within_subscripts=count_within_subscripts,
ignore_c_instruction_ops=ignore_c_instruction_ops,
subgroup_size=subgroup_size)

op_counter = ExpressionOpCounter(knl, callables_table, kernel_rec,
count_within_subscripts)
op_map = op_counter.new_zero_poly_map()

from loopy.kernel.instruction import (
CallInstruction, CInstruction, Assignment,
MultiAssignmentBase, CInstruction,
NoOpInstruction, BarrierInstruction)

for insn in knl.instructions:
if isinstance(insn, (CallInstruction, CInstruction, Assignment)):
ops = op_counter(insn.assignees) + op_counter(insn.expression)
for key, val in ops.count_map.items():
count = _get_insn_count(knl, callables_table, insn.id,
subgroup_size, count_redundant_work,
key.count_granularity)
op_map = op_map + ToCountMap({key: val}) * count
if isinstance(insn, MultiAssignmentBase):
exprs_in_insn = (insn.assignees, insn.expression,
tuple(insn.predicates))
elif isinstance(insn, CInstruction):
if ignore_c_instruction_ops:
exprs_in_insn = tuple(insn.predicates)
else:
raise LoopyError("Cannot count number of operations in CInstruction."
" To ignore the operations in CInstructions pass"
" `ignore_c_instruction_ops=True`.")

elif isinstance(insn, (NoOpInstruction, BarrierInstruction)):
pass
exprs_in_insn = tuple(insn.predicates)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't know how much computation goes on in a CInstruction, I think we at least need a flag to turn on this behavior, since it's easy to have misconceptions on what it does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a warning in #649

else:
raise NotImplementedError("unexpected instruction item type: '%s'"
% type(insn).__name__)

ops = op_counter(exprs_in_insn)
for key, val in ops.count_map.items():
count = _get_insn_count(knl, callables_table, insn.id,
subgroup_size, count_redundant_work,
key.count_granularity)
op_map = op_map + ToCountMap({key: val}) * count

return op_map


def get_op_map(program, count_redundant_work=False,
count_within_subscripts=True, subgroup_size=None,
ignore_c_instruction_ops=True,
entrypoint=None):

"""Count the number of operations in a loopy kernel.
Expand All @@ -1708,6 +1722,12 @@ def get_op_map(program, count_redundant_work=False,
:arg count_within_subscripts: A :class:`bool` specifying whether to
count operations inside array indices.

:arg ignore_c_instruction_ops: A instance of :class:`bool`. If *True*
ignores the operations performed in :attr:`loopy.CInstruction`. If
*False*, raises an error on encountering a :attr:`loopy.CInstruction`,
since :mod:`loopy` cannot parse the number of operations in plain
C-code.

:arg subgroup_size: (currently unused) An :class:`int`, :class:`str`
``"guess"``, or *None* that specifies the sub-group size. An OpenCL
sub-group is an implementation-dependent grouping of work-items within
Expand Down Expand Up @@ -1768,6 +1788,7 @@ def get_op_map(program, count_redundant_work=False,
program[entrypoint], program.callables_table,
count_redundant_work=count_redundant_work,
count_within_subscripts=count_within_subscripts,
ignore_c_instruction_ops=ignore_c_instruction_ops,
subgroup_size=subgroup_size)

# }}}
Expand Down
29 changes: 29 additions & 0 deletions test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import sys
import pytest
from pyopencl.tools import ( # noqa
pytest_generate_tests_for_pyopencl
as pytest_generate_tests)
Expand All @@ -29,6 +30,7 @@
import numpy as np
from pytools import div_ceil
from loopy.statistics import CountGranularity as CG
from loopy.diagnostic import LoopyError

from pymbolic.primitives import Variable

Expand Down Expand Up @@ -1531,6 +1533,33 @@ def test_no_loop_ops():
assert f64_mul == 1


def test_c_instructions_stats():
# loopy.git <= 04fb703 would fail this regression as CInstructions weren't
# supported in loopy.statistics
knl = lp.make_kernel(
"{ : }",
["""
a = 2.0f
b = 2*a
""",
lp.CInstruction((),
code='printf("Hello World\n");'),
"""
c = a + b
"""
])

op_map = lp.get_op_map(knl, subgroup_size=1)
f32_add = op_map.filter_by(name="add").eval_and_sum({})
f32_mul = op_map.filter_by(name="mul").eval_and_sum({})
assert f32_add == 1
assert f32_mul == 1

with pytest.raises(LoopyError):
op_map = lp.get_op_map(knl, subgroup_size=1,
ignore_c_instruction_ops=False)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down