-
Notifications
You must be signed in to change notification settings - Fork 4
/
dynamics.py
52 lines (44 loc) · 1.62 KB
/
dynamics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Dict, Any
import jax
import jax.numpy as jnp
import flax.linen as nn
from utils.nn import STTransformer
class DynamicsMaskGIT(nn.Module):
"""MaskGIT dynamics model"""
model_dim: int
num_latents: int
num_blocks: int
num_heads: int
dropout: float
mask_limit: float
def setup(self):
self.dynamics = STTransformer(
self.model_dim,
self.num_latents,
self.num_blocks,
self.num_heads,
self.dropout,
)
self.patch_embed = nn.Embed(self.num_latents, self.model_dim)
self.mask_token = self.param(
"mask_token",
nn.initializers.lecun_uniform(),
(1, 1, 1, self.model_dim),
)
self.action_up = nn.Dense(self.model_dim)
def __call__(self, batch: Dict[str, Any], training: bool = True) -> Dict[str, Any]:
# --- Mask videos ---
vid_embed = self.patch_embed(batch["video_tokens"])
if training:
rng1, rng2 = jax.random.split(batch["mask_rng"])
mask_prob = jax.random.uniform(rng1, minval=self.mask_limit)
mask = jax.random.bernoulli(rng2, mask_prob, vid_embed.shape[:-1])
mask = mask.at[:, 0].set(False)
vid_embed = jnp.where(jnp.expand_dims(mask, -1), self.mask_token, vid_embed)
else:
mask = None
# --- Predict transition ---
act_embed = self.action_up(batch["latent_actions"])
vid_embed += jnp.pad(act_embed, ((0, 0), (1, 0), (0, 0), (0, 0)))
logits = self.dynamics(vid_embed)
return dict(token_logits=logits, mask=mask)