diff --git a/base/essentials.jl b/base/essentials.jl index c5322273ba4a5..747d424b36fa3 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -34,9 +34,16 @@ convert(::Type{Tuple{Vararg{T}}}, x::Tuple) where {T} = cnvt_all(T, x...) cnvt_all(T) = () cnvt_all(T, x, rest...) = tuple(convert(T,x), cnvt_all(T, rest...)...) +function eventually_call(ex) + isa(ex, Expr) && (ex.head === :call || + ((ex.head === :where || ex.head === :(::)) && + eventually_call(ex.args[1]))) +end + macro generated(f) isa(f, Expr) || error("invalid syntax; @generated must be used with a function definition") - if f.head === :function || (isdefined(:length) && f.head === :(=) && length(f.args) == 2 && f.args[1].head == :call) + if f.head === :function || (isdefined(:length) && f.head === :(=) && length(f.args) == 2 && + eventually_call(f.args[1])) f.head = :stagedfunction return Expr(:escape, f) else diff --git a/src/julia-parser.scm b/src/julia-parser.scm index 7e05db3f7ff89..264a9ca6d3e3e 100644 --- a/src/julia-parser.scm +++ b/src/julia-parser.scm @@ -613,7 +613,7 @@ (define (eventually-call ex) (and (pair? ex) (or (eq? (car ex) 'call) - (and (eq? (car ex) 'where) + (and (or (eq? (car ex) 'where) (eq? (car ex) '|::|)) (eventually-call (cadr ex)))))) ;; insert line/file for short-form function defs, otherwise leave alone diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 7f9a5085f2a97..9b12db8a70c44 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -828,18 +828,18 @@ (pattern-replace (pattern-set ;; definitions without `where` - (pattern-lambda (function (call name . sig) body) + (pattern-lambda (function (-$ (call name . sig) (|::| (call name . sig) _t)) body) (ctor-def (car __) name Tname params bounds sig ctor-body body #f)) - (pattern-lambda (stagedfunction (call name . sig) body) + (pattern-lambda (stagedfunction (-$ (call name . sig) (|::| (call name . sig) _t)) body) (ctor-def (car __) name Tname params bounds sig ctor-body body #f)) - (pattern-lambda (= (call name . sig) body) + (pattern-lambda (= (-$ (call name . sig) (|::| (call name . sig) _t)) body) (ctor-def 'function name Tname params bounds sig ctor-body body #f)) ;; definitions with `where` - (pattern-lambda (function (where (call name . sig) . wheres) body) + (pattern-lambda (function (where (-$ (call name . sig) (|::| (call name . sig) _t)) . wheres) body) (ctor-def (car __) name Tname params bounds sig ctor-body body wheres)) - (pattern-lambda (stagedfunction (where (call name . sig) . wheres) body) + (pattern-lambda (stagedfunction (where (-$ (call name . sig) (|::| (call name . sig) _t)) . wheres) body) (ctor-def (car __) name Tname params bounds sig ctor-body body wheres)) - (pattern-lambda (= (where (call name . sig) . wheres) body) + (pattern-lambda (= (where (-$ (call name . sig) (|::| (call name . sig) _t)) . wheres) body) (ctor-def 'function name Tname params bounds sig ctor-body body wheres))) ;; flatten `where`s first @@ -1330,9 +1330,11 @@ (expand-forms (expand-decls (car e) (cdr e) #f)))) (define (assigned-name e) - (if (and (pair? e) (memq (car e) '(call curly))) - (assigned-name (cadr e)) - e)) + (cond ((atom? e) e) + ((or (memq (car e) '(call curly where)) + (and (eq? (car e) '|::|) (eventually-call e))) + (assigned-name (cadr e))) + (else e))) ;; local x, y=2, z => local x;local y;local z;y = 2 (define (expand-decls what binds const?) @@ -2428,7 +2430,8 @@ (else '()))) (define (all-decl-vars e) ;; map decl-var over every level of an assignment LHS - (cond ((decl? e) (decl-var e)) + (cond ((eventually-call e) e) + ((decl? e) (decl-var e)) ((and (pair? e) (eq? (car e) 'tuple)) (cons 'tuple (map all-decl-vars (cdr e)))) (else e))) diff --git a/src/macroexpand.scm b/src/macroexpand.scm index 62d6b611bf097..0e946996d2719 100644 --- a/src/macroexpand.scm +++ b/src/macroexpand.scm @@ -62,8 +62,11 @@ sparams))))) ;; function definition - (pattern-lambda (function (call name . argl) body) + (pattern-lambda (function (-$ (call name . argl) (|::| (call name . argl) _t)) body) (cons 'varlist (llist-vars (fix-arglist argl)))) + (pattern-lambda (function (where (-$ (call name . argl) (|::| (call name . argl) _t)) . wheres) body) + (cons 'varlist (append (llist-vars (fix-arglist argl)) + (map typevar-expr-name wheres)))) (pattern-lambda (function (tuple . args) body) `(-> (tuple ,@args) ,body)) @@ -71,8 +74,10 @@ ;; expression form function definition (pattern-lambda (= (call (curly name . sparams) . argl) body) `(function (call (curly ,name . ,sparams) . ,argl) ,body)) - (pattern-lambda (= (call name . argl) body) + (pattern-lambda (= (-$ (call name . argl) (|::| (call name . argl) _t)) body) `(function (call ,name ,@argl) ,body)) + (pattern-lambda (= (where (-$ (call name . argl) (|::| (call name . argl) _t)) . wheres) body) + (cons 'function (cdr __))) ;; anonymous function (pattern-lambda (-> a b) @@ -82,6 +87,10 @@ (list a)))) (cons 'varlist (llist-vars (fix-arglist a))))) + ;; where + (pattern-lambda (where ex . vars) + (cons 'varlist (map typevar-expr-name vars))) + ;; let (pattern-lambda (let ex . binds) (let loop ((binds binds) @@ -143,13 +152,17 @@ (pattern-lambda (function (call (curly name . sparams) . argl) body) (cons 'varlist (llist-keywords (fix-arglist argl)))) - (pattern-lambda (function (call name . argl) body) + (pattern-lambda (function (-$ (call name . argl) (|::| (call name . argl) _t)) body) + (cons 'varlist (llist-keywords (fix-arglist argl)))) + (pattern-lambda (function (where (-$ (call name . argl) (|::| (call name . argl) _t)) . wheres) body) (cons 'varlist (llist-keywords (fix-arglist argl)))) (pattern-lambda (= (call (curly name . sparams) . argl) body) `(function (call (curly ,name . ,sparams) . ,argl) ,body)) - (pattern-lambda (= (call name . argl) body) + (pattern-lambda (= (-$ (call name . argl) (|::| (call name . argl) _t)) body) `(function (call ,name ,@argl) ,body)) + (pattern-lambda (= (where (-$ (call name . argl) (|::| (call name . argl) _t)) . wheres) body) + (cons 'function (cdr __))) )) (define (pair-with-gensyms v) @@ -166,6 +179,20 @@ (define (typevar-expr-name e) (car (analyze-typevar e))) +;; resolve-expansion-vars-with-new-env, but turn on `inarg` once we get inside +;; the formal argument list. `e` in general might be e.g. `(f{T}(x)::T) where T`. +(define (resolve-in-function-lhs e env m inarg) + (define (recur x) (resolve-in-function-lhs x env m inarg)) + (define (other x) (resolve-expansion-vars-with-new-env x env m inarg)) + (case (car e) + ((where) `(where ,(recur (cadr e)) ,@(map other (cddr e)))) + ((|::|) `(|::| ,(recur (cadr e)) ,(other (caddr e)))) + ((call) `(call ,(other (cadr e)) + ,@(map (lambda (x) + (resolve-expansion-vars-with-new-env x env m #t)) + (cddr e)))) + (else (other e)))) + (define (new-expansion-env-for x env (outermost #f)) (let ((introduced (pattern-expand1 vars-introduced-by-patterns x))) (if (or (atom? x) @@ -252,12 +279,9 @@ (cdr e)))) ((= function) - (if (and (pair? (cadr e)) (eq? (caadr e) 'call)) + (if (and (pair? (cadr e)) (function-def? e)) ;; in (kw x 1) inside an arglist, the x isn't actually a kwarg - `(,(car e) (call ,(resolve-expansion-vars-with-new-env (cadadr e) env m inarg) - ,@(map (lambda (x) - (resolve-expansion-vars-with-new-env x env m #t)) - (cddr (cadr e)))) + `(,(car e) ,(resolve-in-function-lhs (cadr e) env m inarg) ,(resolve-expansion-vars-with-new-env (caddr e) env m inarg)) `(,(car e) ,@(map (lambda (x) (resolve-expansion-vars-with-new-env x env m inarg)) @@ -308,6 +332,8 @@ ((eq? (car e) 'call) (decl-var* (cadr e))) ((eq? (car e) '=) (decl-var* (cadr e))) ((eq? (car e) 'curly) (decl-var* (cadr e))) + ((eq? (car e) '|::|) (decl-var* (cadr e))) + ((eq? (car e) 'where) (decl-var* (cadr e))) (else (decl-var e)))) (define (decl-vars* e) @@ -318,7 +344,7 @@ (define (function-def? e) (and (pair? e) (or (eq? (car e) 'function) (eq? (car e) '->) (and (eq? (car e) '=) (length= e 3) - (pair? (cadr e)) (eq? (caadr e) 'call))))) + (eventually-call (cadr e)))))) (define (find-declared-vars-in-expansion e decl (outer #t)) (cond ((or (not (pair? e)) (quoted? e)) '()) @@ -335,11 +361,11 @@ ((eq? (car e) 'escape) '()) ((and (not outer) (function-def? e)) ;; pick up only function name - (let ((fname (cond ((eq? (car e) '=) (cadr (cadr e))) + (let ((fname (cond ((eq? (car e) '=) (decl-var* (cadr e))) ((eq? (car e) 'function) (cond ((atom? (cadr e)) (cadr e)) ((eq? (car (cadr e)) 'tuple) #f) - (else (cadr (cadr e))))) + (else (decl-var* (cadr e))))) (else #f)))) (if (symbol? fname) (list fname) diff --git a/test/core.jl b/test/core.jl index 37afe6c68fb26..bf0d22504ddd5 100644 --- a/test/core.jl +++ b/test/core.jl @@ -5018,3 +5018,65 @@ for i in 1:10 @test ptr1 === ptr2 @test ptr1 % 16 == 0 end + +# issue #21581 +global function f21581()::Int + return 2.0 +end +@test f21581() === 2 +global g21581()::Int = 2.0 +@test g21581() === 2 +module M21581 +macro bar() + :(foo21581(x)::Int = x) +end +M21581.@bar +end +@test M21581.foo21581(1) === 1 + +module N21581 +macro foo(var) + quote + function f(x::T = 1) where T + ($(esc(var)), x) + end + f() + end +end +end +let x = 8 + @test @N21581.foo(x) === (8, 1) +end + +# issue #22122 +let + global @inline function f22122(x::T) where {T} + T + end +end +@test f22122(1) === Int + +# issue #22026 +module M22026 + +macro foo(TYP) + quote + global foofunction + foofunction(x::Type{T}) where {T<:Number} = x + end +end +struct Foo end +@foo Foo + +macro foo2() + quote + global foofunction2 + (foofunction2(x::T)::Float32) where {T<:Number} = 2x + end +end + +@foo2 + +end +@test M22026.foofunction(Int16) === Int16 +@test M22026.foofunction2(3) === 6.0f0 diff --git a/test/staged.jl b/test/staged.jl index 7d6a95ff8b145..ba62fc2d91816 100644 --- a/test/staged.jl +++ b/test/staged.jl @@ -224,3 +224,7 @@ g10178(x) = f10178(x) end g10178(x) = f10178(x) @test g10178(5) == 10 + +# issue #22135 +@generated f22135(x::T) where T = x +@test f22135(1) === Int