-
Notifications
You must be signed in to change notification settings - Fork 417
feature(pu): adapt to unizero-multitask ddp, and adapt ppo to support jericho config #858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…n mask), e.g. detective env
| if param.grad is not None: | ||
| allreduce(param.grad.data) | ||
| else: | ||
| # 如果梯度为 None,则创建一个与 param.grad_size 相同的零张量,并执行 allreduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code and add English comment, then these modifications will be merged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| # dist.init_process_group(backend=backend, rank=rank, world_size=world_size) | ||
| # TODO: | ||
| import datetime | ||
| dist.init_process_group(backend=backend, rank=rank, world_size=world_size, timeout=datetime.timedelta(seconds=60000)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why add this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是因为执行程序前的环境变量不起作用,在这里显式传入才能生效,我将其改成了默认值,作为最后一个参数哈
| # if self._rank == 0: | ||
| # self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) | ||
|
|
||
| self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add an argument named only_monitor_rank0 to control the logic, defaults to True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| for k in engine.log_buffer: | ||
| engine.log_buffer[k].clear() | ||
| return | ||
| # if engine.rank != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also pass the only_monitor_rank0 argument to the hook class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| self._global_state_encoder = nn.Identity() | ||
| elif len(global_obs_shape) == 3: | ||
| self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation) | ||
| self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why BN rather than using LN as default here
| agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0) | ||
| agent_state = agent_state.unsqueeze(0) | ||
| if single_step and len(global_state.shape) == 2: | ||
| global_state = global_state.unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add shape comments
| agent_q_act = agent_q_act.squeeze(-1) # T, B, A | ||
| if self.mixer: | ||
| global_state_embedding = self._global_state_encoder(global_state) | ||
| if len(global_state.shape) == 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add some comments
| """ | ||
| if self.share_encoder: | ||
| x = self.encoder(x) | ||
| # import ipdb;ipdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modify the corresponding API comments, and the isinstance(x, dict) to control the logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
We have a new polished PR: #860 |
Description
Related Issue
TODO
Check List