-
Notifications
You must be signed in to change notification settings - Fork 104
JAX implementation of CMAES #6
base: main
Are you sure you want to change the base?
Conversation
Dear @moskomule, Thanks for the contribution! As discussed in the README file, can you please test your implementation on the 4 tasks, and report the results? |
Hi, thank you. As you may know, CMA-ES includes eigendecomposition of covariance matrices, which makes it too slow for MNIST's CNN policy (at least in my environment). So, can I skip MNIST results? |
Do you mind sharing some log to show how slow it is? |
evojax/algo/__init__.py
Outdated
from .base import NEAlgorithm | ||
from .cma_wrapper import CMA | ||
from .pgpe import PGPE | ||
from .cmaes import CMAES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There can be multiple CMA-ES implementations in EvoJAX, we are considering to use the naming rule: {algo_name}_{src_name / contributor}
.
For example, we will change ours to something like CMAES_OriginalCPU
, and yours may be CMAES_CyberAgent
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will update so.
As for MNIST, I previously run a reduced CNN, but for the original CNN policy, even a single A6000 GPU throws OOM. So it's appreciated if you could run it. One question: can I modify the examples to select solvers as an option, or do you prefer to keep the examples as is? |
Thanks for the update. |
@moskomule I put together a small repository for running/logging the tests/benchmarks. Maybe this can be of help to: https://github.com/RobertTLange/evojax-benchmarks. I have also been thinking of porting the Let me know if you would like to work on something like that together. |
@RobertTLange Thanks! @moskomule |
Hi, I'm really amazed by this library.
Currently, CMAES is just a wrapper. I implemented a JAX CMAES based on https://github.com/CyberAgentAILab/cmaes/.