Skip to content

Commit

Permalink
Add built-in support to write compound data types (#1013)
Browse files Browse the repository at this point in the history
* Add write_compound function and tests

* Remove HDF5 namespace

* Fix code formatting

* Add more type annotations

* Julia formatter

* Rename write_compound

* Rename write_compound to write_compound_dataset

Co-authored-by: Mark Kittisopikul <[email protected]>

* Implement proper datatype discovery for compound

* Dispatch on struct-type itself

* Readd AbstractArray to datatype dispatch

* JuliaFormatter

* Cleanup

* JuliaFormatter

* Update src/typeconversions.jl

Co-authored-by: Mark Kittisopikul <[email protected]>

* Remove remark

* JuliaFormatter

* Cleanup

* Update src/HDF5.jl

Co-authored-by: Mark Kittisopikul <[email protected]>

* Clean up

* Add tests for mutable structs

* Convert non-bitstypes to a NamedTuple when writing datasets

* Make attributes compound type compatible

* Garbage collect compound types

* Formatting

* Apply suggestions from code review

Co-authored-by: Mustafa M <[email protected]>

Co-authored-by: Mark Kittisopikul <[email protected]>
Co-authored-by: Mark Kittisopikul <[email protected]>
Co-authored-by: Mustafa M <[email protected]>
  • Loading branch information
4 people authored Dec 9, 2022
1 parent e24fade commit 205518d
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 6 deletions.
68 changes: 65 additions & 3 deletions src/attributes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,77 @@ function create_attribute(
end

# generic method
write_attribute(attr::Attribute, memtype::Datatype, x) = API.h5a_write(attr, memtype, x)
function write_attribute(attr::Attribute, memtype::Datatype, x::T) where {T}
if isbitstype(T)
API.h5a_write(attr, memtype, x)
else
jl_type = get_mem_compatible_jl_type(memtype)
try
x_mem = convert(jl_type, x)
API.h5a_write(attr, memtype, Ref(x_mem))
catch err
if err isa MethodError
throw(
ArgumentError(
"Could not convert non-bitstype $T to $jl_type for writing to HDF5. Consider implementing `convert(::Type{$jl_type}, ::$T)`"
)
)
else
rethrow()
end
end
end
end
function write_attribute(attr::Attribute, memtype::Datatype, x::Ref{T}) where {T}
if isbitstype(T)
API.h5a_write(attr, memtype, x)
else
jl_type = get_mem_compatible_jl_type(memtype)
try
x_mem = convert(Ref{jl_type}, x[])
API.h5a_write(attr, memtype, x_mem)
catch err
if err isa MethodError
throw(
ArgumentError(
"Could not convert non-bitstype $T to $jl_type for writing to HDF5. Consider implementing `convert(::Type{$jl_type}, ::$T)`"
)
)
else
rethrow()
end
end
end
end

# specific methods
function write_attribute(attr::Attribute, memtype::Datatype, x::AbstractArray)
write_attribute(attr::Attribute, memtype::Datatype, x::VLen) =
API.h5a_write(attr, memtype, x)
function write_attribute(attr::Attribute, memtype::Datatype, x::AbstractArray{T}) where {T}
length(x) == length(attr) || throw(
ArgumentError(
"Invalid length: $(length(x)) != $(length(attr)), for attribute \"$(name(attr))\""
)
)
API.h5a_write(attr, memtype, x)
if isbitstype(T)
API.h5a_write(attr, memtype, x)
else
jl_type = get_mem_compatible_jl_type(memtype)
try
x_mem = convert(Array{jl_type}, x)
API.h5a_write(attr, memtype, x_mem)
catch err
if err isa MethodError
throw(
ArgumentError(
"Could not convert non-bitstype $T to $jl_type for writing to HDF5. Consider implementing `convert(::Type{$jl_type}, ::$T)`"
)
)
else
rethrow()
end
end
end
end
function write_attribute(attr::Attribute, memtype::Datatype, str::AbstractString)
strbuf = Base.cconvert(Cstring, str)
Expand Down
26 changes: 23 additions & 3 deletions src/datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,31 @@ end
function write_dataset(
dataset::Dataset,
memtype::Datatype,
buf::AbstractArray,
buf::AbstractArray{T},
xfer::DatasetTransferProperties=dataset.xfer
)
) where {T}
_check_invalid(dataset, buf)
API.h5d_write(dataset, memtype, API.H5S_ALL, API.H5S_ALL, xfer, buf)
if isbitstype(T)
API.h5d_write(dataset, memtype, API.H5S_ALL, API.H5S_ALL, xfer, buf)
else
# For non-bitstypes, we need to convert the buffer to a bitstype
# For mutable structs, this will usually be a NamedTuple.
jl_type = get_mem_compatible_jl_type(memtype)
try
memtype_buf = convert(Array{jl_type}, buf)
API.h5d_write(dataset, memtype, API.H5S_ALL, API.H5S_ALL, xfer, memtype_buf)
catch err
if err isa MethodError
throw(
ArgumentError(
"Could not convert non-bitstype $T to $jl_type for writing to HDF5. Consider implementing `convert(::Type{$jl_type}, ::$T)`"
)
)
else
rethrow()
end
end
end
end
function write_dataset(
dataset::Dataset,
Expand Down
7 changes: 7 additions & 0 deletions src/dataspaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ dataspace(ds::Dataspace) = ds
The default `Dataspace` used for representing a Julia object `data`:
- strings or numbers: a scalar `Dataspace`
- arrays: a simple `Dataspace`
- `struct` types: a scalar `Dataspace`
- `nothing` or an `EmptyArray`: a null dataspace
"""
dataspace(x::T) where {T} =
if isstructtype(T)
Dataspace(API.h5s_create(API.H5S_SCALAR))
else
throw(MethodError(dataspace, x))
end
dataspace(x::Union{T,Complex{T}}) where {T<:ScalarType} =
Dataspace(API.h5s_create(API.H5S_SCALAR))
dataspace(::AbstractString) = Dataspace(API.h5s_create(API.H5S_SCALAR))
Expand Down
19 changes: 19 additions & 0 deletions src/typeconversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ function datatype(str::VLen{C}) where {C<:CharType}
Datatype(API.h5t_vlen_create(type_id))
end

# Compound types

# These will use finalizers. Close them eagerly to avoid issues.
datatype(::T) where {T} = Datatype(hdf5_type_id(T), true)
datatype(x::AbstractArray{T}) where {T} = Datatype(hdf5_type_id(T), true)

hdf5_type_id(::Type{T}) where {T} = hdf5_type_id(T, Val(isstructtype(T)))
function hdf5_type_id(::Type{T}, isstruct::Val{true}) where {T}
dtype = API.h5t_create(API.H5T_COMPOUND, sizeof(T))
for (idx, fn) in enumerate(fieldnames(T))
ftype = fieldtype(T, idx)
API.h5t_insert(dtype, fn, fieldoffset(T, idx), hdf5_type_id(ftype))
end
return dtype
end
# Perhaps we need a custom error type here
hdf5_type_id(::Type{T}, isstruct::Val{false}) where {T} =
throw(MethodError(hdf5_type_id, (T, isstruct)))

# Opaque types
struct Opaque
data
Expand Down
82 changes: 82 additions & 0 deletions test/compound.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,85 @@ end
@test HDF5.do_reclaim(TTTT) == false
@test HDF5.do_normalize(TTTT) == true
end

struct Bar
a::Int32
b::Float64
c::Bool
end

mutable struct MutableBar
x::Int64
end

@testset "write_compound" begin
bars = [
[Bar(1, 1.1, true) Bar(2, 2.1, false) Bar(3, 3.1, true)]
[Bar(4, 4.1, false) Bar(5, 5.1, true) Bar(6, 6.1, false)]
]

namedtuples = [(a=1, b=2.3), (a=4, b=5.6)]

fn = tempname()
h5open(fn, "w") do h5f
write_dataset(h5f, "the/bars", bars)
write_dataset(h5f, "the/namedtuples", namedtuples)
end

thebars = h5open(fn, "r") do h5f
read(h5f, "the/bars")
end

@test (2, 3) == size(thebars)
@test thebars[1, 2].b 2.1
@test thebars[2, 1].a == 4
@test thebars[1, 3].c

thebars_r = reinterpret(Bar, thebars)
@test (2, 3) == size(thebars_r)
@test thebars_r[1, 2].b 2.1
@test thebars_r[2, 1].a == 4
@test thebars_r[1, 3].c

thenamedtuples = h5open(fn, "r") do h5f
read(h5f, "the/namedtuples")
end

@test (2,) == size(thenamedtuples)
@test thenamedtuples[1].a == 1
@test thenamedtuples[1].b 2.3
@test thenamedtuples[2].a == 4
@test thenamedtuples[2].b 5.6

mutable_bars = [MutableBar(1), MutableBar(2), MutableBar(3)]

fn = tempname()
@test_throws ArgumentError begin
h5open(fn, "w") do h5f
write_dataset(h5f, "the/mutable_bars", mutable_bars)
end
end

Base.convert(::Type{NamedTuple{(:x,),Tuple{Int64}}}, mb::MutableBar) = (x=mb.x,)
Base.unsafe_convert(::Type{Ptr{Nothing}}, mb::MutableBar) = pointer_from_objref(mb)

h5open(fn, "w") do h5f
write_dataset(h5f, "the/mutable_bars", mutable_bars)
write_dataset(h5f, "the/mutable_bar", first(mutable_bars))
end

h5open(fn, "r") do h5f
@test [1, 2, 3] == [b.x for b in read(h5f, "the/mutable_bars")]
@test 1 == read(h5f, "the/mutable_bar").x
end

h5open(fn, "w") do h5f
write_attribute(h5f, "mutable_bars", mutable_bars)
write_attribute(h5f, "mutable_bar", first(mutable_bars))
end

h5open(fn, "r") do h5f
@test [1, 2, 3] == [b.x for b in attrs(h5f)["mutable_bars"]]
@test 1 == attrs(h5f)["mutable_bar"].x
end
end

0 comments on commit 205518d

Please sign in to comment.