diff --git a/spec/compiler/macro/macro_methods_spec.cr b/spec/compiler/macro/macro_methods_spec.cr index 7e884e887f22..2bce2f8b5b16 100644 --- a/spec/compiler/macro/macro_methods_spec.cr +++ b/spec/compiler/macro/macro_methods_spec.cr @@ -2746,4 +2746,54 @@ module Crystal end end end + + describe "error reporting" do + it "reports wrong number of arguments" do + expect_raises(Crystal::TypeException, "wrong number of arguments for macro 'ArrayLiteral#push' (given 0, expected 1)") do + assert_macro "", %({{[1, 2, 3].push}}), [] of ASTNode, "" + end + end + + it "reports wrong number of arguments, with optional parameters" do + expect_raises(Crystal::TypeException, "wrong number of arguments for macro 'NumberLiteral#+' (given 2, expected 0..1)") do + assert_macro "", %({{1.+(2, 3)}}), [] of ASTNode, "" + end + + expect_raises(Crystal::TypeException, "wrong number of arguments for macro 'ArrayLiteral#[]' (given 0, expected 1..2)") do + assert_macro "", %({{[1][]}}), [] of ASTNode, "" + end + end + + it "reports unexpected block" do + expect_raises(Crystal::TypeException, "macro 'ArrayLiteral#shuffle' is not expected to be invoked with a block, but a block was given") do + assert_macro "", %({{[1, 2, 3].shuffle { |x| }}}), [] of ASTNode, "" + end + end + + it "reports missing block" do + expect_raises(Crystal::TypeException, "macro 'ArrayLiteral#reduce' is expected to be invoked with a block, but no block was given") do + assert_macro "", %({{[1, 2, 3].reduce}}), [] of ASTNode, "" + end + end + + it "reports unexpected named argument" do + expect_raises(Crystal::TypeException, "named arguments are not allowed here") do + assert_macro "", %({{"".starts_with?(other: "")}}), [] of ASTNode, "" + end + end + + it "reports unexpected named argument (2)" do + expect_raises(Crystal::TypeException, "no named parameter 'foo'") do + assert_macro "", %({{"".camelcase(foo: "")}}), [] of ASTNode, "" + end + end + + # there are no macro methods with required named parameters + + it "uses correct name for top-level macro methods" do + expect_raises(Crystal::TypeException, "wrong number of arguments for top-level macro 'flag?' (given 0, expected 1)") do + assert_macro "", %({{flag?}}), [] of ASTNode, "" + end + end + end end diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index aaa66b0f184c..6bcf0fb0902a 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -79,31 +79,27 @@ module Crystal end def interpret_compare_versions(node) - unless node.args.size == 2 - node.wrong_number_of_arguments "macro call 'compare_versions'", node.args.size, 2 - end + interpret_check_args_toplevel do |first_arg, second_arg| + first = accept first_arg + first_string = first.to_string("first argument to 'compare_versions'") - first_arg = node.args[0] - first = accept first_arg - first_string = first.to_string("first argument to 'compare_versions'") + second = accept second_arg + second_string = second.to_string("second argument to 'compare_versions'") - second_arg = node.args[1] - second = accept second_arg - second_string = second.to_string("second argument to 'compare_versions'") + first_version = begin + SemanticVersion.parse(first_string) + rescue ex + first_arg.raise ex.message + end - first_version = begin - SemanticVersion.parse(first_string) - rescue ex - first_arg.raise ex.message - end + second_version = begin + SemanticVersion.parse(second_string) + rescue ex + second_arg.raise ex.message + end - second_version = begin - SemanticVersion.parse(second_string) - rescue ex - second_arg.raise ex.message + @last = NumberLiteral.new(first_version <=> second_version) end - - @last = NumberLiteral.new(first_version <=> second_version) end def interpret_debug(node) @@ -134,19 +130,17 @@ module Crystal end def interpret_env(node) - if node.args.size == 1 - node.args[0].accept self + interpret_check_args_toplevel do |arg| + arg.accept self cmd = @last.to_macro_id env_value = ENV[cmd]? @last = env_value ? StringLiteral.new(env_value) : NilLiteral.new - else - node.wrong_number_of_arguments "macro call 'env'", node.args.size, 1 end end def interpret_flag?(node) - if node.args.size == 1 - node.args[0].accept self + interpret_check_args_toplevel do |arg| + arg.accept self flag_name = @last.to_macro_id flags = case node.name when "flag?" @@ -157,8 +151,6 @@ module Crystal raise "Bug: unexpected macro method #{node.name}" end @last = BoolLiteral.new(flags.includes?(flag_name)) - else - node.wrong_number_of_arguments "macro call '#{node.name}'", node.args.size, 1 end end @@ -228,24 +220,22 @@ module Crystal end def interpret_read_file(node, nilable = false) - unless node.args.size == 1 - node.wrong_number_of_arguments "macro call '#{node.name}'", node.args.size, 1 - end - - node.args[0].accept self - filename = @last.to_macro_id + interpret_check_args_toplevel do |arg| + arg.accept self + filename = @last.to_macro_id - begin - @last = StringLiteral.new(File.read(filename)) - rescue ex - node.raise ex.to_s unless nilable - @last = NilLiteral.new + begin + @last = StringLiteral.new(File.read(filename)) + rescue ex + node.raise ex.to_s unless nilable + @last = NilLiteral.new + end end end def interpret_run(node) if node.args.size == 0 - node.wrong_number_of_arguments "macro call 'run'", 0, "1+" + node.wrong_number_of_arguments "top-level macro 'run'", 0, "1+" end node.args.first.accept self @@ -338,80 +328,57 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "id" - interpret_argless_method("id", args) { MacroId.new(to_macro_id) } + interpret_check_args { MacroId.new(to_macro_id) } when "stringify" - interpret_argless_method("stringify", args) { stringify } + interpret_check_args { stringify } when "symbolize" - interpret_argless_method("symbolize", args) { symbolize } + interpret_check_args { symbolize } when "class_name" - interpret_argless_method("class_name", args) { class_name } + interpret_check_args { class_name } when "raise" macro_raise self, args, interpreter when "filename" - interpret_argless_method("filename", args) do + interpret_check_args do filename = location.try &.original_filename filename ? StringLiteral.new(filename) : NilLiteral.new end when "line_number" - interpret_argless_method("line_number", args) do + interpret_check_args do line_number = location.try &.expanded_location.try &.line_number line_number ? NumberLiteral.new(line_number) : NilLiteral.new end when "column_number" - interpret_argless_method("column_number", args) do + interpret_check_args do column_number = location.try &.expanded_location.try &.column_number column_number ? NumberLiteral.new(column_number) : NilLiteral.new end when "end_line_number" - interpret_argless_method("end_line_number", args) do + interpret_check_args do line_number = end_location.try &.expanded_location.try &.line_number line_number ? NumberLiteral.new(line_number) : NilLiteral.new end when "end_column_number" - interpret_argless_method("end_column_number", args) do + interpret_check_args do column_number = end_location.try &.expanded_location.try &.column_number column_number ? NumberLiteral.new(column_number) : NilLiteral.new end when "==" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| BoolLiteral.new(self == arg) end when "!=" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| BoolLiteral.new(self != arg) end when "!" - BoolLiteral.new(!truthy?) + interpret_check_args { BoolLiteral.new(!truthy?) } when "nil?" - interpret_argless_method("nil?", args) do - BoolLiteral.new(is_a?(NilLiteral) || is_a?(Nop)) - end + interpret_check_args { BoolLiteral.new(is_a?(NilLiteral) || is_a?(Nop)) } else raise "undefined macro method '#{class_desc}##{method}'", exception_type: Crystal::UndefinedMacroMethodError end end - def interpret_argless_method(method, args) - interpret_check_args_size method, args, 0 - yield - end - - def interpret_one_arg_method(method, args) - interpret_check_args_size method, args, 1 - yield args.first - end - - def interpret_two_args_method(method, args) - interpret_check_args_size method, args, 2 - yield args[0], args[1] - end - - def interpret_check_args_size(method, args, size) - unless args.size == size - wrong_number_of_arguments method, args.size, size - end - end - def interpret_compare(other) raise "can't compare #{self} to #{other}" end @@ -445,69 +412,67 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when ">" - bool_bin_op(method, args) { |me, other| me > other } + bool_bin_op(method, args, named_args, block) { |me, other| me > other } when ">=" - bool_bin_op(method, args) { |me, other| me >= other } + bool_bin_op(method, args, named_args, block) { |me, other| me >= other } when "<" - bool_bin_op(method, args) { |me, other| me < other } + bool_bin_op(method, args, named_args, block) { |me, other| me < other } when "<=" - bool_bin_op(method, args) { |me, other| me <= other } + bool_bin_op(method, args, named_args, block) { |me, other| me <= other } when "<=>" - num_bin_op(method, args) do |me, other| + num_bin_op(method, args, named_args, block) do |me, other| (me <=> other) || (return NilLiteral.new) end when "+" - if args.empty? - self - else - num_bin_op(method, args) { |me, other| me + other } + interpret_check_args(min_count: 0) do |other| + if other + raise "can't #{method} with #{other}" unless other.is_a?(NumberLiteral) + NumberLiteral.new(to_number + other.to_number) + else + self + end end when "-" - if args.empty? - num = to_number - if num.is_a?(Int::Unsigned) - raise "undefined method '-' for unsigned integer literal: #{self}" + interpret_check_args(min_count: 0) do |other| + if other + raise "can't #{method} with #{other}" unless other.is_a?(NumberLiteral) + NumberLiteral.new(to_number - other.to_number) else + num = to_number + raise "undefined method '-' for unsigned integer literal: #{self}" if num.is_a?(Int::Unsigned) NumberLiteral.new(-num) end - else - num_bin_op(method, args) { |me, other| me - other } end when "*" - num_bin_op(method, args) { |me, other| me * other } + num_bin_op(method, args, named_args, block) { |me, other| me * other } when "/" - num_bin_op(method, args) { |me, other| me / other } + num_bin_op(method, args, named_args, block) { |me, other| me / other } when "//" - num_bin_op(method, args) { |me, other| me // other } + num_bin_op(method, args, named_args, block) { |me, other| me // other } when "**" - num_bin_op(method, args) { |me, other| me ** other } + num_bin_op(method, args, named_args, block) { |me, other| me ** other } when "%" - int_bin_op(method, args) { |me, other| me % other } + int_bin_op(method, args, named_args, block) { |me, other| me % other } when "&" - int_bin_op(method, args) { |me, other| me & other } + int_bin_op(method, args, named_args, block) { |me, other| me & other } when "|" - int_bin_op(method, args) { |me, other| me | other } + int_bin_op(method, args, named_args, block) { |me, other| me | other } when "^" - int_bin_op(method, args) { |me, other| me ^ other } + int_bin_op(method, args, named_args, block) { |me, other| me ^ other } when "<<" - int_bin_op(method, args) { |me, other| me << other } + int_bin_op(method, args, named_args, block) { |me, other| me << other } when ">>" - int_bin_op(method, args) { |me, other| me >> other } + int_bin_op(method, args, named_args, block) { |me, other| me >> other } when "~" - if args.empty? + interpret_check_args do num = to_number - if num.is_a?(Int) - NumberLiteral.new(~num) - else - raise "undefined method '~' for float literal: #{self}" - end - else - wrong_number_of_arguments "NumberLiteral#~", args.size, 0 + raise "undefined method '~' for float literal: #{self}" unless num.is_a?(Int) + NumberLiteral.new(~num) end when "kind" - SymbolLiteral.new(kind.to_s) + interpret_check_args { SymbolLiteral.new(kind.to_s) } when "to_number" - MacroId.new(to_number.to_s) + interpret_check_args { MacroId.new(to_number.to_s) } else super end @@ -517,39 +482,35 @@ module Crystal to_number <=> other.to_number end - def bool_bin_op(op, args) - BoolLiteral.new(bin_op(op, args) { |me, other| yield me, other }) - end - - def num_bin_op(op, args) - NumberLiteral.new(bin_op(op, args) { |me, other| yield me, other }) + def bool_bin_op(method, args, named_args, block) + interpret_check_args do |other| + raise "can't #{method} with #{other}" unless other.is_a?(NumberLiteral) + BoolLiteral.new(yield to_number, other.to_number) + end end - def int_bin_op(op, args) - result = bin_op(op, args) do |me, other| - if me.is_a?(Int) && other.is_a?(Int) - yield me, other - elsif me.is_a?(Float) - raise "undefined method '#{op}' for float literal: #{self}" - else - raise "argument to NumberLiteral##{op} can't be float literal: #{self}" - end + def num_bin_op(method, args, named_args, block) + interpret_check_args do |other| + raise "can't #{method} with #{other}" unless other.is_a?(NumberLiteral) + NumberLiteral.new(yield to_number, other.to_number) end - - NumberLiteral.new result end - def bin_op(op, args) - if args.size != 1 - wrong_number_of_arguments "NumberLiteral##{op}", args.size, 1 - end + def int_bin_op(method, args, named_args, block) + interpret_check_args do |other| + raise "can't #{method} with #{other}" unless other.is_a?(NumberLiteral) + me = to_number + other = other.to_number - other = args.first - unless other.is_a?(NumberLiteral) - raise "can't #{op} with #{other}" + case {me, other} + when {Int, Int} + NumberLiteral.new(yield me, other) + when {Float, _} + raise "undefined method '#{method}' for float literal: #{self}" + else + raise "argument to NumberLiteral##{method} can't be float literal: #{self}" + end end - - yield(to_number, other.to_number) end def to_number @@ -580,18 +541,20 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "==", "!=" - case arg = args.first? - when MacroId - if method == "==" - return BoolLiteral.new(@value == arg.value) + interpret_check_args do |arg| + case arg + when MacroId + if method == "==" + BoolLiteral.new(@value == arg.value) + else + BoolLiteral.new(@value != arg.value) + end else - return BoolLiteral.new(@value != arg.value) + super end - else - return super end when "[]" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when RangeLiteral from, to = arg.from, arg.to @@ -614,7 +577,7 @@ module Crystal end end when "=~" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when RegexLiteral arg_value = arg.value @@ -629,7 +592,7 @@ module Crystal end end when ">" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when StringLiteral, MacroId return BoolLiteral.new(interpret_compare(arg) > 0) @@ -638,7 +601,7 @@ module Crystal end end when "<" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when StringLiteral, MacroId return BoolLiteral.new(interpret_compare(arg) < 0) @@ -647,7 +610,7 @@ module Crystal end end when "+" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when CharLiteral piece = arg.value @@ -659,7 +622,7 @@ module Crystal StringLiteral.new(@value + piece) end when "camelcase" - interpret_argless_method(method, args) do + interpret_check_args(named_params: ["lower"]) do lower = if named_args && (lower_arg = named_args["lower"]?) lower_arg else @@ -671,17 +634,17 @@ module Crystal StringLiteral.new(@value.camelcase(lower: lower.value)) end when "capitalize" - interpret_argless_method(method, args) { StringLiteral.new(@value.capitalize) } + interpret_check_args { StringLiteral.new(@value.capitalize) } when "chars" - interpret_argless_method(method, args) { ArrayLiteral.map(@value.chars, Path.global("Char")) { |value| CharLiteral.new(value) } } + interpret_check_args { ArrayLiteral.map(@value.chars, Path.global("Char")) { |value| CharLiteral.new(value) } } when "chomp" - interpret_argless_method(method, args) { StringLiteral.new(@value.chomp) } + interpret_check_args { StringLiteral.new(@value.chomp) } when "downcase" - interpret_argless_method(method, args) { StringLiteral.new(@value.downcase) } + interpret_check_args { StringLiteral.new(@value.downcase) } when "empty?" - interpret_argless_method(method, args) { BoolLiteral.new(@value.empty?) } + interpret_check_args { BoolLiteral.new(@value.empty?) } when "ends_with?" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when CharLiteral piece = arg.value @@ -693,7 +656,7 @@ module Crystal BoolLiteral.new(@value.ends_with?(piece)) end when "gsub" - interpret_two_args_method(method, args) do |first, second| + interpret_check_args do |first, second| raise "first argument to StringLiteral#gsub must be a regex, not #{first.class_desc}" unless first.is_a?(RegexLiteral) raise "second argument to StringLiteral#gsub must be a string, not #{second.class_desc}" unless second.is_a?(StringLiteral) @@ -707,9 +670,9 @@ module Crystal StringLiteral.new(value.gsub(regex, second.value)) end when "identify" - interpret_argless_method(method, args) { StringLiteral.new(@value.tr(":", "_")) } + interpret_check_args { StringLiteral.new(@value.tr(":", "_")) } when "includes?" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when CharLiteral piece = arg.value @@ -721,30 +684,28 @@ module Crystal BoolLiteral.new(@value.includes?(piece)) end when "size" - interpret_argless_method(method, args) { NumberLiteral.new(@value.size) } + interpret_check_args { NumberLiteral.new(@value.size) } when "lines" - interpret_argless_method(method, args) { ArrayLiteral.map(@value.lines, Path.global("String")) { |value| StringLiteral.new(value) } } + interpret_check_args { ArrayLiteral.map(@value.lines, Path.global("String")) { |value| StringLiteral.new(value) } } when "split" - case args.size - when 0 - ArrayLiteral.map(@value.split, Path.global("String")) { |value| StringLiteral.new(value) } - when 1 - first_arg = args.first - case first_arg - when CharLiteral - splitter = first_arg.value - when StringLiteral - splitter = first_arg.value + interpret_check_args(min_count: 0) do |arg| + if arg + case arg + when CharLiteral + splitter = arg.value + when StringLiteral + splitter = arg.value + else + splitter = arg.to_s + end + + ArrayLiteral.map(@value.split(splitter), Path.global("String")) { |value| StringLiteral.new(value) } else - splitter = first_arg.to_s + ArrayLiteral.map(@value.split, Path.global("String")) { |value| StringLiteral.new(value) } end - - ArrayLiteral.map(@value.split(splitter), Path.global("String")) { |value| StringLiteral.new(value) } - else - wrong_number_of_arguments "StringLiteral#split", args.size, "0..1" end when "count" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when CharLiteral chr = arg.value @@ -754,7 +715,7 @@ module Crystal NumberLiteral.new(@value.count(chr)) end when "starts_with?" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when CharLiteral piece = arg.value @@ -766,20 +727,17 @@ module Crystal BoolLiteral.new(@value.starts_with?(piece)) end when "strip" - interpret_argless_method(method, args) { StringLiteral.new(@value.strip) } + interpret_check_args { StringLiteral.new(@value.strip) } when "titleize" - interpret_argless_method(method, args) { StringLiteral.new(@value.titleize) } + interpret_check_args { StringLiteral.new(@value.titleize) } when "to_i" - case args.size - when 0 - value = @value.to_i64? - when 1 - arg = args.first - raise "argument to StringLiteral#to_i must be a number, not #{arg.class_desc}" unless arg.is_a?(NumberLiteral) - - value = @value.to_i64?(arg.to_number.to_i) - else - wrong_number_of_arguments "StringLiteral#to_i", args.size, "0..1" + value = interpret_check_args(min_count: 0) do |base| + if base + raise "argument to StringLiteral#to_i must be a number, not #{base.class_desc}" unless base.is_a?(NumberLiteral) + @value.to_i64?(base.to_number.to_i) + else + @value.to_i64? + end end if value @@ -788,16 +746,16 @@ module Crystal raise "StringLiteral#to_i: #{@value} is not an integer" end when "tr" - interpret_two_args_method(method, args) do |first, second| + interpret_check_args do |first, second| raise "first argument to StringLiteral#tr must be a string, not #{first.class_desc}" unless first.is_a?(StringLiteral) raise "second argument to StringLiteral#tr must be a string, not #{second.class_desc}" unless second.is_a?(StringLiteral) StringLiteral.new(@value.tr(first.value, second.value)) end when "underscore" - interpret_argless_method(method, args) { StringLiteral.new(@value.underscore) } + interpret_check_args { StringLiteral.new(@value.underscore) } when "upcase" - interpret_argless_method(method, args) { StringLiteral.new(@value.upcase) } + interpret_check_args { StringLiteral.new(@value.upcase) } else super end @@ -816,7 +774,7 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "expressions" - interpret_argless_method(method, args) { ArrayLiteral.new(expressions) } + interpret_check_args { ArrayLiteral.new(expressions) } else super end @@ -827,16 +785,16 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "of" - interpret_argless_method(method, args) { @of || Nop.new } + interpret_check_args { @of || Nop.new } when "type" - interpret_argless_method(method, args) { @name || Nop.new } + interpret_check_args { @name || Nop.new } when "clear" - interpret_argless_method(method, args) do + interpret_check_args do elements.clear self end else - value = interpret_array_or_tuple_method(self, ArrayLiteral, method, args, block, interpreter) + value = interpret_array_or_tuple_method(self, ArrayLiteral, method, args, named_args, block, interpreter) value || super end end @@ -846,21 +804,19 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "empty?" - interpret_argless_method(method, args) { BoolLiteral.new(entries.empty?) } + interpret_check_args { BoolLiteral.new(entries.empty?) } when "keys" - interpret_argless_method(method, args) { ArrayLiteral.map entries, &.key } + interpret_check_args { ArrayLiteral.map entries, &.key } when "size" - interpret_argless_method(method, args) { NumberLiteral.new(entries.size) } + interpret_check_args { NumberLiteral.new(entries.size) } when "to_a" - interpret_argless_method(method, args) do + interpret_check_args do ArrayLiteral.map(entries) { |entry| TupleLiteral.new([entry.key, entry.value] of ASTNode) } end when "values" - interpret_argless_method(method, args) { ArrayLiteral.map entries, &.value } + interpret_check_args { ArrayLiteral.map entries, &.value } when "each" - interpret_argless_method(method, args) do - raise "each expects a block" unless block - + interpret_check_args(uses_block: true) do block_arg_key = block.args[0]? block_arg_value = block.args[1]? @@ -873,9 +829,7 @@ module Crystal NilLiteral.new end when "map" - interpret_argless_method(method, args) do - raise "map expects a block" unless block - + interpret_check_args(uses_block: true) do block_arg_key = block.args[0]? block_arg_value = block.args[1]? @@ -886,37 +840,28 @@ module Crystal end end when "double_splat" - case args.size - when 0 - to_double_splat - when 1 - interpret_one_arg_method(method, args) do |arg| + interpret_check_args(min_count: 0) do |arg| + if arg + unless arg.is_a?(Crystal::StringLiteral) + arg.raise "argument to double_splat must be a StringLiteral, not #{arg.class_desc}" + end + if entries.empty? to_double_splat else - unless arg.is_a?(Crystal::StringLiteral) - arg.raise "argument to double_splat must be a StringLiteral, not #{arg.class_desc}" - end to_double_splat(arg.value) end + else + to_double_splat end - else - wrong_number_of_arguments "double_splat", args.size, 0..1 end when "[]" - case args.size - when 1 - key = args.first + interpret_check_args do |key| entry = entries.find &.key.==(key) entry.try(&.value) || NilLiteral.new - else - wrong_number_of_arguments "HashLiteral#[]", args.size, 1 end when "[]=" - case args.size - when 2 - key, value = args - + interpret_check_args do |key, value| index = entries.index &.key.==(key) if index entries[index] = HashLiteral::Entry.new(key, value) @@ -925,17 +870,15 @@ module Crystal end value - else - wrong_number_of_arguments "HashLiteral#[]=", args.size, 2 end when "of_key" - interpret_argless_method(method, args) { @of.try(&.key) || Nop.new } + interpret_check_args { @of.try(&.key) || Nop.new } when "of_value" - interpret_argless_method(method, args) { @of.try(&.value) || Nop.new } + interpret_check_args { @of.try(&.value) || Nop.new } when "type" - interpret_argless_method(method, args) { @name || Nop.new } + interpret_check_args { @name || Nop.new } when "clear" - interpret_argless_method(method, args) do + interpret_check_args do entries.clear self end @@ -955,21 +898,19 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "empty?" - interpret_argless_method(method, args) { BoolLiteral.new(entries.empty?) } + interpret_check_args { BoolLiteral.new(entries.empty?) } when "keys" - interpret_argless_method(method, args) { ArrayLiteral.map(entries) { |entry| MacroId.new(entry.key) } } + interpret_check_args { ArrayLiteral.map(entries) { |entry| MacroId.new(entry.key) } } when "size" - interpret_argless_method(method, args) { NumberLiteral.new(entries.size) } + interpret_check_args { NumberLiteral.new(entries.size) } when "to_a" - interpret_argless_method(method, args) do + interpret_check_args do ArrayLiteral.map(entries) { |entry| TupleLiteral.new([MacroId.new(entry.key), entry.value] of ASTNode) } end when "values" - interpret_argless_method(method, args) { ArrayLiteral.map entries, &.value } + interpret_check_args { ArrayLiteral.map entries, &.value } when "each" - interpret_argless_method(method, args) do - raise "each expects a block" unless block - + interpret_check_args(uses_block: true) do block_arg_key = block.args[0]? block_arg_value = block.args[1]? @@ -982,9 +923,7 @@ module Crystal NilLiteral.new end when "map" - interpret_argless_method(method, args) do - raise "map expects a block" unless block - + interpret_check_args(uses_block: true) do block_arg_key = block.args[0]? block_arg_value = block.args[1]? @@ -995,28 +934,23 @@ module Crystal end end when "double_splat" - case args.size - when 0 - to_double_splat - when 1 - interpret_one_arg_method(method, args) do |arg| + interpret_check_args(min_count: 0) do |arg| + if arg + unless arg.is_a?(Crystal::StringLiteral) + arg.raise "argument to double_splat must be a StringLiteral, not #{arg.class_desc}" + end + if entries.empty? to_double_splat else - unless arg.is_a?(Crystal::StringLiteral) - arg.raise "argument to double_splat must be a StringLiteral, not #{arg.class_desc}" - end to_double_splat(arg.value) end + else + to_double_splat end - else - wrong_number_of_arguments "double_splat", args.size, 0..1 end when "[]" - case args.size - when 1 - key = args.first - + interpret_check_args do |key| case key when SymbolLiteral key = key.value @@ -1030,14 +964,9 @@ module Crystal entry = entries.find &.key.==(key) entry.try(&.value) || NilLiteral.new - else - wrong_number_of_arguments "NamedTupleLiteral#[]", args.size, 1 end when "[]=" - case args.size - when 2 - key, value = args - + interpret_check_args do |key, value| case key when SymbolLiteral key = key.value @@ -1057,8 +986,6 @@ module Crystal end value - else - wrong_number_of_arguments "NamedTupleLiteral#[]=", args.size, 2 end else super @@ -1078,7 +1005,7 @@ module Crystal class TupleLiteral def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) - value = interpret_array_or_tuple_method(self, TupleLiteral, method, args, block, interpreter) + value = interpret_array_or_tuple_method(self, TupleLiteral, method, args, named_args, block, interpreter) value || super end end @@ -1087,45 +1014,45 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "begin" - interpret_argless_method(method, args) { self.from } + interpret_check_args { self.from } when "end" - interpret_argless_method(method, args) { self.to } + interpret_check_args { self.to } when "excludes_end?" - interpret_argless_method(method, args) { BoolLiteral.new(self.exclusive?) } + interpret_check_args { BoolLiteral.new(self.exclusive?) } when "each" - raise "each expects a block" unless block + interpret_check_args(uses_block: true) do + block_arg = block.args.first? - block_arg = block.args.first? + interpret_to_range(interpreter).each do |num| + interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg + interpreter.accept block.body + end - interpret_to_range(interpreter).each do |num| - interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg - interpreter.accept block.body + NilLiteral.new end - - NilLiteral.new when "map" - raise "map expects a block" unless block - - block_arg = block.args.first? + interpret_check_args(uses_block: true) do + block_arg = block.args.first? - interpret_map(method, args, interpreter) do |num| - interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg - interpreter.accept block.body + interpret_map(interpreter) do |num| + interpreter.define_var(block_arg.name, NumberLiteral.new(num)) if block_arg + interpreter.accept block.body + end end when "to_a" - interpret_map(method, args, interpreter) do |num| - NumberLiteral.new(num) + interpret_check_args do + interpret_map(interpreter) do |num| + NumberLiteral.new(num) + end end else super end end - def interpret_map(method, args, interpreter) - interpret_argless_method(method, args) do - ArrayLiteral.map(interpret_to_range(interpreter)) do |num| - yield num - end + def interpret_map(interpreter) + ArrayLiteral.map(interpret_to_range(interpreter)) do |num| + yield num end end @@ -1155,9 +1082,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "source" - interpret_argless_method(method, args) { @value } + interpret_check_args { @value } when "options" - interpret_argless_method(method, args) do + interpret_check_args do options = [] of Symbol options << :i if @options.ignore_case? options << :m if @options.multiline? @@ -1178,9 +1105,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { MacroId.new(@name) } + interpret_check_args { MacroId.new(@name) } when "type" - interpret_argless_method(method, args) do + interpret_check_args do if type = @type TypeNode.new(type) else @@ -1188,19 +1115,19 @@ module Crystal end end when "default_value" - interpret_argless_method(method, args) do + interpret_check_args do default_value || NilLiteral.new end when "has_default_value?" - interpret_argless_method(method, args) do + interpret_check_args do BoolLiteral.new(!!default_value) end when "annotation" - fetch_annotation(self, method, args) do |type| + fetch_annotation(self, method, args, named_args, block) do |type| self.var.annotation(type) end when "annotations" - fetch_annotation(self, method, args) do |type| + fetch_annotation(self, method, args, named_args, block) do |type| annotations = self.var.annotations(type) return ArrayLiteral.new if annotations.nil? ArrayLiteral.map(annotations, &.itself) @@ -1215,13 +1142,13 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "body" - interpret_argless_method(method, args) { @body } + interpret_check_args { @body } when "args" - interpret_argless_method(method, args) do + interpret_check_args do ArrayLiteral.map(@args) { |arg| MacroId.new(arg.name) } end when "splat_index" - interpret_argless_method(method, args) do + interpret_check_args do @splat_index ? NumberLiteral.new(@splat_index.not_nil!) : NilLiteral.new end else @@ -1234,9 +1161,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "inputs" - interpret_argless_method(method, args) { ArrayLiteral.new(@inputs || [] of ASTNode) } + interpret_check_args { ArrayLiteral.new(@inputs || [] of ASTNode) } when "output" - interpret_argless_method(method, args) { @output || NilLiteral.new } + interpret_check_args { @output || NilLiteral.new } else super end @@ -1258,11 +1185,11 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "obj" - interpret_argless_method(method, args) { @obj || NilLiteral.new } + interpret_check_args { @obj || NilLiteral.new } when "name" - interpret_argless_method(method, args) { MacroId.new(@name) } + interpret_check_args { MacroId.new(@name) } when "args" - interpret_argless_method(method, args) { ArrayLiteral.new(@args) } + interpret_check_args { ArrayLiteral.new(@args) } else super end @@ -1273,7 +1200,7 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "expressions" - interpret_argless_method(method, args) do + interpret_check_args do ArrayLiteral.map(@expressions) { |expression| expression } end else @@ -1286,9 +1213,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "left" - interpret_argless_method(method, args) { @left } + interpret_check_args { @left } when "right" - interpret_argless_method(method, args) { @right } + interpret_check_args { @right } else super end @@ -1299,15 +1226,15 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "var" - interpret_argless_method(method, args) do + interpret_check_args do var = @var var = MacroId.new(var.name) if var.is_a?(Var) var end when "type" - interpret_argless_method(method, args) { @declared_type } + interpret_check_args { @declared_type } when "value" - interpret_argless_method(method, args) { @value || Nop.new } + interpret_check_args { @value || Nop.new } else super end @@ -1318,13 +1245,13 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "var" - interpret_argless_method(method, args) do + interpret_check_args do var = @var var = MacroId.new(var.name) if var.is_a?(Var) var end when "type" - interpret_argless_method(method, args) { @declared_type } + interpret_check_args { @declared_type } else super end @@ -1335,11 +1262,11 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "resolve" - interpret_argless_method(method, args) { interpreter.resolve(self) } + interpret_check_args { interpreter.resolve(self) } when "resolve?" - interpret_argless_method(method, args) { interpreter.resolve?(self) || NilLiteral.new } + interpret_check_args { interpreter.resolve?(self) || NilLiteral.new } when "types" - interpret_argless_method(method, args) { ArrayLiteral.new(@types) } + interpret_check_args { ArrayLiteral.new(@types) } else super end @@ -1350,13 +1277,13 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { MacroId.new(external_name) } + interpret_check_args { MacroId.new(external_name) } when "internal_name" - interpret_argless_method(method, args) { MacroId.new(name) } + interpret_check_args { MacroId.new(name) } when "default_value" - interpret_argless_method(method, args) { default_value || Nop.new } + interpret_check_args { default_value || Nop.new } when "restriction" - interpret_argless_method(method, args) { restriction || Nop.new } + interpret_check_args { restriction || Nop.new } else super end @@ -1367,35 +1294,35 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { MacroId.new(@name) } + interpret_check_args { MacroId.new(@name) } when "args" - interpret_argless_method(method, args) { ArrayLiteral.map @args, &.itself } + interpret_check_args { ArrayLiteral.map @args, &.itself } when "splat_index" - interpret_argless_method(method, args) do + interpret_check_args do @splat_index ? NumberLiteral.new(@splat_index.not_nil!) : NilLiteral.new end when "double_splat" - interpret_argless_method(method, args) { @double_splat || Nop.new } + interpret_check_args { @double_splat || Nop.new } when "block_arg" - interpret_argless_method(method, args) { @block_arg || Nop.new } + interpret_check_args { @block_arg || Nop.new } when "accepts_block?" - interpret_argless_method(method, args) { BoolLiteral.new(@yields != nil) } + interpret_check_args { BoolLiteral.new(@yields != nil) } when "return_type" - interpret_argless_method(method, args) { @return_type || Nop.new } + interpret_check_args { @return_type || Nop.new } when "body" - interpret_argless_method(method, args) { @body } + interpret_check_args { @body } when "receiver" - interpret_argless_method(method, args) { @receiver || Nop.new } + interpret_check_args { @receiver || Nop.new } when "visibility" - interpret_argless_method(method, args) do + interpret_check_args do visibility_to_symbol(@visibility) end when "annotation" - fetch_annotation(self, method, args) do |type| + fetch_annotation(self, method, args, named_args, block) do |type| self.annotation(type) end when "annotations" - fetch_annotation(self, method, args) do |type| + fetch_annotation(self, method, args, named_args, block) do |type| annotations = self.annotations(type) return ArrayLiteral.new if annotations.nil? ArrayLiteral.map(annotations, &.itself) @@ -1410,21 +1337,21 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { MacroId.new(@name) } + interpret_check_args { MacroId.new(@name) } when "args" - interpret_argless_method(method, args) { ArrayLiteral.map @args, &.itself } + interpret_check_args { ArrayLiteral.map @args, &.itself } when "splat_index" - interpret_argless_method(method, args) do + interpret_check_args do @splat_index ? NumberLiteral.new(@splat_index.not_nil!) : NilLiteral.new end when "double_splat" - interpret_argless_method(method, args) { @double_splat || Nop.new } + interpret_check_args { @double_splat || Nop.new } when "block_arg" - interpret_argless_method(method, args) { @block_arg || Nop.new } + interpret_check_args { @block_arg || Nop.new } when "body" - interpret_argless_method(method, args) { @body } + interpret_check_args { @body } when "visibility" - interpret_argless_method(method, args) do + interpret_check_args do visibility_to_symbol(@visibility) end else @@ -1437,7 +1364,7 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "exp" - interpret_argless_method(method, args) { @exp } + interpret_check_args { @exp } else super end @@ -1448,9 +1375,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "type" - interpret_argless_method(method, args) { @offsetof_type } + interpret_check_args { @offsetof_type } when "offset" - interpret_argless_method(method, args) { @offset } + interpret_check_args { @offset } else super end @@ -1461,9 +1388,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "exp" - interpret_argless_method(method, args) { @exp } + interpret_check_args { @exp } when "visibility" - interpret_argless_method(method, args) do + interpret_check_args do visibility_to_symbol(@modifier) end else @@ -1476,9 +1403,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "receiver" - interpret_argless_method(method, args) { @obj } + interpret_check_args { @obj } when "arg" - interpret_argless_method(method, args) { @const } + interpret_check_args { @const } else super end @@ -1489,9 +1416,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "receiver" - interpret_argless_method(method, args) { @obj } + interpret_check_args { @obj } when "name" - interpret_argless_method(method, args) { StringLiteral.new(@name) } + interpret_check_args { StringLiteral.new(@name) } else super end @@ -1502,7 +1429,7 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "path" - interpret_argless_method(method, args) { StringLiteral.new(@string) } + interpret_check_args { StringLiteral.new(@string) } else super end @@ -1513,15 +1440,17 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "==", "!=" - case arg = args.first? - when StringLiteral, SymbolLiteral - if method == "==" - BoolLiteral.new(@value == arg.value) + interpret_check_args do |arg| + case arg + when StringLiteral, SymbolLiteral + if method == "==" + BoolLiteral.new(@value == arg.value) + else + BoolLiteral.new(@value != arg.value) + end else - BoolLiteral.new(@value != arg.value) + super end - else - super end when "stringify", "class_name", "symbolize" super @@ -1543,15 +1472,17 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "==", "!=" - case arg = args.first? - when MacroId - if method == "==" - BoolLiteral.new(@value == arg.value) + interpret_check_args do |arg| + case arg + when MacroId + if method == "==" + BoolLiteral.new(@value == arg.value) + else + BoolLiteral.new(@value != arg.value) + end else - BoolLiteral.new(@value != arg.value) + super end - else - super end when "stringify", "class_name", "symbolize" super @@ -1569,21 +1500,21 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "abstract?" - interpret_argless_method(method, args) { BoolLiteral.new(type.abstract?) } + interpret_check_args { BoolLiteral.new(type.abstract?) } when "union?" - interpret_argless_method(method, args) { BoolLiteral.new(type.is_a?(UnionType)) } + interpret_check_args { BoolLiteral.new(type.is_a?(UnionType)) } when "module?" - interpret_argless_method(method, args) { BoolLiteral.new(type.module?) } + interpret_check_args { BoolLiteral.new(type.module?) } when "class?" - interpret_argless_method(method, args) { BoolLiteral.new(type.class? && !type.struct?) } + interpret_check_args { BoolLiteral.new(type.class? && !type.struct?) } when "struct?" - interpret_argless_method(method, args) { BoolLiteral.new(type.class? && type.struct?) } + interpret_check_args { BoolLiteral.new(type.class? && type.struct?) } when "nilable?" - interpret_argless_method(method, args) { BoolLiteral.new(type.nilable?) } + interpret_check_args { BoolLiteral.new(type.nilable?) } when "union_types" - interpret_argless_method(method, args) { TypeNode.union_types(self) } + interpret_check_args { TypeNode.union_types(self) } when "name" - interpret_argless_method(method, args) do + interpret_check_args(named_params: ["generic_args"]) do generic_args = if named_args && (generic_arg = named_args["generic_args"]?) generic_arg else @@ -1595,52 +1526,52 @@ module Crystal MacroId.new(type.devirtualize.to_s(generic_args: generic_args.value)) end when "type_vars" - interpret_argless_method(method, args) { TypeNode.type_vars(type) } + interpret_check_args { TypeNode.type_vars(type) } when "instance_vars" - interpret_argless_method(method, args) { TypeNode.instance_vars(type) } + interpret_check_args { TypeNode.instance_vars(type) } when "class_vars" - interpret_argless_method(method, args) { TypeNode.class_vars(type) } + interpret_check_args { TypeNode.class_vars(type) } when "ancestors" - interpret_argless_method(method, args) { TypeNode.ancestors(type) } + interpret_check_args { TypeNode.ancestors(type) } when "superclass" - interpret_argless_method(method, args) { TypeNode.superclass(type) } + interpret_check_args { TypeNode.superclass(type) } when "subclasses" - interpret_argless_method(method, args) { TypeNode.subclasses(type) } + interpret_check_args { TypeNode.subclasses(type) } when "all_subclasses" - interpret_argless_method(method, args) { TypeNode.all_subclasses(type) } + interpret_check_args { TypeNode.all_subclasses(type) } when "includers" - interpret_argless_method(method, args) { TypeNode.includers(type) } + interpret_check_args { TypeNode.includers(type) } when "constants" - interpret_argless_method(method, args) { TypeNode.constants(type) } + interpret_check_args { TypeNode.constants(type) } when "constant" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| value = arg.to_string("argument to 'TypeNode#constant'") TypeNode.constant(type, value) end when "has_constant?" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| value = arg.to_string("argument to 'TypeNode#has_constant?'") TypeNode.has_constant?(type, value) end when "methods" - interpret_argless_method(method, args) { TypeNode.methods(type) } + interpret_check_args { TypeNode.methods(type) } when "has_method?" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| value = arg.to_string("argument to 'TypeNode#has_method?'") TypeNode.has_method?(type, value) end when "annotation" - fetch_annotation(self, method, args) do |type| + fetch_annotation(self, method, args, named_args, block) do |type| self.type.annotation(type) end when "annotations" - fetch_annotation(self, method, args) do |type| + fetch_annotation(self, method, args, named_args, block) do |type| annotations = self.type.annotations(type) return ArrayLiteral.new if annotations.nil? ArrayLiteral.map(annotations, &.itself) end when "size" - interpret_argless_method(method, args) do + interpret_check_args do type = self.type.instance_type case type when TupleInstanceType @@ -1652,7 +1583,7 @@ module Crystal end end when "keys" - interpret_argless_method(method, args) do + interpret_check_args do type = self.type.instance_type if type.is_a?(NamedTupleInstanceType) ArrayLiteral.map(type.entries) { |entry| MacroId.new(entry.name) } @@ -1661,7 +1592,7 @@ module Crystal end end when "[]" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| type = self.type.instance_type case type when NamedTupleInstanceType @@ -1695,11 +1626,11 @@ module Crystal end end when "class" - interpret_argless_method(method, args) { TypeNode.new(type.metaclass) } + interpret_check_args { TypeNode.new(type.metaclass) } when "instance" - interpret_argless_method(method, args) { TypeNode.new(type.instance_type) } + interpret_check_args { TypeNode.new(type.instance_type) } when "==", "!=" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| return super unless arg.is_a?(TypeNode) self_type = self.type.devirtualize @@ -1713,7 +1644,7 @@ module Crystal end end when "<", "<=", ">", ">=" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| unless arg.is_a?(TypeNode) raise "TypeNode##{method} expects TypeNode, not #{arg.class_desc}" end @@ -1734,7 +1665,7 @@ module Crystal BoolLiteral.new(!!value) end when "overrides?" - interpret_two_args_method(method, args) do |arg1, arg2| + interpret_check_args do |arg1, arg2| unless arg1.is_a?(TypeNode) raise "TypeNode##{method} expects TypeNode as a first argument, not #{arg1.class_desc}" end @@ -1743,9 +1674,9 @@ module Crystal TypeNode.overrides?(type, arg1.type, value) end when "resolve" - interpret_argless_method(method, args) { self } + interpret_check_args { self } when "resolve?" - interpret_argless_method(method, args) { self } + interpret_check_args { self } else super end @@ -1934,13 +1865,13 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { MacroId.new(name) } + interpret_check_args { MacroId.new(name) } when "receiver" - interpret_argless_method(method, args) { obj || Nop.new } + interpret_check_args { obj || Nop.new } when "args" - interpret_argless_method(method, args) { ArrayLiteral.map self.args, &.itself } + interpret_check_args { ArrayLiteral.map self.args, &.itself } when "named_args" - interpret_argless_method(method, args) do + interpret_check_args do if named_args = self.named_args ArrayLiteral.map(named_args) { |arg| arg } else @@ -1948,9 +1879,9 @@ module Crystal end end when "block" - interpret_argless_method(method, args) { self.block || Nop.new } + interpret_check_args { self.block || Nop.new } when "block_arg" - interpret_argless_method(method, args) { self.block_arg || Nop.new } + interpret_check_args { self.block_arg || Nop.new } else super end @@ -1969,9 +1900,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { MacroId.new(name) } + interpret_check_args { MacroId.new(name) } when "value" - interpret_argless_method(method, args) { value } + interpret_check_args { value } else super end @@ -1982,11 +1913,11 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "cond" - interpret_argless_method(method, args) { @cond } + interpret_check_args { @cond } when "then" - interpret_argless_method(method, args) { @then } + interpret_check_args { @then } when "else" - interpret_argless_method(method, args) { @else } + interpret_check_args { @else } else super end @@ -1997,11 +1928,11 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "cond" - interpret_argless_method(method, args) { cond || Nop.new } + interpret_check_args { cond || Nop.new } when "whens" - interpret_argless_method(method, args) { ArrayLiteral.map whens, &.itself } + interpret_check_args { ArrayLiteral.map whens, &.itself } when "else" - interpret_argless_method(method, args) { self.else || Nop.new } + interpret_check_args { self.else || Nop.new } else super end @@ -2012,9 +1943,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "conds" - interpret_argless_method(method, args) { ArrayLiteral.new(conds) } + interpret_check_args { ArrayLiteral.new(conds) } when "body" - interpret_argless_method(method, args) { body } + interpret_check_args { body } else super end @@ -2025,9 +1956,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "target" - interpret_argless_method(method, args) { target } + interpret_check_args { target } when "value" - interpret_argless_method(method, args) { value } + interpret_check_args { value } else super end @@ -2038,9 +1969,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "targets" - interpret_argless_method(method, args) { ArrayLiteral.new(targets) } + interpret_check_args { ArrayLiteral.new(targets) } when "values" - interpret_argless_method(method, args) { ArrayLiteral.new(values) } + interpret_check_args { ArrayLiteral.new(values) } else super end @@ -2055,7 +1986,7 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { MacroId.new(@name) } + interpret_check_args { MacroId.new(@name) } else super end @@ -2066,9 +1997,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "obj" - interpret_argless_method(method, args) { @obj } + interpret_check_args { @obj } when "name" - interpret_argless_method(method, args) { MacroId.new(@name) } + interpret_check_args { MacroId.new(@name) } else super end @@ -2083,7 +2014,7 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { MacroId.new(@name) } + interpret_check_args { MacroId.new(@name) } else super end @@ -2098,7 +2029,7 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { MacroId.new(@name) } + interpret_check_args { MacroId.new(@name) } else super end @@ -2109,20 +2040,20 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "names" - interpret_argless_method(method, args) do + interpret_check_args do ArrayLiteral.map(@names) { |name| MacroId.new(name) } end when "global" interpreter.report_warning_at(name_loc, "Deprecated Path#global. Use `#global?` instead") - interpret_argless_method(method, args) { BoolLiteral.new(@global) } + interpret_check_args { BoolLiteral.new(@global) } when "global?" - interpret_argless_method(method, args) { BoolLiteral.new(@global) } + interpret_check_args { BoolLiteral.new(@global) } when "resolve" - interpret_argless_method(method, args) { interpreter.resolve(self) } + interpret_check_args { interpreter.resolve(self) } when "resolve?" - interpret_argless_method(method, args) { interpreter.resolve?(self) || NilLiteral.new } + interpret_check_args { interpreter.resolve?(self) || NilLiteral.new } when "types" - interpret_argless_method(method, args) { ArrayLiteral.new([self] of ASTNode) } + interpret_check_args { ArrayLiteral.new([self] of ASTNode) } else super end @@ -2137,9 +2068,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "cond" - interpret_argless_method(method, args) { @cond } + interpret_check_args { @cond } when "body" - interpret_argless_method(method, args) { @body } + interpret_check_args { @body } else super end @@ -2150,9 +2081,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "obj" - interpret_argless_method(method, args) { obj } + interpret_check_args { obj } when "to" - interpret_argless_method(method, args) { to } + interpret_check_args { to } else super end @@ -2163,9 +2094,9 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "obj" - interpret_argless_method(method, args) { obj } + interpret_check_args { obj } when "to" - interpret_argless_method(method, args) { to } + interpret_check_args { to } else super end @@ -2176,11 +2107,11 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "name" - interpret_argless_method(method, args) { name } + interpret_check_args { name } when "type_vars" - interpret_argless_method(method, args) { ArrayLiteral.new(type_vars) } + interpret_check_args { ArrayLiteral.new(type_vars) } when "named_args" - interpret_argless_method(method, args) do + interpret_check_args do if named_args = @named_args NamedTupleLiteral.new(named_args.map { |arg| NamedTupleLiteral::Entry.new(arg.name, arg.value) }) else @@ -2188,11 +2119,11 @@ module Crystal end end when "resolve" - interpret_argless_method(method, args) { interpreter.resolve(self) } + interpret_check_args { interpreter.resolve(self) } when "resolve?" - interpret_argless_method(method, args) { interpreter.resolve?(self) || NilLiteral.new } + interpret_check_args { interpreter.resolve?(self) || NilLiteral.new } when "types" - interpret_argless_method(method, args) { ArrayLiteral.new([self] of ASTNode) } + interpret_check_args { ArrayLiteral.new([self] of ASTNode) } else super end @@ -2203,7 +2134,7 @@ module Crystal def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method when "[]" - interpret_one_arg_method(method, args) do |arg| + interpret_check_args do |arg| case arg when NumberLiteral index = arg.to_number.to_i @@ -2221,11 +2152,11 @@ module Crystal named_arg.try(&.value) || NilLiteral.new end when "args" - interpret_argless_method(method, args) do + interpret_check_args do TupleLiteral.new self.args end when "named_args" - interpret_argless_method(method, args) do + interpret_check_args do get_named_annotation_args self end else @@ -2243,12 +2174,10 @@ private def get_named_annotation_args(object) end end -private def interpret_array_or_tuple_method(object, klass, method, args, block, interpreter) +private def interpret_array_or_tuple_method(object, klass, method, args, named_args, block, interpreter) case method when "any?" - object.interpret_argless_method(method, args) do - raise "any? expects a block" unless block - + interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? Crystal::BoolLiteral.new(object.elements.any? do |elem| @@ -2257,9 +2186,7 @@ private def interpret_array_or_tuple_method(object, klass, method, args, block, end) end when "all?" - object.interpret_argless_method(method, args) do - raise "all? expects a block" unless block - + interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? Crystal::BoolLiteral.new(object.elements.all? do |elem| @@ -2268,29 +2195,25 @@ private def interpret_array_or_tuple_method(object, klass, method, args, block, end) end when "splat" - case args.size - when 0 - Crystal::MacroId.new(object.elements.join ", ") - when 1 - object.interpret_one_arg_method(method, args) do |arg| + interpret_check_args(node: object, min_count: 0) do |arg| + if arg + unless arg.is_a?(Crystal::StringLiteral) + arg.raise "argument to splat must be a StringLiteral, not #{arg.class_desc}" + end + if object.elements.empty? Crystal::MacroId.new("") else - unless arg.is_a?(Crystal::StringLiteral) - arg.raise "argument to splat must be a StringLiteral, not #{arg.class_desc}" - end Crystal::MacroId.new((object.elements.join ", ") + arg.value) end + else + Crystal::MacroId.new(object.elements.join ", ") end - else - object.wrong_number_of_arguments "#{klass}#splat", args.size, "0..1" end when "empty?" - object.interpret_argless_method(method, args) { Crystal::BoolLiteral.new(object.elements.empty?) } + interpret_check_args(node: object) { Crystal::BoolLiteral.new(object.elements.empty?) } when "find" - object.interpret_argless_method(method, args) do - raise "find expects a block" unless block - + interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? found = object.elements.find do |elem| @@ -2300,23 +2223,21 @@ private def interpret_array_or_tuple_method(object, klass, method, args, block, found ? found : Crystal::NilLiteral.new end when "first" - object.interpret_argless_method(method, args) { object.elements.first? || Crystal::NilLiteral.new } + interpret_check_args(node: object) { object.elements.first? || Crystal::NilLiteral.new } when "includes?" - object.interpret_one_arg_method(method, args) do |arg| + interpret_check_args(node: object) do |arg| Crystal::BoolLiteral.new(object.elements.includes?(arg)) end when "join" - object.interpret_one_arg_method(method, args) do |arg| + interpret_check_args(node: object) do |arg| Crystal::StringLiteral.new(object.elements.map(&.to_macro_id).join arg.to_macro_id) end when "last" - object.interpret_argless_method(method, args) { object.elements.last? || Crystal::NilLiteral.new } + interpret_check_args(node: object) { object.elements.last? || Crystal::NilLiteral.new } when "size" - object.interpret_argless_method(method, args) { Crystal::NumberLiteral.new(object.elements.size) } + interpret_check_args(node: object) { Crystal::NumberLiteral.new(object.elements.size) } when "each" - object.interpret_argless_method(method, args) do - raise "each expects a block" unless block - + interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? object.elements.each do |elem| @@ -2327,9 +2248,7 @@ private def interpret_array_or_tuple_method(object, klass, method, args, block, Crystal::NilLiteral.new end when "each_with_index" - object.interpret_argless_method(method, args) do - raise "each_with_index expects a block" unless block - + interpret_check_args(node: object, uses_block: true) do block_arg = block.args[0]? index_arg = block.args[1]? @@ -2342,9 +2261,7 @@ private def interpret_array_or_tuple_method(object, klass, method, args, block, Crystal::NilLiteral.new end when "map" - object.interpret_argless_method(method, args) do - raise "map expects a block" unless block - + interpret_check_args(node: object, uses_block: true) do block_arg = block.args.first? klass.map(object.elements) do |elem| @@ -2353,9 +2270,7 @@ private def interpret_array_or_tuple_method(object, klass, method, args, block, end end when "map_with_index" - object.interpret_argless_method(method, args) do - raise "map_with_index expects a block" unless block - + interpret_check_args(node: object, uses_block: true) do block_arg = block.args[0]? index_arg = block.args[1]? @@ -2366,96 +2281,83 @@ private def interpret_array_or_tuple_method(object, klass, method, args, block, end end when "select" - object.interpret_argless_method(method, args) do - raise "select expects a block" unless block + interpret_check_args(node: object, uses_block: true) do filter(object, klass, block, interpreter) end when "reject" - object.interpret_argless_method(method, args) do - raise "reject expects a block" unless block + interpret_check_args(node: object, uses_block: true) do filter(object, klass, block, interpreter, keep: false) end when "reduce" - raise "reduce expects a block" unless block - accumulate_arg = block.args.first? - value_arg = block.args[1]? - case args.size - when 0 - object.interpret_argless_method(method, args) do - object.elements.reduce do |accumulate, elem| + interpret_check_args(node: object, min_count: 0, uses_block: true) do |memo| + accumulate_arg = block.args.first? + value_arg = block.args[1]? + + if memo + object.elements.reduce(memo) do |accumulate, elem| interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg interpreter.define_var(value_arg.name, elem) if value_arg interpreter.accept block.body end - end - when 1 - object.interpret_one_arg_method(method, args) do |arg| - object.elements.reduce(arg) do |accumulate, elem| + else + object.elements.reduce do |accumulate, elem| interpreter.define_var(accumulate_arg.name, accumulate) if accumulate_arg interpreter.define_var(value_arg.name, elem) if value_arg interpreter.accept block.body end end - else - raise "only 0 or 1 args expected for reduce, got #{args.size}" end when "shuffle" - klass.new(object.elements.shuffle) + interpret_check_args(node: object) { klass.new(object.elements.shuffle) } when "sort" - klass.new(object.elements.sort { |x, y| x.interpret_compare(y) }) + interpret_check_args(node: object) { klass.new(object.elements.sort { |x, y| x.interpret_compare(y) }) } when "sort_by" - object.interpret_argless_method(method, args) do - raise "sort_by expects a block" unless block - + interpret_check_args(node: object, uses_block: true) do sort_by(object, klass, block, interpreter) end when "uniq" - klass.new(object.elements.uniq) + interpret_check_args(node: object) { klass.new(object.elements.uniq) } when "[]" - case args.size - when 1 - arg = args.first - case arg - when Crystal::NumberLiteral - index = arg.to_number.to_i - value = object.elements[index]? || Crystal::NilLiteral.new - when Crystal::RangeLiteral - range = arg.interpret_to_range(interpreter) + interpret_check_args(node: object, min_count: 1) do |from, to| + if to + from = interpreter.accept(from) + to = interpreter.accept(to) + + unless from.is_a?(Crystal::NumberLiteral) + from.raise "expected first argument to RangeLiteral#[] to be a number, not #{from.class_desc}" + end + + unless to.is_a?(Crystal::NumberLiteral) + to.raise "expected second argument to RangeLiteral#[] to be a number, not #{from.class_desc}" + end + + from = from.to_number.to_i + to = to.to_number.to_i + begin - klass.new(object.elements[range]) + klass.new(object.elements[from, to]) rescue ex object.raise ex.message end else - arg.raise "argument to [] must be a number or range, not #{arg.class_desc}:\n\n#{arg}" - end - when 2 - from, to = args - - from = interpreter.accept(from) - to = interpreter.accept(to) - - unless from.is_a?(Crystal::NumberLiteral) - from.raise "expected first argument to RangeLiteral#[] to be a number, not #{from.class_desc}" - end - - unless to.is_a?(Crystal::NumberLiteral) - to.raise "expected second argument to RangeLiteral#[] to be a number, not #{from.class_desc}" - end - - from = from.to_number.to_i - to = to.to_number.to_i - - begin - klass.new(object.elements[from, to]) - rescue ex - object.raise ex.message + case arg = from + when Crystal::NumberLiteral + index = arg.to_number.to_i + value = object.elements[index]? || Crystal::NilLiteral.new + when Crystal::RangeLiteral + range = arg.interpret_to_range(interpreter) + begin + klass.new(object.elements[range]) + rescue ex + object.raise ex.message + end + else + arg.raise "argument to [] must be a number or range, not #{arg.class_desc}:\n\n#{arg}" + end end - else - object.wrong_number_of_arguments "#{klass}#[]", args.size, 1 end when "[]=" - object.interpret_two_args_method(method, args) do |index_node, value| + interpret_check_args(node: object) do |index_node, value| unless index_node.is_a?(Crystal::NumberLiteral) index_node.raise "expected index argument to ArrayLiteral#[]= to be a number, not #{index_node.class_desc}" end @@ -2471,23 +2373,17 @@ private def interpret_array_or_tuple_method(object, klass, method, args, block, value end when "unshift" - case args.size - when 1 - object.elements.unshift(args.first) + interpret_check_args(node: object) do |arg| + object.elements.unshift(arg) object - else - object.wrong_number_of_arguments "#{klass}#unshift", args.size, 1 end when "push", "<<" - case args.size - when 1 - object.elements << args.first + interpret_check_args(node: object) do |arg| + object.elements << arg object - else - object.wrong_number_of_arguments "#{klass}##{method}", args.size, 1 end when "+" - object.interpret_one_arg_method(method, args) do |arg| + interpret_check_args(node: object) do |arg| case arg when Crystal::TupleLiteral other_elements = arg.elements @@ -2503,6 +2399,88 @@ private def interpret_array_or_tuple_method(object, klass, method, args, block, end end +# Checks the following in an invocation of a macro `foo`: +# +# * The number of macro arguments to `foo` matches the number of block +# parameters to this macro. If `min_count` is given then only that many macro +# parameters are required, others are optional and this macro's corresponding +# block parameter will receive `nil` instead. +# * If `named_params` is true, any named arguments to `foo` are allowed. If it +# is falsey (the default), no named arguments are allowed. Otherwise, only +# named arguments included by `named_params` are allowed. The block parameters +# of this macro are unaffected by named arguments. +# * There is a block supplied to `foo` if and only if `uses_block` is true. +# +# `top_level` affects how error messages are formatted. +# +# Accesses the `method`, `args`, `named_args`, and `block` variables in the +# current scope. +private macro interpret_check_args(*, node = self, min_count = nil, named_params = nil, uses_block = false, top_level = false, &block) + {% if uses_block %} + unless block + %full_name = full_macro_name({{ node }}, method, {{ top_level }}) + {{ node }}.raise "#{%full_name} is expected to be invoked with a block, but no block was given" + end + {% else %} + if block + %full_name = full_macro_name({{ node }}, method, {{ top_level }}) + {{ node }}.raise "#{%full_name} is not expected to be invoked with a block, but a block was given" + end + {% end %} + + {% if !named_params %} + if named_args && !named_args.empty? + %full_name = full_macro_name({{ node }}, method, {{ top_level }}) + {{ node }}.raise "named arguments are not allowed here" + end + {% elsif named_params != true %} + if named_args + allowed_keys = {{ named_params }} + named_args.each_key do |name| + {{ node }}.raise "no named parameter '#{name}'" unless allowed_keys.includes?(name) + end + end + {% end %} + + {% if min_count %} + unless {{ min_count }} <= args.size <= {{ block.args.size }} + %full_name = full_macro_name({{ node }}, method, {{ top_level }}) + {{ node }}.wrong_number_of_arguments %full_name, args.size, {{ min_count }}..{{ block.args.size }} + end + + {% for var, i in block.args %} + {{ var }} = args[{{ i }}]{% if i >= min_count %}?{% end %} + {% end %} + {% else %} + unless args.size == {{ block.args.size }} + %full_name = full_macro_name({{ node }}, method, {{ top_level }}) + {{ node }}.wrong_number_of_arguments %full_name, args.size, {{ block.args.size }} + end + + {% for var, i in block.args %} + {{ var }} = args[{{ i }}] + {% end %} + {% end %} + + {{ block.body }} +end + +private macro interpret_check_args_toplevel(*, min_count = nil, uses_block = false, &block) + method = node.name + args = node.args + named_args = node.named_args + block = node.block + interpret_check_args(node: node, min_count: {{ min_count }}, uses_block: {{ uses_block }}, top_level: true) {{ block }} +end + +private def full_macro_name(node, method, top_level) + if top_level + "top-level macro '#{method}'" + else + "macro '#{node.class_desc}##{method}'" + end +end + private def visibility_to_symbol(visibility) visibility_name = case visibility @@ -2540,8 +2518,8 @@ private def filter(object, klass, block, interpreter, keep = true) end) end -private def fetch_annotation(node, method, args) - node.interpret_one_arg_method(method, args) do |arg| +private def fetch_annotation(node, method, args, named_args, block) + interpret_check_args(node: node) do |arg| unless arg.is_a?(Crystal::TypeNode) args[0].raise "argument to '#{node.class_desc}#annotation' must be a TypeNode, not #{arg.class_desc}" end