Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 84 additions & 45 deletions docs/design/module/dit_module.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,80 +127,117 @@ def step(self, requests: list[OmniDiffusionRequest]):

## 2. Scheduler

**Location**: `vllm_omni/diffusion/scheduler.py`
**Location**: `vllm_omni/diffusion/sched/`

### Architecture

The `Scheduler` is implemented as a **Singleton** pattern to ensure a single coordination point across the system, i.e., only one scheduler instance exists for coordination.
The scheduler is a **request-state scheduler**. It owns request lifecycle management and scheduling decisions, while execution stays in `DiffusionEngine` and the executor.

### Key Components

#### 2.1 Message Queue System
#### 2.1 Scheduler Interface

```python
class Scheduler:
def initialize(self, od_config: OmniDiffusionConfig):
# Broadcast queue: scheduler -> all workers
self.mq = MessageQueue(
n_reader=od_config.num_gpus,
n_local_reader=od_config.num_gpus,
local_reader_ranks=list(range(od_config.num_gpus)),
)

# Result queue: rank 0 worker -> scheduler
self.result_mq = None # Initialized later
class SchedulerInterface(ABC):
def add_request(self, request: OmniDiffusionRequest) -> str: ...
def schedule(self) -> DiffusionSchedulerOutput: ...
def update_from_output(
self,
sched_output: DiffusionSchedulerOutput,
output: DiffusionOutput,
) -> set[str]: ...
```

**Communication Pattern**:
**Responsibilities**:

- **Broadcast Queue**: One-to-many communication (scheduler → all workers)
- **Lifecycle contract**: Defines how the engine adds requests, triggers one scheduling cycle, and feeds executor results back.

- **Result Queue**: One-to-one communication (rank 0 → scheduler)
- **Stable boundary**: `DiffusionSchedulerOutput` is the only scheduling result consumed by `DiffusionEngine`.

- **Shared Memory**: Uses `MessageQueue` (ZMQ-based) for efficient IPC
- **Pluggability**: Different scheduler policies can reuse the same engine integration path.

#### 2.2 Request Distribution
#### 2.2 Request State Model

```python
def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
# Broadcast request to all workers
self.mq.enqueue(requests)
class DiffusionRequestStatus(enum.IntEnum):
WAITING = ...
RUNNING = ...
PREEMPTED = ...
FINISHED_COMPLETED = ...
FINISHED_ABORTED = ...
FINISHED_ERROR = ...

@dataclass
class DiffusionRequestState:
sched_req_id: str
req: OmniDiffusionRequest
status: DiffusionRequestStatus = DiffusionRequestStatus.WAITING
```

# Wait for result from Rank 0
output = self.result_mq.dequeue()
return output
**Design Features**:

- **Scheduler-owned ID**: Each `OmniDiffusionRequest` is tracked by an internal `sched_req_id`, separated from public `request_id` values.

- **Explicit lifecycle**: Requests move through waiting, running, optional preemption, and terminal states.

- **Centralized error handling**: Completion, abort, and error states are all normalized in the scheduler layer.

#### 2.3 Shared Bookkeeping in `_BaseScheduler`

```python
class _BaseScheduler(SchedulerInterface):
def __init__(self) -> None:
self._request_states = {}
self._request_id_to_sched_req_id = {}
self._waiting = deque()
self._running = []
self._finished_req_ids = set()
self._max_batch_size = 1
```

**Design Features**:

- **Broadcast Model**: All workers receive the same request (for tensor parallelism)
- **Common state storage**: Shared request maps and waiting/running sets live in the base class.

- **Single Response**: Only rank 0 sends results back (avoids duplicate outputs)
- **Shared cleanup logic**: Request-id registration, finish handling, and state removal are centralized instead of duplicated in each policy.

- **Synchronous**: Blocks until result is received (can be made async)
- **Current constraint**: `_max_batch_size` remains `1` because the current engine path is still synchronous request-mode execution.

#### 2.3 Singleton Pattern
#### 2.4 Current `RequestScheduler` Policy

```python
class Scheduler:
_instance = None
class RequestScheduler(_BaseScheduler):
def schedule(self) -> DiffusionSchedulerOutput:
# 1. keep existing RUNNING requests in the scheduling result
# 2. pull WAITING requests while capacity remains
# 3. move newly admitted requests into RUNNING
```

**Behavior**:

- **FIFO request scheduling**: Waiting requests are promoted in queue order.

def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
- **Single-request admission**: The current policy only admits one active request at a time.

# Global singleton instance
scheduler = Scheduler()
- **Executor result feedback**: `update_from_output()` converts executor output into `FINISHED_COMPLETED` or `FINISHED_ERROR` and returns finished scheduler ids.

#### 2.5 Engine-Driven Execution Loop

```python
sched_req_id = scheduler.add_request(request)
while True:
sched_output = scheduler.schedule()
output = executor.add_req(req)
finished_req_ids = scheduler.update_from_output(sched_output, output)
```

**Benefits**:
**Design Decisions**:

- **Single Point of Control**: Ensures consistent state
- **Separation of concerns**: Scheduler manages state and policy; executor handles runtime execution.

- **Easy Access**: Global `scheduler` instance accessible everywhere
- **No scheduler-owned IPC**: Scheduler no longer talks to workers directly.

- **Resource Management**: Centralized queue management
- **Conservative concurrency**: The current request-mode implementation still allows only one active request at a time.

---

Expand Down Expand Up @@ -880,8 +917,9 @@ def initialize_model_parallel(
└─> Model-specific transformations

3. Scheduling
└─> scheduler.add_req(requests)
└─> Broadcast via MessageQueue to all workers
└─> scheduler.add_request(request)
└─> scheduler.schedule()
└─> DiffusionEngine submits scheduled request to executor.add_req(req)

4. Worker Execution
└─> WorkerProc.worker_busy_loop()
Expand All @@ -895,8 +933,9 @@ def initialize_model_parallel(
└─> vae.decode()

5. Result Collection
└─> Rank 0 sends DiffusionOutput via result queue
└─> Scheduler receives and returns
└─> Executor returns DiffusionOutput
└─> scheduler.update_from_output(...)
└─> DiffusionEngine pops finished request state

6. Post-processing
└─> post_process_func(output)
Expand Down
Loading
Loading