-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Use different serialization context for each driver in worker.py #2357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… times in cluster mode
… times in cluster mode
|
Test FAILed. |
|
Test FAILed. |
| # Continue because FunctionsToRun are the only things | ||
| # that the driver should import. | ||
| elif key.startswith(b"RegisterType"): | ||
| with log_span( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log_span should be changed to profile.
Also, I'm hoping duplicate code can be reduced. This function is already very lengthy.
| # we don't need to export it again. | ||
| return | ||
|
|
||
| if (len(pickled_function) > |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's a check_oversized_pickle helper function that can do this check.
| }) | ||
| self.redis_client.rpush("Exports", key) | ||
|
|
||
| def register_class_on_all_workers(self, function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's okay to implement RegisterType using the same mechanism as FunctionToRun.
But since the purpose of this function is to register serializer for a type, I think it'd make more sense to use the type, the serializer, etc as the parameters, instead of using a function.
Also, looks like that some code is copy-pasted from run_function_on_all_workers, let's also unify the 2 functions?
| function_name = self.function_execution_info[self.task_driver_id.id()][ | ||
| function_id.id()].function_name | ||
|
|
||
| if not self.serialization_context_map.has_key(self.task_driver_id): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has_key is removed in python 3.
we should use self.task_driver_id not in self.serialization_context_map instead.
|
|
||
|
|
||
| def register_existing_class(worker=global_worker): | ||
| export_keys = worker.redis_client.lrange("Exports", 0, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it query all existing Exports?
I guess it might hurt perf this way.
Is it possible to only query RegisterType or only query RegisterType with a given driver id?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a list in redis, we can't query subset of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can move the Register type out of the Exports, it will make the code more readable and improve the perf.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree.
|
@robertnishihara , could you please take a look? |
|
Hi, @ericl @richardliaw, would you please take a look? Appreciate you help! |
robertnishihara
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @surehb! I think this is a good idea.
I think this PR can be much simpler. E.g., can't we just make serialization_context_map a dictionary (like you've done) and that's about it? I don't think we need to introduce register_class_on_all_workers.
Note that down the road, we may need to just have different workers correspond to different drivers, since certain things will leak between the workers like module-level global variables.
|
|
||
| if not self.serialization_context_map.has_key(self.task_driver_id): | ||
| _initialize_serialization() | ||
| register_existing_class() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to add this? Shouldn't the regular import mechanism be sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To separate the content for different driver id, we have to do that, otherwise, it is hard to do GC for the context.
| }) | ||
| self.redis_client.rpush("Exports", key) | ||
|
|
||
| def register_class_on_all_workers(self, function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the difference between this code path and the "run function on all workers" code path? Why separate them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We must separate the normal "run function on all workers" and register class cases. The "register class" only be run when the task about related driver id run on the target worker, we separate it to make sure the functions in Export only be run for once.
|
@robertnishihara create different worker for different driver can resolve it, but it will change the whole ray design. We use a map to do that just to avoid the leak. When the job function is done, we can release the context in the map by the job_id(driver_id). |
|
I think this PR can be much much simpler. All that really needs to happen is the following:
Note that I agree that solving this problem is super important. |
|
@robertnishihara, register_class_for_serialization (you mentioned) is called via run_function_on_all_workers, then the function will be called on all the workers in the cluster, but we actually don't need that on the workers who don't handle the tasks from current driver. The code was written in current way in order to save memory. |
|
@robertnishihara the fix design for it with two considerations. The first one is the memory cost, we don't want to push the types to all the workers even it doesn't run the tasks create by the driver. The second one is failover, when a task created by the driver A create a custom class foo, it will send to all workers, it can works well. But when a worker crashed during execution, and be restarted. The register function will not be called, which will cause problem. |
|
@surehb @eric-jj, I agree that we should fix the memory issue. However, that that will require a lot more work than what is done in this PR. E.g., this PR may reduce the number of custom serializers that each worker has to register, but it does not reduce the number of remote function definitions. I'm more concerned about latency and complexity in this PR. I'd rather do a very simple fix (I think it can just be a couple lines) to fix the bug and then address the memory issue afterwards by redesigning the entire import mechanism. |
|
@robertnishihara, I created another PR (the code base of the previous one was very old on my local) for this issue, please check: #2406, I will close this one. |
What do these changes do?
Previously, we used shared serialization context for tasks generated by different drivers. This will cause problems some cases. For example, we have a task which will generate a new class (A) on worker during execution. Worker will try to register A if it fails to serialize it, see here. All the workers will receive the register message and those who have the same driver id will register the A into context. Every thing goes fine for the first time we run a task. However, since the worker runs in a while loop and never exit, so when the second time we run the same thing, (actually it starts from a different driver), the worker already has A registered in its context (by the previous execution) so it will not publish the register message, then the driver will never know how to deserialize it.
In this change:
PS.
Related issue number
2165
2288