Skip to content

Commit

Permalink
fix incorrect function dispatch on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Jun 30, 2023
1 parent 83f48c9 commit 9983f50
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TropicalGEMM"
uuid = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
authors = ["GiggleLiu <[email protected]> and contributors"]
version = "0.1.9"
version = "0.1.10"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ julia> a = Tropical.(randn(1000, 1000))
julia> @benchmark Octavian.matmul_serial($a, $a)
```

**Warning:** using TropicalGEMM will overload the `mul!` function for Tropical numbers.

## Benchmarks

Matrix size `n x n`, CPU Intel(R) Core(TM) i5-10400 CPU @ 2.90GHz.
Expand Down
5 changes: 5 additions & 0 deletions src/TropicalGEMM.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module TropicalGEMM

using LinearAlgebra, TropicalNumbers, VectorizationBase, LoopVectorization
using VectorizationBase: OffsetPrecalc, StaticBool, Bit, static, NativeTypes, Index, gep_quote, VectorIndex,
AbstractMask, NativeTypesExceptBit, AbstractSIMDVector, IndexNoUnroll, AbstractStridedPointer, AbstractSIMD
using VectorizationBase: contiguous_batch_size, contiguous_axis, val_stride_rank, bytestrides, offsets, memory_reference,
vmaximum, fmap, FloatingTypes, IntegerIndex, LazyMulAdd
using LinearAlgebra: StridedMaybeAdjOrTransMat

export Tropical, TropicalF64, TropicalF32, TropicalF16

Expand Down
13 changes: 5 additions & 8 deletions src/fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,12 @@ function naive_mul!(o::AbstractMatrix{T0}, a::AbstractMatrix{T1}, b::AbstractMat
return o
end

# For types not nativelly supported, go to fallback.
# Overwrite the `mul!` in LinearAlgebra (also changes the behavior of `*` in Base)!
for TA in [:(AbstractMatrix{T} where T<:TropicalTypes), :(Transpose{T,S} where {T<:TropicalTypes,S<:AbstractVecOrMat{T}})]
for TB in [:(AbstractMatrix{T} where T<:TropicalTypes), :(Transpose{T,S} where {T<:TropicalTypes,S<:AbstractVecOrMat{T}})]
@eval @inline function LinearAlgebra.mul!(o::AbstractMatrix{TO}, a::$TA, b::$TB, α::Number, β::Number) where TO
α = _convert_to_static(TO, α)
β = _convert_to_static(TO, β)
naive_mul!(o, a, b, α, β)
end
end
function LinearAlgebra.mul!(o::StridedMaybeAdjOrTransMat{TO}, a::StridedMaybeAdjOrTransMat, b::StridedMaybeAdjOrTransMat, α::Number, β::Number) where TO
α = _convert_to_static(TO, α)
β = _convert_to_static(TO, β)
naive_mul!(o, a, b, α, β)
end

Base.:*(a::T, b::StaticInt{0}) where T<:TropicalTypes = zero(T)
Expand Down
18 changes: 4 additions & 14 deletions src/gemm.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
using VectorizationBase: OffsetPrecalc, StaticBool, Bit, static, NativeTypes, Index, gep_quote, VectorIndex,
AbstractMask, NativeTypesExceptBit, AbstractSIMDVector, IndexNoUnroll, AbstractStridedPointer, AbstractSIMD
using VectorizationBase: contiguous_batch_size, contiguous_axis, val_stride_rank, bytestrides, offsets, memory_reference,
vmaximum, fmap, FloatingTypes, IntegerIndex, LazyMulAdd

LoopVectorization.check_args(::Type{T}, ::Type{T}) where T<:Tropical = true
LoopVectorization.check_type(::Type{Tropical{T}}) where {T} = LoopVectorization.check_type(T)

Expand Down Expand Up @@ -148,15 +143,10 @@ end

# Overwrite the `mul!` in LinearAlgebra (also changes the behavior of `*` in Base)!
using Octavian
const XTranspose{T} = Transpose{T, <:AbstractVecOrMat{T}}
for TA in [:AbstractMatrix, :XTranspose]
for TB in [:AbstractMatrix, :XTranspose]
@eval function LinearAlgebra.mul!(o::AbstractMatrix{T}, a::$TA{T}, b::$TB{T}, α::Number, β::Number) where {T<:Tropical{<:NativeTypes}}
α = _convert_to_static(T, α)
β = _convert_to_static(T, β)
Octavian.matmul!(o, a, b, α, β)
end
end
function LinearAlgebra.mul!(o::StridedMaybeAdjOrTransMat{T}, a::StridedMaybeAdjOrTransMat{T}, b::StridedMaybeAdjOrTransMat{T}, α::Number, β::Number) where {T<:Tropical{<:NativeTypes}}
α = _convert_to_static(T, α)
β = _convert_to_static(T, β)
Octavian.matmul!(o, a, b, α, β)
end
# NOTE: benchmark shows, the type instability here can be optimized by the compiler
# so you do not need to worry about the overheads.
Expand Down

0 comments on commit 9983f50

Please sign in to comment.