Skip to content

Commit 259acd9

Browse files
authored
Merge pull request #4491 from vicentebolea/python-stream-numpy
python: stream.write accept np array w/o extra args
2 parents f6645c5 + 8ad0af9 commit 259acd9

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

python/adios2/stream.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""License:
2-
Distributed under the OSI-approved Apache License, Version 2.0. See
3-
accompanying file Copyright.txt for details.
2+
Distributed under the OSI-approved Apache License, Version 2.0. See
3+
accompanying file Copyright.txt for details.
44
"""
55

66
from functools import singledispatchmethod
@@ -296,14 +296,17 @@ def _(self, name, content, shape=[], start=[], count=[], operations=None):
296296

297297
if not variable:
298298
# Sequence variables
299-
if isinstance(content, np.ndarray):
300-
variable = self._io.define_variable(name, content, shape, start, count)
301-
elif isinstance(content, list):
302-
if shape == [] and count == []:
303-
shape = [len(content)]
304-
count = shape
305-
start = [0]
299+
if isinstance(content, (list, np.ndarray)):
300+
if isinstance(content, list):
301+
content = np.asarray(content)
302+
303+
# If shape, start, and count is not specified, use the numpy array's shape
304+
if shape == [] and start == [] and count == []:
305+
shape = list(content.shape)
306+
start = [0] * content.ndim
307+
count = shape[:]
306308
variable = self._io.define_variable(name, content, shape, start, count)
309+
307310
# Scalar variables
308311
elif isinstance(content, str):
309312
variable = self._io.define_variable(name, content)

testing/adios2/python/TestStream.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from adios2 import Stream, LocalValueDim
22
from random import randint
3+
import numpy as np
34

45
import unittest
56

@@ -22,14 +23,16 @@ def test_basic(self):
2223
s.write("Wind", [5], shape=[LocalValueDim])
2324
# Local Array
2425
s.write("Coords", [38, -46], [], [], [2])
26+
s.write("humidity", np.random.rand(3, 1))
2527

2628
with Stream("pythonstreamtest.bp", "r") as s:
2729
for _ in s.steps():
2830
for var_name in s.available_variables():
2931
print(f"var:{var_name}\t{s.read(var_name)}")
30-
self.assertEqual(s.read("Wind", block_id=0), 5)
31-
self.assertEqual(s.read("Coords", block_id=0)[0], 38)
32-
self.assertEqual(s.read("Coords", block_id=0)[1], -46)
32+
self.assertEqual(s.read("Wind", block_id=0), 5)
33+
self.assertEqual(s.read("Coords", block_id=0)[0], 38)
34+
self.assertEqual(s.read("Coords", block_id=0)[1], -46)
35+
self.assertEqual(s.read("humidity", block_id=0).ndim, 2)
3336

3437
def test_start_count(self):
3538
with Stream("pythonstreamtest.bp", "w") as s:

0 commit comments

Comments
 (0)