Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions src/mock/ray/raylet/worker_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,107 @@ class MockWorkerPool : public WorkerPoolInterface {
GetRegisteredWorker,
(const WorkerID &worker_id),
(const, override));
MOCK_METHOD(std::shared_ptr<WorkerInterface>,
GetRegisteredWorker,
(const std::shared_ptr<ClientConnection> &connection),
(const, override));
MOCK_METHOD(std::shared_ptr<WorkerInterface>,
GetRegisteredDriver,
(const WorkerID &worker_id),
(const, override));
MOCK_METHOD(std::shared_ptr<WorkerInterface>,
GetRegisteredDriver,
(const std::shared_ptr<ClientConnection> &connection),
(const, override));
MOCK_METHOD(void,
HandleJobStarted,
(const JobID &job_id, const rpc::JobConfig &job_config),
(override));
MOCK_METHOD(void, HandleJobFinished, (const JobID &job_id), (override));
MOCK_METHOD(void, Start, (), (override));
MOCK_METHOD(void, SetNodeManagerPort, (int node_manager_port), (override));
MOCK_METHOD(void,
SetRuntimeEnvAgentClient,
(std::unique_ptr<RuntimeEnvAgentClient> runtime_env_agent_client),
(override));
MOCK_METHOD((std::vector<std::shared_ptr<WorkerInterface>>),
GetAllRegisteredDrivers,
(bool filter_dead_drivers),
(const, override));
MOCK_METHOD(Status,
RegisterDriver,
(const std::shared_ptr<WorkerInterface> &worker,
const rpc::JobConfig &job_config,
std::function<void(Status, int)> send_reply_callback),
(override));
MOCK_METHOD(Status,
RegisterWorker,
(const std::shared_ptr<WorkerInterface> &worker,
pid_t pid,
StartupToken worker_startup_token,
std::function<void(Status, int)> send_reply_callback),
(override));
MOCK_METHOD(Status,
RegisterWorker,
(const std::shared_ptr<WorkerInterface> &worker,
pid_t pid,
StartupToken worker_startup_token),
(override));
MOCK_METHOD(void,
OnWorkerStarted,
(const std::shared_ptr<WorkerInterface> &worker),
(override));
MOCK_METHOD(void,
PushSpillWorker,
(const std::shared_ptr<WorkerInterface> &worker),
(override));
MOCK_METHOD(void,
PushRestoreWorker,
(const std::shared_ptr<WorkerInterface> &worker),
(override));
MOCK_METHOD(void,
DisconnectWorker,
(const std::shared_ptr<WorkerInterface> &worker,
rpc::WorkerExitType disconnect_type),
(override));
MOCK_METHOD(void,
DisconnectDriver,
(const std::shared_ptr<WorkerInterface> &driver),
(override));
MOCK_METHOD(void,
PrestartWorkers,
(const TaskSpecification &task_spec, int64_t backlog_size),
(override));
MOCK_METHOD(void,
StartNewWorker,
(const std::shared_ptr<PopWorkerRequest> &pop_worker_request),
(override));
MOCK_METHOD(std::string, DebugString, (), (const, override));

MOCK_METHOD(void,
PopSpillWorker,
(std::function<void(std::shared_ptr<WorkerInterface>)> callback),
(override));

MOCK_METHOD(void,
PopRestoreWorker,
(std::function<void(std::shared_ptr<WorkerInterface>)> callback),
(override));

MOCK_METHOD(void,
PushDeleteWorker,
(const std::shared_ptr<WorkerInterface> &worker),
(override));

MOCK_METHOD(void,
PopDeleteWorker,
(std::function<void(std::shared_ptr<WorkerInterface>)> callback),
(override));

boost::optional<const rpc::JobConfig &> GetJobConfig(
const JobID &job_id) const override {
RAY_CHECK(false) << "Not used.";
return boost::none;
}
};
} // namespace ray::raylet
120 changes: 120 additions & 0 deletions src/ray/raylet/local_task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,126 @@ class MockWorkerPool : public WorkerPoolInterface {
return 0;
}

std::shared_ptr<WorkerInterface> GetRegisteredWorker(
const std::shared_ptr<ClientConnection> &connection) const override {
RAY_CHECK(false) << "Not used.";
return nullptr;
}

std::shared_ptr<WorkerInterface> GetRegisteredDriver(
const std::shared_ptr<ClientConnection> &connection) const override {
RAY_CHECK(false) << "Not used.";
return nullptr;
}

void HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) override {
RAY_CHECK(false) << "Not used.";
}

void HandleJobFinished(const JobID &job_id) override {
RAY_CHECK(false) << "Not used.";
}

void Start() override { RAY_CHECK(false) << "Not used."; }

void SetNodeManagerPort(int node_manager_port) override {
RAY_CHECK(false) << "Not used.";
}

void SetRuntimeEnvAgentClient(
std::unique_ptr<RuntimeEnvAgentClient> runtime_env_agent_client) override {
RAY_CHECK(false) << "Not used.";
}

std::vector<std::shared_ptr<WorkerInterface>> GetAllRegisteredDrivers(
bool filter_dead_drivers) const override {
RAY_CHECK(false) << "Not used.";
return {};
}

Status RegisterDriver(const std::shared_ptr<WorkerInterface> &worker,
const rpc::JobConfig &job_config,
std::function<void(Status, int)> send_reply_callback) override {
RAY_CHECK(false) << "Not used.";
return Status::Invalid("Not used.");
}

Status RegisterWorker(const std::shared_ptr<WorkerInterface> &worker,
pid_t pid,
StartupToken worker_startup_token,
std::function<void(Status, int)> send_reply_callback) override {
RAY_CHECK(false) << "Not used.";
return Status::Invalid("Not used.");
}

Status RegisterWorker(const std::shared_ptr<WorkerInterface> &worker,
pid_t pid,
StartupToken worker_startup_token) override {
RAY_CHECK(false) << "Not used.";
return Status::Invalid("Not used.");
}

boost::optional<const rpc::JobConfig &> GetJobConfig(
const JobID &job_id) const override {
RAY_CHECK(false) << "Not used.";
return boost::none;
}

void OnWorkerStarted(const std::shared_ptr<WorkerInterface> &worker) override {
RAY_CHECK(false) << "Not used.";
}

void PushSpillWorker(const std::shared_ptr<WorkerInterface> &worker) override {
RAY_CHECK(false) << "Not used.";
}

void PushRestoreWorker(const std::shared_ptr<WorkerInterface> &worker) override {
RAY_CHECK(false) << "Not used.";
}

void DisconnectWorker(const std::shared_ptr<WorkerInterface> &worker,
rpc::WorkerExitType disconnect_type) override {
RAY_CHECK(false) << "Not used.";
}

void DisconnectDriver(const std::shared_ptr<WorkerInterface> &driver) override {
RAY_CHECK(false) << "Not used.";
}

void PrestartWorkers(const TaskSpecification &task_spec,
int64_t backlog_size) override {
RAY_CHECK(false) << "Not used.";
}

void StartNewWorker(
const std::shared_ptr<PopWorkerRequest> &pop_worker_request) override {
RAY_CHECK(false) << "Not used.";
}

std::string DebugString() const override {
RAY_CHECK(false) << "Not used.";
return "";
}

void PopSpillWorker(
std::function<void(std::shared_ptr<WorkerInterface>)> callback) override {
RAY_CHECK(false) << "Not used.";
}

void PopRestoreWorker(
std::function<void(std::shared_ptr<WorkerInterface>)> callback) override {
RAY_CHECK(false) << "Not used.";
}

void PushDeleteWorker(const std::shared_ptr<WorkerInterface> &worker) override {
RAY_CHECK(false) << "Not used.";
}

void PopDeleteWorker(
std::function<void(std::shared_ptr<WorkerInterface>)> callback) override {
RAY_CHECK(false) << "Not used.";
}

std::list<std::shared_ptr<WorkerInterface>> workers;
absl::flat_hash_map<int, std::list<PopWorkerCallback>> callbacks;
int num_pops;
Expand Down
Loading