Skip to content

Commit 1b610a6

Browse files
committed
feat: connect: collect
1 parent 4f1210c commit 1b610a6

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

tests/connect/test_collect.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from __future__ import annotations
2+
3+
import time
4+
5+
import pytest
6+
from pyspark.sql import SparkSession
7+
8+
9+
@pytest.fixture
10+
def spark_session():
11+
"""Fixture to create and clean up a Spark session."""
12+
from daft.daft import connect_start
13+
14+
# Start Daft Connect server
15+
server = connect_start("sc://localhost:50051")
16+
17+
# Initialize Spark Connect session
18+
session = SparkSession.builder.appName("DaftConfigTest").remote("sc://localhost:50051").getOrCreate()
19+
20+
yield session
21+
22+
# Cleanup
23+
server.shutdown()
24+
session.stop()
25+
time.sleep(2) # Allow time for session cleanup
26+
27+
28+
def test_range_collect(spark_session):
29+
# Create a range using Spark
30+
# For example, creating a range from 0 to 9
31+
spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9
32+
33+
# Collect the data
34+
collected_rows = spark_range.collect()
35+
36+
# Verify the collected data has expected values
37+
assert len(collected_rows) == 10, "Should have 10 rows"
38+
assert [row["id"] for row in collected_rows] == list(range(10)), "Should contain values 0-9"

0 commit comments

Comments
 (0)