diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index a1444f932436..917aef2a20d7 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -32,7 +32,7 @@ def mangle_ty(ty): return f'{elt}S{shape}S' if ty.is_void(): return 'V' - assert False, "Unsupported type" + raise TypeError(f'Unsupported type {ty}') def mangle_fn(name, arg_tys, constants): @@ -121,10 +121,7 @@ def __init__(self, gscope): self.gscope = gscope def _visit_stmts(self, body) -> bool: - for s in body: - if self.visit(s): - return True - return False + return any(self.visit(s) for s in body) def _visit_function(self, fn) -> bool: # Currently we only support JITFunctions defined in the global scope @@ -160,7 +157,7 @@ def visit_Attribute(self, node: ast.Attribute) -> bool: return self.visit(node.value) def visit_Name(self, node: ast.Name) -> bool: - if type(node.ctx) == ast.Store: + if type(node.ctx) is ast.Store: return False if node.id in self.gscope: fn = self.gscope[node.id] @@ -226,7 +223,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.function_ret_types = {} if function_types is None else function_types self.prototype = prototype self.gscope = gscope - self.lscope = dict() + self.lscope = {} self.attributes = attributes self.constants = constants self.jit_fn = jit_fn @@ -281,19 +278,20 @@ def global_lookup(name: str, absent): # The high-level rule is that only constexpr globals are allowed. # But actually a bunch of other things, such as module imports, are # technically Python globals. We have to allow these too! - if (val is absent # - or name in self.builtin_namespace # - or type(val) == ModuleType # - or isinstance(val, JITFunction) # - or getattr(val, "__triton_builtin__", False) # - or getattr(val, "__module__", "").startswith("triton.language") # - or isinstance(val, language.dtype) # - or self._is_constexpr_global(name) # + if any([ + val is absent, name in self.builtin_namespace, # + type(val) is ModuleType, # + isinstance(val, JITFunction), # + getattr(val, "__triton_builtin__", False), # + getattr(val, "__module__", "").startswith("triton.language"), # + isinstance(val, language.dtype), # + self._is_constexpr_global(name), # # Allow accesses to globals while visiting an ast.arg # because you should be able to do # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... - or self.visiting_arg_default_value # - or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"): + self.visiting_arg_default_value, # + os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1" + ]): return val raise NameError( textwrap.dedent(f"""\ @@ -418,7 +416,7 @@ def visit_FunctionDef(self, node): entry = self.fn.add_entry_block() arg_values = [] idx = 0 - for i, arg_name in enumerate(arg_names): + for i in range(len(arg_names)): if i in self.constants: cst = self.constants[i] if not _is_constexpr(cst): @@ -514,7 +512,7 @@ def visit_AugAssign(self, node): return self.dereference_name(name) def visit_Name(self, node): - if type(node.ctx) == ast.Store: + if type(node.ctx) is ast.Store: return node.id return self.dereference_name(node.id) @@ -770,9 +768,9 @@ def visit_Compare(self, node): rhs = self.visit(node.comparators[0]) lhs_value = _unwrap_if_constexpr(lhs) rhs_value = _unwrap_if_constexpr(rhs) - if type(node.ops[0]) == ast.Is: + if type(node.ops[0]) is ast.Is: return constexpr(lhs_value is rhs_value) - if type(node.ops[0]) == ast.IsNot: + if type(node.ops[0]) is ast.IsNot: return constexpr(lhs_value is not rhs_value) method_name = self._method_name_for_comp_op.get(type(node.ops[0])) if method_name is None: @@ -1048,7 +1046,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): args = [args[name] for name in fn.arg_names] args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] # generate function def - attributes = dict() + attributes = {} constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] constants = {i: args[i] for i in constexprs} # generate call @@ -1098,14 +1096,14 @@ def visit_Call(self, node): kws = dict(self.visit(keyword) for keyword in node.keywords) args = [self.visit(arg) for arg in node.args] - if fn is language.core.device_assert: # TODO: this should not be so hardcoded - if not self.debug: - return + # TODO: this should not be so hardcoded + if fn is language.core.device_assert and not self.debug: + return if isinstance(fn, JITFunction): _check_fn_args(node, fn, args) return self.call_JitFunction(fn, args, kws) if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): - extra_kwargs = dict(_builder=self.builder) + extra_kwargs = {"_builder": self.builder} sig = inspect.signature(fn) if '_generator' in sig.parameters: extra_kwargs['_generator'] = self @@ -1154,9 +1152,8 @@ def visit_Str(self, node): def visit_Attribute(self, node): lhs = self.visit(node.value) - if _is_triton_tensor(lhs): - if node.attr == "T": - return language.semantic.permute(lhs, (1, 0), builder=self.builder) + if _is_triton_tensor(lhs) and node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) return getattr(lhs, node.attr) def visit_Expr(self, node): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index eca402e5bb43..9e7a90cf5090 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -98,7 +98,7 @@ def _update_hash(self, func): self.hasher.update(func_key.encode("utf-8")) def visit_Name(self, node): - if type(node.ctx) == ast.Store: + if type(node.ctx) is ast.Store: return node.id if node.id in self.local_names: @@ -117,12 +117,11 @@ def visit_Name(self, node): and not self.visiting_arg_default_value # It would be pretty evil if someone did `import x` and then # `x = blah`. - and type(val) != ModuleType + and type(val) is not ModuleType # It would be pretty evil if we used function `foo` inside of # `bar` and then someone did `foo = baz`. and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # - and node.id not in self.supported_python_builtins # - ): + and node.id not in self.supported_python_builtins): self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) self._update_hash(val) @@ -650,7 +649,7 @@ def run(self, *args, grid, warmup, **kwargs): # Check that used global values have not changed. not_present = object() - for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items(): + for (name, _), (val, globals_dict) in self.used_global_vals.items(): if (newVal := globals_dict.get(name, not_present)) != val: raise RuntimeError( f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}")