Skip to content

Commit

Permalink
C expression casting logic: refactor, add some types
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jun 27, 2024
1 parent 39e4955 commit 5be8abf
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions loopy/target/c/codegen/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""


from typing import Optional
import numpy as np

from pymbolic.mapper import RecursiveMapper, IdentityMapper
Expand Down Expand Up @@ -109,21 +110,23 @@ def find_array(self, expr):

return ary

def wrap_in_typecast(self, actual_type, needed_type, s):
def wrap_in_typecast(self, actual_type: LoopyType, needed_type: LoopyType, s):
if actual_type != needed_type:
registry = self.codegen_state.ast_builder.target.get_dtype_registry()
cast = var("(%s) " % registry.dtype_to_ctype(needed_type))
return cast(s)

return s

def rec(self, expr, type_context=None, needed_type=None):
if needed_type is None:
return RecursiveMapper.rec(self, expr, type_context)
def rec(self, expr, type_context=None, needed_type: Optional[LoopyType] = None):
result = RecursiveMapper.rec(self, expr, type_context)

return self.wrap_in_typecast(
self.infer_type(expr), needed_type,
RecursiveMapper.rec(self, expr, type_context))
if needed_type is None:
return result
else:
return self.wrap_in_typecast(
self.infer_type(expr), needed_type,
result)

def __call__(self, expr, prec=None, type_context=None, needed_dtype=None):
if prec is None:
Expand Down

0 comments on commit 5be8abf

Please sign in to comment.