Skip to content

Commit

Permalink
WIP: add dask tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CJ-Wright committed Dec 18, 2018
1 parent 7869276 commit 6807aaa
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 4 deletions.
2 changes: 1 addition & 1 deletion rapidz/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,12 +1182,12 @@ class unique(Stream):

def __init__(self, upstream, history=None, key=identity, **kwargs):
self.seen = dict()
self.non_hash_seen = []
self.key = key
if history:
from zict import LRU

self.seen = LRU(history, self.seen)
# TODO: pull this out from history
self.non_hash_seen = deque(maxlen=history)

Stream.__init__(self, upstream, **kwargs)
Expand Down
12 changes: 9 additions & 3 deletions rapidz/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from builtins import zip as szip

from rapidz import apply
from rapidz.core import _truthy, args_kwargs, move_to_first
from rapidz.core import _truthy, args_kwargs, move_to_first, identity
from rapidz.core import get_io_loop
from rapidz.clients import DEFAULT_BACKENDS, FILL_COLOR_LOOKUP
from rapidz.clients import DEFAULT_BACKENDS, FILL_COLOR_LOOKUP, result_maybe
from operator import getitem

from tornado import gen
Expand Down Expand Up @@ -447,7 +447,8 @@ def update(self, x, who=None):
if len(L) == 1 and all(self.buffers.values()):
client = self.default_client()

tup = tuple(client.submit(future_chain, self.buffers[up][0], self.future_buffers[up]) for up in self.upstreams)
tup = tuple(client.submit(future_chain, self.buffers[up][0],
self.future_buffers[up]) for up in self.upstreams)
for buf in self.buffers.values():
buf.popleft()
for t, up in szip(tup, self.upstreams):
Expand Down Expand Up @@ -480,6 +481,11 @@ def is_unique(x, past):
return NULL_COMPUTE
return x


# TODO: this doesn't work on dask (and might never work)
# The main issue here is that we need a central storage place for all the
# history but dask might not be so friendly to this idea
# This (or something very like it) works in actual ipython dask
@args_kwargs
@ParallelStream.register_api()
class unique(ParallelStream):
Expand Down
104 changes: 104 additions & 0 deletions rapidz/tests/test_parallel_filter_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from distributed.utils_test import (
gen_cluster,
) # flake8: noqa
from rapidz import Stream
from rapidz.parallel import scatter, NULL_COMPUTE


@gen_cluster(client=True)
def test_filter_combine_latest(c, s, a, b):
source = Stream(asynchronous=True)

s = scatter(source)
futures = s.filter(lambda x: x % 2 == 0).combine_latest(s)
L = futures.gather().sink_to_list()

presents = source.filter(lambda x: x % 2 == 0).combine_latest(source)

LL = presents.sink_to_list()

for i in range(5):
yield source.emit(i)

assert L == LL


@gen_cluster(client=True)
def test_filter_combine_latest_odd(c, s, a, b):
source = Stream(asynchronous=True)

s = scatter(source)
futures = s.filter(lambda x: x % 2 == 1).combine_latest(s)
L = futures.gather().sink_to_list()

presents = source.filter(lambda x: x % 2 == 1).combine_latest(source)

LL = presents.sink_to_list()

for i in range(5):
yield source.emit(i)

assert L == LL


@gen_cluster(client=True)
def test_filter_combine_latest_emit_on(c, s, a, b):
source = Stream(asynchronous=True)

s = scatter(source)
futures = s.filter(lambda x: x % 2 == 1).combine_latest(s, emit_on=0)
L = futures.gather().sink_to_list()

presents = source.filter(lambda x: x % 2 == 1).combine_latest(source,
emit_on=0)

LL = presents.sink_to_list()

for i in range(5):
yield source.emit(i)

assert L == LL


@gen_cluster(client=True)
def test_filter_combine_latest_triple(c, s, a, b):
source = Stream(asynchronous=True)

s = scatter(source)
futures = s.filter(lambda x: x % 3 == 1).combine_latest(s)
L = futures.gather().sink_to_list()

presents = source.filter(lambda x: x % 3 == 1).combine_latest(source)

LL = presents.sink_to_list()

for i in range(10):
yield source.emit(i)

assert L == LL


@gen_cluster(client=True)
def test_unique(c, s, a, b):
source = Stream(asynchronous=True)

def acc_func(state, x):
if x in state:
return state, NULL_COMPUTE
state.append(x)
return state, x

s = scatter(source)
futures = s.accumulate(acc_func, start=[], returns_state=True)
L = futures.gather().sink_to_list()

presents = source.accumulate(acc_func, start=[], returns_state=True)

LL = presents.sink_to_list()

for i in range(10):
if i % 2 == 1:
i = i - 1
yield source.emit(i)

assert L == LL

0 comments on commit 6807aaa

Please sign in to comment.