Skip to content

Commit a3e489b

Browse files
alnotranslunar
authored andcommitted
Preserve matrix dtype on NMatrix#repeat call (#515)
1 parent be10094 commit a3e489b

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

lib/nmatrix/nmatrix.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,7 @@ def repeat(count, axis)
10251025
raise(ArgumentError, 'Matrix should be repeated at least 2 times.') if count < 2
10261026
new_shape = shape
10271027
new_shape[axis] *= count
1028-
new_matrix = NMatrix.new(new_shape)
1028+
new_matrix = NMatrix.new(new_shape, dtype: dtype)
10291029
slice = new_shape.map { |axis_size| 0...axis_size }
10301030
start = 0
10311031
count.times do

spec/00_nmatrix_spec.rb

+5
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,11 @@
726726
expect(@sample_matrix.repeat(2, 0)).to eq(NMatrix.new([4, 2], [1, 2, 3, 4, 1, 2, 3, 4]))
727727
expect(@sample_matrix.repeat(2, 1)).to eq(NMatrix.new([2, 4], [1, 2, 1, 2, 3, 4, 3, 4]))
728728
end
729+
730+
it "preserves dtype" do
731+
expect(@sample_matrix.repeat(2, 0).dtype).to eq(@sample_matrix.dtype)
732+
expect(@sample_matrix.repeat(2, 1).dtype).to eq(@sample_matrix.dtype)
733+
end
729734
end
730735

731736
context "#meshgrid" do

0 commit comments

Comments
 (0)