-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add topk #353
base: master
Are you sure you want to change the base?
add topk #353
Conversation
@@ -177,3 +177,33 @@ for pool in [:maxpool, :meanpool] | |||
return Ω, $pullback | |||
end | |||
end | |||
|
|||
|
|||
function topk(x::AbstractArray{T,N}, k; rev=false, dims=nothing) where {T,N} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems a little odd that "top k" gives the smallest values:
julia> r = [3,11,5,13,7];
julia> topk(r, 2)
([3, 5], CartesianIndex{1}[CartesianIndex(1,), CartesianIndex(3,)])
julia> sort(r)
5-element Vector{Int64}:
3
5
7
11
13
I understand it's following what sort
does, but perhaps it needs a better name, or a different default? Is the typical use to select the largest elements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch has a largest
param to control this: https://pytorch.org/docs/stable/generated/torch.topk.html#torch.topk. As you noted, the default is to return the largest k values. I think passing rev=!rev
to partialsortperm
would do the trick.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks. So they default to "largest=True" i.e. "rev=false" here.
They allow "sorted=False", but with partialsortperm
, they will be sorted by default. I presume there are no downsides to getting a sorted result, it's just that they (perhaps) do something cheaper when this isn't required?
One nice thing about following sort
is that you can pass all of its keywords along, by=..., lt=...
, in case people want these... and you can just point at Base's documentation. It wouldn't be crazy to do that and just note that rev=true
is the default here.
I think the obvious way to make this AD-able would be just rely on the gradient for getindex, which will store the indices for the backwards pass. Maybe it's as simple as this:
Unlike the PR this doesn't return the permutation, but do you need it for anything else? It also defaults to first dimension, and won't accept But to make this GPU-friendly... There should be no aliasing issues with the gradient. There is |
For GPU compat, TF's XLA implementation should be using all out of place ops: https://github.com/tensorflow/tensorflow/blob/8d72537c6abf5a44103b57b9c2e22c14f5f49698/tensorflow/core/tpu/kernels/topk_ops.cc. Of course we can't rely on any of the optimizations in XLA, so a more ideal implementation would probably look the PyTorch one here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/TensorTopK.cu. |
What I hoped might work, re-using what CUDA.jl already has, doesn't seem to -- can this easily be fixed? julia> sort(tuple.(cu(rand(10)), 1:10), by=first)
ERROR: InvalidIRError: compiling kernel qsort_kernel(CuDeviceVector{Tuple{Float32, Int64}, 1}, Int64, Int64, Bool, Val{true}, Int64, Nothing, typeof(isless), typeof(first), Val{1}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to zero)
Stacktrace:
[1] bitonic_median
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:217
[2] qsort_kernel
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:405
[3] qsort_kernel
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:378
Reason: unsupported dynamic function invocation (call to zero)
Stacktrace:
[1] bubble_sort
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:274
[2] qsort_kernel
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:401
[3] qsort_kernel
@ ~/.julia/packages/CUDA/9T5Sq/src/sorting.jl:378
Stacktrace:
[1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(CUDA.Quicksort.qsort_kernel), Tuple{CuDeviceVector{Tuple{Float32, Int64}, 1}, Int64, Int64, Bool, Val{true}, Int64, Nothing, typeof(isless), typeof(first), Val{1}}}}, args::LLVM.Module)
@ GPUCompiler ~/.julia/packages/GPUCompiler/fG3xK/src/validation.jl:111
[2] macro expansion
@ ~/.julia/packages/GPUCompiler/fG3xK/src/driver.jl:319 [inlined] |
OK, this seems to work: function topkperm(x::CuArray, k::Integer; dims::Integer=1, rev=true, lt=isless, by=identity)
tups = tuple.(x, reshape(axes(x,dims), fill(1, dims-1)..., :))
CUDA.quicksort!(tups; lt=(rev ? !lt : lt), by=by∘first, dims=dims, partial_k=1:k)
tv = view(tups, ntuple(d -> d==dims ? (1:k) : (:), ndims(x))...)
broadcast(tv, CartesianIndices(ntuple(d -> d==dims ? Base.OneTo(1) : axes(x,d), ndims(x)))) do (_,i), J
CartesianIndex(ntuple(d -> d==dims ? i : J[d], ndims(x)))
end
end
# piracy to make e.g. this work: sort(tuple.(cu(rand(10)), 1:10), by=first)
@inline Base.zero(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> zero(T.parameters[i]), N)
@inline Base.one(::Type{T}) where {T<:Tuple{Vararg{Any,N}}} where {N} = ntuple(i -> one(T.parameters[i]), N)
My CPU version above is not correct, since it uses indices along one slice as linear indices in the whole array. It would need to also make CartesianIndices everywhere. After which it's not just one line... is there a tidier way?
|
@mcabbott thanks for all these suggestions, feel free to push on this branch. I like the Representation of the (partial) permutation as an array of CartesianIndex it's redundant, since it would be enough to
We could return a custom type for |
Yes I don't like the redundancy of returning CartesianIndices, but I do like the simplicity of getindex. Maybe returning an array of linear indices instead would be nicer? |
Initial steps to fix #352
TODO