Skip to content

Commit 1c9ae1b

Browse files
committed
fix(ir): handle the case of non-overlapping data and add a test
1 parent 4af9bd7 commit 1c9ae1b

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

ibis/backends/tests/test_join.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,10 @@ def test_join_with_trivial_predicate(awards_players, predicate, how, pandas_valu
328328
assert len(result) == len(expected)
329329

330330

331+
@pytest.mark.notimpl(
332+
["druid"], raises=sa.exc.NoSuchTableError, reason="`win` table isn't loaded"
333+
)
334+
@pytest.mark.notimpl(["flink"], reason="`win` table isn't loaded")
331335
@pytest.mark.parametrize(
332336
("how", "nrows"),
333337
[
@@ -349,17 +353,30 @@ def test_join_with_trivial_predicate(awards_players, predicate, how, pandas_valu
349353
),
350354
],
351355
)
352-
@pytest.mark.notimpl(
353-
["druid"], raises=sa.exc.NoSuchTableError, reason="`win` table isn't loaded"
356+
@pytest.mark.parametrize(
357+
("gen_right", "keys"),
358+
[
359+
param(
360+
lambda left: left.filter(lambda t: t.x == 1).select(y=lambda t: t.x),
361+
[("x", "y")],
362+
id="non_overlapping",
363+
marks=[pytest.mark.notyet(["polars"], reason="renaming fails")],
364+
),
365+
param(
366+
lambda left: left.filter(lambda t: t.x == 1),
367+
"x",
368+
id="overlapping",
369+
marks=[pytest.mark.notimpl(["pyspark"], reason="overlapping columns")],
370+
),
371+
],
354372
)
355-
@pytest.mark.notimpl(["flink"], reason="`win` table isn't loaded")
356-
def test_outer_join_nullability(backend, how, nrows):
373+
def test_outer_join_nullability(backend, how, nrows, gen_right, keys):
357374
win = backend.win
358375
left = win.select(x=lambda t: t.x.cast(t.x.type().copy(nullable=False))).filter(
359376
lambda t: t.x.isin((1, 2))
360377
)
361-
right = left.filter(lambda t: t.x == 1)
362-
expr = left.join(right, "x", how=how)
378+
right = gen_right(left)
379+
expr = left.join(right, keys, how=how)
363380
assert all(typ.nullable for typ in expr.schema().types)
364381

365382
result = expr.to_pyarrow()

ibis/expr/operations/relations.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -293,17 +293,29 @@ class InnerJoin(Join):
293293

294294
@public
295295
class LeftJoin(Join):
296-
pass
296+
@property
297+
def schema(self) -> Schema:
298+
return Schema(
299+
{name: typ.copy(nullable=True) for name, typ in super().schema.items()}
300+
)
297301

298302

299303
@public
300304
class RightJoin(Join):
301-
pass
305+
@property
306+
def schema(self) -> Schema:
307+
return Schema(
308+
{name: typ.copy(nullable=True) for name, typ in super().schema.items()}
309+
)
302310

303311

304312
@public
305313
class OuterJoin(Join):
306-
pass
314+
@property
315+
def schema(self) -> Schema:
316+
return Schema(
317+
{name: typ.copy(nullable=True) for name, typ in super().schema.items()}
318+
)
307319

308320

309321
@public
@@ -313,7 +325,11 @@ class AnyInnerJoin(Join):
313325

314326
@public
315327
class AnyLeftJoin(Join):
316-
pass
328+
@property
329+
def schema(self) -> Schema:
330+
return Schema(
331+
{name: typ.copy(nullable=True) for name, typ in super().schema.items()}
332+
)
317333

318334

319335
@public

0 commit comments

Comments
 (0)