Skip to content

Commit

Permalink
Fix linting error (#437)
Browse files Browse the repository at this point in the history
* fix: linting error and testing error;
  • Loading branch information
WenjieDu authored Jun 18, 2024
1 parent d1d4e6d commit 13b04b4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 13 deletions.
1 change: 0 additions & 1 deletion .github/workflows/testing_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ jobs:
run: |
which python
which pip
pip install --upgrade
pip install numpy==1.24 torch==${{ matrix.pytorch-version }} -f https://download.pytorch.org/whl/cpu
python -c "import torch; print('PyTorch:', torch.__version__)"
Expand Down
8 changes: 1 addition & 7 deletions pypots/nn/modules/reformer/local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ def apply_rotary_pos_emb(q, k, freqs, scale=1):
return q, k


def exists(val):
return val is not None


def default(value, d):
return d if not exists(value) else value

Expand Down Expand Up @@ -186,7 +182,6 @@ def forward(
), "cannot perform window size extrapolation if xpos is not turned on"

(
shape,
autopad,
pad_value,
window_size,
Expand All @@ -195,7 +190,6 @@ def forward(
look_forward,
shared_qk,
) = (
q.shape,
self.autopad,
-1,
default(window_size, self.window_size),
Expand All @@ -216,7 +210,7 @@ def forward(
lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v)
)

b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
b, n, dim_head, device = *q.shape, q.device

scale = default(self.scale, dim_head**-0.5)

Expand Down
9 changes: 4 additions & 5 deletions pypots/nn/modules/reformer/lsh_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ def merge_dims(ind_from, ind_to, tensor):

def split_at_index(dim, index, t):
pre_slices = (slice(None),) * dim
l = (*pre_slices, slice(None, index))
r = (*pre_slices, slice(index, None))
return t[l], t[r]
l_ = (*pre_slices, slice(None, index))
r_ = (*pre_slices, slice(index, None))
return t[l_], t[r_]


class FullQKAttention(nn.Module):
Expand Down Expand Up @@ -608,10 +608,9 @@ def forward(
**kwargs,
):
device, dtype = x.device, x.dtype
b, t, e, h, dh, m, l_h = (
b, t, e, h, m, l_h = (
*x.shape,
self.heads,
self.dim_head,
self.num_mem_kv,
self.n_local_attn_heads,
)
Expand Down

0 comments on commit 13b04b4

Please sign in to comment.