Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ext/TensorKitCUDAExt/cutensormap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
fill!(t, $felt(T))
return t
end
function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA <: CuArray}
t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V)
fill!(t, $felt(eltype(TorA)))
return t
end
Comment thread
kshyatt marked this conversation as resolved.
Outdated
end
end

Expand Down
2 changes: 1 addition & 1 deletion ext/TensorKitChainRulesCoreExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function ChainRulesCore.rrule(
end

function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A))
tr_pullback(Δtr) = NoTangent(), Δtr * id(storagetype(A), domain(A))
Comment thread
kshyatt marked this conversation as resolved.
Outdated
return tr(A), tr_pullback
end

Expand Down
17 changes: 11 additions & 6 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,19 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
return Base.$fname(codomain ← domain)
end
function Base.$fname(
::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
) where {T, S <: IndexSpace}
return Base.$fname(T, codomain ← domain)
::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
) where {TorA, S <: IndexSpace}
return Base.$fname(TorA, codomain ← domain)
end
function Base.$fname(
::Type{T}, ::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
Comment thread
kshyatt marked this conversation as resolved.
Outdated
) where {T, TorA, S <: IndexSpace}
return Base.$fname(TorA, codomain ← domain)
end
Base.$fname(V::TensorMapSpace) = Base.$fname(Float64, V)
function Base.$fname(::Type{T}, V::TensorMapSpace) where {T}
t = TensorMap{T}(undef, V)
fill!(t, $felt(T))
function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA}
t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V)
fill!(t, $felt(TorA))
Comment thread
kshyatt marked this conversation as resolved.
Outdated
return t
end
end
Expand Down
Loading