Skip to content

Commit

Permalink
more tropical element types
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu committed Sep 24, 2023
1 parent 71c1e01 commit b2e0b32
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 135 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TropicalGEMM"
uuid = "a4ad3063-64a7-4bad-8738-34ed09bc0236"
authors = ["GiggleLiu <[email protected]> and contributors"]
version = "0.1.10"
version = "0.2.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
34 changes: 17 additions & 17 deletions src/TropicalGEMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,26 @@ using VectorizationBase: OffsetPrecalc, StaticBool, Bit, static, NativeTypes, In
using VectorizationBase: contiguous_batch_size, contiguous_axis, val_stride_rank, bytestrides, offsets, memory_reference,
vmaximum, fmap, FloatingTypes, IntegerIndex, LazyMulAdd

export Tropical, TropicalF64, TropicalF32
export Tropical, TropicalF64, TropicalF32, TropicalMinPlus, TropicalMinPlusF64, TropicalMinPlusF32, TropicalMaxMul, TropicalMaxMulF64, TropicalMaxMulF32, TropicalMaxPlus, TropicalMaxPlusF64, TropicalMaxPlusF32, BlasSemiringTypes

include("fallbacks.jl")
include("gemm.jl")

import PrecompileTools
PrecompileTools.@setup_workload begin
# Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# precompile file and potentially make loading faster.
PrecompileTools.@compile_workload begin
for T in (Float32, Float64, Int64)
A = Tropical.(rand(T, 10, 10))
TA = transpose(A)
for x in [A, TA]
for y in [A, TA]
x * y
end
end
end
end
end
# import PrecompileTools
# PrecompileTools.@setup_workload begin
# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# # precompile file and potentially make loading faster.
# PrecompileTools.@compile_workload begin
# for T in (Float32, Float64, Int64)
# A = Tropical.(rand(T, 10, 10))
# TA = transpose(A)
# for x in [A, TA]
# for y in [A, TA]
# x * y
# end
# end
# end
# end
# end

end
10 changes: 6 additions & 4 deletions src/fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ end

# For types not nativelly supported, go to fallback.
# Overwrite the `mul!` in LinearAlgebra (also changes the behavior of `*` in Base)!
function LinearAlgebra.mul!(o::MaybeAdjOrTransMat{TO}, a::MaybeAdjOrTransMat{<:Tropical}, b::MaybeAdjOrTransMat{<:Tropical}, α::Number, β::Number) where TO<:Tropical
α = _convert_to_static(TO, α)
β = _convert_to_static(TO, β)
naive_mul!(o, a, b, α, β)
for TT in [:Tropical, :TropicalMinPlus, TropicalMaxMul]
@eval function LinearAlgebra.mul!(o::MaybeAdjOrTransMat{TO}, a::MaybeAdjOrTransMat{<:$TT}, b::MaybeAdjOrTransMat{<:$TT}, α::Number, β::Number) where TO<:$TT
α = _convert_to_static(TO, α)
β = _convert_to_static(TO, β)
naive_mul!(o, a, b, α, β)
end
end

Base.:*(a::T, b::StaticInt{0}) where T<:TropicalTypes = zero(T)
Expand Down
Loading

0 comments on commit b2e0b32

Please sign in to comment.