26
26
import rofunc as rf
27
27
from rofunc .learning .RofuncRL .agents .base_agent import BaseAgent
28
28
from rofunc .learning .RofuncRL .agents .mixline .amp_agent import AMPAgent
29
+ from rofunc .learning .RofuncRL .models .misc_models import ASEDiscEnc
29
30
from rofunc .learning .RofuncRL .models .base_models import BaseMLP
30
31
from rofunc .learning .RofuncRL .utils .memory import Memory
31
32
@@ -72,6 +73,13 @@ def __init__(self,
72
73
self ._enc_reward_weight = cfg .Agent .enc_reward_weight
73
74
74
75
'''Define ASE specific models except for AMP'''
76
+ # self.discriminator = ASEDiscEnc(cfg.Model,
77
+ # input_dim=amp_observation_space.shape[0],
78
+ # enc_output_dim=self._ase_latent_dim,
79
+ # disc_output_dim=1,
80
+ # cfg_name='encoder').to(device)
81
+ # self.encoder = self.discriminator
82
+
75
83
self .encoder = BaseMLP (cfg .Model ,
76
84
input_dim = amp_observation_space .shape [0 ],
77
85
output_dim = self ._ase_latent_dim ,
@@ -95,10 +103,11 @@ def __init__(self,
95
103
96
104
def _set_up (self ):
97
105
super ()._set_up ()
98
- self .optimizer_enc = torch .optim .Adam (self .encoder .parameters (), lr = self ._lr_e , eps = self ._adam_eps )
99
- if self ._lr_scheduler is not None :
100
- self .scheduler_enc = self ._lr_scheduler (self .optimizer_enc , ** self ._lr_scheduler_kwargs )
101
- self .checkpoint_modules ["optimizer_enc" ] = self .optimizer_enc
106
+ if self .encoder is not self .discriminator :
107
+ self .optimizer_enc = torch .optim .Adam (self .encoder .parameters (), lr = self ._lr_e , eps = self ._adam_eps )
108
+ if self ._lr_scheduler is not None :
109
+ self .scheduler_enc = self ._lr_scheduler (self .optimizer_enc , ** self ._lr_scheduler_kwargs )
110
+ self .checkpoint_modules ["optimizer_enc" ] = self .optimizer_enc
102
111
103
112
def act (self , states : torch .Tensor , deterministic : bool = False , ase_latents : torch .Tensor = None ):
104
113
if self ._current_states is not None :
@@ -173,7 +182,10 @@ def update_net(self):
173
182
style_rewards *= self ._discriminator_reward_scale
174
183
175
184
# Compute encoder reward
176
- enc_output = self .encoder (self ._amp_state_preprocessor (amp_states ))
185
+ if self .encoder is self .discriminator :
186
+ enc_output = self .encoder .get_enc (self ._amp_state_preprocessor (amp_states ))
187
+ else :
188
+ enc_output = self .encoder (self ._amp_state_preprocessor (amp_states ))
177
189
enc_output = torch .nn .functional .normalize (enc_output , dim = - 1 )
178
190
enc_reward = torch .clamp_min (torch .sum (enc_output * ase_latents , dim = - 1 , keepdim = True ), 0.0 )
179
191
enc_reward *= self ._enc_reward_scale
@@ -311,7 +323,10 @@ def update_net(self):
311
323
discriminator_loss *= self ._discriminator_loss_scale
312
324
313
325
# encoder loss
314
- enc_output = self .encoder (self ._amp_state_preprocessor (sampled_amp_states ))
326
+ if self .encoder is self .discriminator :
327
+ enc_output = self .encoder .get_enc (self ._amp_state_preprocessor (sampled_amp_states ))
328
+ else :
329
+ enc_output = self .encoder (self ._amp_state_preprocessor (sampled_amp_states_batch ))
315
330
enc_output = torch .nn .functional .normalize (enc_output , dim = - 1 )
316
331
enc_err = - torch .sum (enc_output * sampled_ase_latents , dim = - 1 , keepdim = True )
317
332
enc_loss = torch .mean (enc_err )
@@ -357,17 +372,21 @@ def update_net(self):
357
372
358
373
# Update discriminator network
359
374
self .optimizer_disc .zero_grad ()
360
- discriminator_loss .backward ()
375
+ if self .encoder is self .discriminator :
376
+ (discriminator_loss + enc_loss ).backward ()
377
+ else :
378
+ discriminator_loss .backward ()
361
379
if self ._grad_norm_clip > 0 :
362
380
nn .utils .clip_grad_norm_ (self .discriminator .parameters (), self ._grad_norm_clip )
363
381
self .optimizer_disc .step ()
364
382
365
383
# Update encoder network
366
- self .optimizer_enc .zero_grad ()
367
- enc_loss .backward ()
368
- if self ._grad_norm_clip > 0 :
369
- nn .utils .clip_grad_norm_ (self .encoder .parameters (), self ._grad_norm_clip )
370
- self .optimizer_enc .step ()
384
+ if self .encoder is not self .discriminator :
385
+ self .optimizer_enc .zero_grad ()
386
+ enc_loss .backward ()
387
+ if self ._grad_norm_clip > 0 :
388
+ nn .utils .clip_grad_norm_ (self .encoder .parameters (), self ._grad_norm_clip )
389
+ self .optimizer_enc .step ()
371
390
372
391
# update cumulative losses
373
392
cumulative_policy_loss += policy_loss .item ()
@@ -382,7 +401,8 @@ def update_net(self):
382
401
self .scheduler_policy .step ()
383
402
self .scheduler_value .step ()
384
403
self .scheduler_disc .step ()
385
- self .scheduler_enc .step ()
404
+ if self .encoder is not self .discriminator :
405
+ self .scheduler_enc .step ()
386
406
387
407
# update AMP replay buffer
388
408
self .replay_buffer .add_samples (states = amp_states .view (- 1 , amp_states .shape [- 1 ]))
@@ -407,4 +427,5 @@ def update_net(self):
407
427
self .track_data ("Learning / Learning rate (policy)" , self .scheduler_policy .get_last_lr ()[0 ])
408
428
self .track_data ("Learning / Learning rate (value)" , self .scheduler_value .get_last_lr ()[0 ])
409
429
self .track_data ("Learning / Learning rate (discriminator)" , self .scheduler_disc .get_last_lr ()[0 ])
410
- self .track_data ("Learning / Learning rate (encoder)" , self .scheduler_enc .get_last_lr ()[0 ])
430
+ if self .encoder is not self .discriminator :
431
+ self .track_data ("Learning / Learning rate (encoder)" , self .scheduler_enc .get_last_lr ()[0 ])
0 commit comments