Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite SLVector as a subtype of StaticVector #19

Merged
merged 4 commits into from
Nov 7, 2018

Conversation

MSeeker1340
Copy link
Contributor

#18 (comment)

For the @SLVector macro, I modified @YingboMa's implementation to not return an anonymous constructor-like function but instead the type/constructor itself. This makes code much simpler (e.g. see the new slvectors.jl tests).

similar(x) and AbstractFloat{T}(x) still returns an unwrapped MArray. This is the behavior intended by StaticArrays (https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/abstractarray.jl#L94). I can probably make it better by defining a MLVector type, but for now this isn't really an issue.

@MSeeker1340 MSeeker1340 mentioned this pull request Nov 6, 2018
@ChrisRackauckas
Copy link
Member

Looks good to me. Get a PR ready for PuMaS with these changes to make sure it does what we need. This should get merged and tagged with a minor release.

@MSeeker1340
Copy link
Contributor Author

@ChrisRackauckas Need some help with metaprogramming.

In the PuMaS update, I tried to use the new @SLVector as

# odevars == (:Depot, :Central)
uType = @SLVector Float64 :($(odevars...,))

However this is what I got:

julia> odevars = (:Depot, :Central);

julia> @SLVector Float64 (:Depot, :Central)
SLVector{2,Float64,(:Depot, :Central)}

julia> @SLVector Float64 :($(odevars...,))
SLVector{1,Float64,(:Depot, :Central)}

The macro definition is

macro SLVector(E,syms)
    quote
        SLVector{$(length(syms.args)),$(esc(E)),$syms}
    end
end

I can of course just use plain constructors instead of @SLVector, but I'm curious as to why I got this behavior.

@ChrisRackauckas
Copy link
Member

Yeah I'm not sure why that happens, but I noticed it before...

@ChrisRackauckas ChrisRackauckas merged commit 6020826 into SciML:master Nov 7, 2018
@ChrisRackauckas
Copy link
Member

Before tagging, I want to see if we can get this working on arrays and not just vectors via whatever the new ind2sub is.

@YingboMa
Copy link
Member

YingboMa commented Nov 7, 2018

julia> struct SLVector{A,B,C} end

julia> macro SLVector(E,syms)
           n = syms isa Expr ? length(syms.args) : length(syms)
           quote
               SLVector{$n,$(esc(E)),$(esc(syms))}
           end
       end
@SLVector (macro with 1 method)

julia> odevars = (:Depot, :Central);

julia> @SLVector Float64 (:Depot, :Central)
SLVector{2,Float64,(:Depot, :Central)}

julia> @eval @SLVector Float64 $(odevars...,)
SLVector{2,Float64,(:Depot, :Central)}

@YingboMa
Copy link
Member

YingboMa commented Nov 7, 2018

Maybe we shouldn't use Val to do index at all. That is not a fast way to do it.

julia> using BenchmarkTools, LabelledArrays

julia> ABC = @SLVector Int (:a,:b,:c)
SLVector{3,Int64,(:a, :b, :c)}

julia> b = ABC(1,2,3)
3-element SLVector{3,Int64,(:a, :b, :c)}:
 1
 2
 3

julia> @btime b[i] setup=(i = (:a, :b, :c)[rand(1:3)])
  3.138 μs (1 allocation: 32 bytes)
1

julia> function Base.getindex(x::SLVector,s::Symbol)
         idx = findfirst(isequal(s), LabelledArrays.symnames(typeof(x)))
         getfield(x, :__x)[idx]
       end

julia> @btime b[i] setup=(i = (:a, :b, :c)[rand(1:3)])
  14.571 ns (0 allocations: 0 bytes)
2

julia> @btime b[i] setup=(i = rand(1:3))
  15.217 ns (0 allocations: 0 bytes)
2

The naive implementation is fast.

@ChrisRackauckas
Copy link
Member

what's the generated code like? When I tried something like that, constant prop didn't work through findfirst. The generated function makes sure it compiles away.

@YingboMa
Copy link
Member

YingboMa commented Nov 7, 2018

julia> @btime b[i] setup=(i = (:a, :b, :c)[rand(1:3)])
  2.828 μs (1 allocation: 32 bytes)
1

julia> @btime b[:a]
  2.985 μs (1 allocation: 32 bytes)
1

julia> @btime b[Val(:a)]
  13.985 ns (0 allocations: 0 bytes)
1

@ChrisRackauckas
Copy link
Member

You might be timing something odd in the global scope there? Interpolate it in?

This is a good case for checking the generated code though. It's either running findfirst or it's just using the scalar index at runtime. If it's not just compiling down to a scalar indexing, that would be an issue for larger SArray operations.

@YingboMa
Copy link
Member

YingboMa commented Nov 7, 2018

I don't need to interpolate when I do @btime b[i] setup=(i = (:a, :b, :c)[rand(1:3)]). With "Val indexing", I got

julia> @code_typed b[:a]
CodeInfo(
22 1%1 = invoke LabelledArrays.Val(_3::Symbol)::Val{_1} where _1                                    │
   │   %2 = (Base.getindex)(x, %1)::Any                                                                │
   └──      return %2                                                                                  │
) => Any

julia> goo(b) = b[:a]
goo (generic function with 1 method)

julia> @code_typed goo(b)
CodeInfo(
1 1%1 = (LabelledArrays.getfield)(b, :__x)::SArray{Tuple{3},Int64,1,3}        │╻╷╷╷ getindex
  │   %2 = (Base.getfield)(%1, :data)::Tuple{Int64,Int64,Int64}                  ││╻    getindex
  │   %3 = (Base.getfield)(%2, 1, true)::Int64                                   │││╻    macro expansion
  └──      return %3                                                             │
) => Int64

With the naive implementation, I got

julia> @code_typed b[:a]
CodeInfo(
2 1 ── %1  = (Base.getfield)((:a, :b, :c), 1, true)::Symbol                                                                                │╻╷╷  findfirst
  └───       goto #12 if not true                                                                                                          ││
  2 ┄─ %3  = φ (#1 => 1, #11 => %22)::Int64                                                                                                ││%4  = φ (#1 => %1, #11 => %23)::Symbol                                                                                              ││%5  = φ (#1 => 1, #11 => %24)::Int64                                                                                                ││%6  = (%4 === s)::Bool                                                                                                              ││╻╷╷  Fix2
  └───       goto #4 if not %6                                                                                                             ││
  3 ──       goto #13                                                                                                                      ││
  4 ── %9  = (%5 === 3)::Bool                                                                                                              │││╻╷   iterate
  └───       goto #6 if not %9                                                                                                             ││││
  5 ──       goto #7                                                                                                                       ││││
  6 ── %12 = (Base.add_int)(%5, 1)::Int64                                                                                                  ││││╻    +
  └───       goto #7                                                                                                                       │││╻    iterate
  7 ┄─ %14 = φ (#5 => true, #6 => false)::Bool                                                                                             │││%15 = φ (#6 => %12)::Int64                                                                                                          │││%16 = φ (#6 => %12)::Int64                                                                                                          │││%17 = φ (#5 => true)::Bool                                                                                                          │││
  └───       goto #9 if not %14                                                                                                            │││
  8 ──       goto #10                                                                                                                      │││
  9 ── %20 = (Base.getfield)((:a, :b, :c), %15, true)::Symbol                                                                              │││╻    getindex
  └───       goto #10                                                                                                                      ││╻    iterate
  10%22 = φ (#9 => %15)::Int64                                                                                                          ││%23 = φ (#9 => %20)::Symbol                                                                                                         ││%24 = φ (#9 => %16)::Int64                                                                                                          ││%25 = φ (#8 => %17, #9 => false)::Bool                                                                                              ││%26 = (Base.not_int)(%25)::Bool                                                                                                     ││
  └───       goto #12 if not %26                                                                                                           ││
  11 ─       goto #2                                                                                                                       ││
  12%29 = Base.nothing::Const(nothing, false)                                                                                           ││
  └───       goto #13                                                                                                                      ││
  13%31 = φ (#3 => %3, #12 => %29)::Union{Nothing, Int64}                                                                               │
3%32 = (Main.getfield)(x, :__x)::SArray{Tuple{3},Int64,1,3}                                                                          │
  │    %33 = (isa)(%31, Int64)::Bool                                                                                                       │
  └───       goto #15 if not %33                                                                                                           │
  14%35 = π (%31, Int64)                                                                                                                │
  │    %36 = (Base.getfield)(%32, :data)::Tuple{Int64,Int64,Int64}                                                                         ││╻    getproperty
  │    %37 = (Base.getfield)(%36, %35, true)::Int64                                                                                        ││╻    getindex
  └───       goto #18                                                                                                                      │
  15%39 = (isa)(%31, Nothing)::Bool                                                                                                     │
  └───       goto #17 if not %39                                                                                                           │
  16%41 = π (%31, Nothing)                                                                                                              │
  │          invoke Base.to_index(%32::SArray{Tuple{3},Int64,1,3}, %41::Nothing)::Union{}                                                  ││╻╷   to_indices
  │          $(Expr(:unreachable))::Union{}                                                                                                │││┃    to_indices
  │          φ ()::Union{}                                                                                                                 │││
  │          $(Expr(:unreachable))::Union{}                                                                                                │││
  │          φ ()::Union{}                                                                                                                 ││
  │          $(Expr(:unreachable))::Union{}                                                                                                ││
  └───       $(Expr(:unreachable))::Union{}17 ┄       (Core.throw)(ErrorException("fatal error in type inference (type bound)"))::Union{}                                           │
  └───       $(Expr(:unreachable))::Union{}18return %37                                                                                                                    │
) => Int64

julia> goo(b) = b[:a]
goo (generic function with 1 method)

julia> @code_typed goo(b)
CodeInfo(
1 1%1 = (Main.getfield)(b, :__x)::SArray{Tuple{3},Int64,1,3}                                                                               │╻   getindex
  │   %2 = (Base.getfield)(%1, :data)::Tuple{Int64,Int64,Int64}                                                                               ││╻   getindex
  │   %3 = (Base.getfield)(%2, 1, true)::Int64                                                                                                │││╻   getindex
  └──      return %3                                                                                                                          │
) => Int64

@ChrisRackauckas
Copy link
Member

So the naive implementation still does constant prop, it's just that it doesn't overdo the compilation when used from the global scope?

@YingboMa
Copy link
Member

YingboMa commented Nov 7, 2018

The naive implementation still compiles quite well in the global scope, but with Val, if the compiler cannot do constant prop, the performance is going to deplete.

@ChrisRackauckas
Copy link
Member

Pick this up post #20

@MSeeker1340 MSeeker1340 deleted the xg/subclass branch November 7, 2018 21:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants