|  | 
| 18 | 18 | from __future__ import division | 
| 19 | 19 | from __future__ import print_function | 
| 20 | 20 | 
 | 
|  | 21 | +import random | 
| 21 | 22 | import jax | 
| 22 | 23 | import numpy as onp | 
| 23 | 24 | 
 | 
| @@ -559,11 +560,17 @@ def body_fun(vals):  # pylint: disable=invalid-name | 
| 559 | 560 | class MergedHashedCausalAttention(BaseCausalAttention): | 
| 560 | 561 |   """Hash-based causal attention.""" | 
| 561 | 562 | 
 | 
| 562 |  | -  def __init__(self, dropout, mode, n_bins=64, bin_by_time=False): | 
|  | 563 | +  def __init__(self, dropout, mode, n_bins=64, | 
|  | 564 | +               bin_by_time=False, one_rng=False): | 
| 563 | 565 |     del dropout, mode | 
| 564 | 566 |     super(MergedHashedCausalAttention, self).__init__() | 
| 565 | 567 |     self.n_bins = n_bins | 
| 566 | 568 |     self.bin_by_time = bin_by_time | 
|  | 569 | +    seed = random.randint(0, 2**31 - 1) | 
|  | 570 | +    self._one_rng = one_rng | 
|  | 571 | +    self._prng = None | 
|  | 572 | +    if one_rng: | 
|  | 573 | +      self._prng = backend.random.get_prng(seed) | 
| 567 | 574 | 
 | 
| 568 | 575 |   def call(self, inputs, params=(), state=(), **kwargs): | 
| 569 | 576 |     del params | 
| @@ -604,8 +611,12 @@ def hash_vectors(self, vecs, rng): | 
| 604 | 611 |     # It's not clear whether sampling a different random rotation for each head | 
| 605 | 612 |     # and batch element matters here, but see MergedMultiHashedCausalAttention. | 
| 606 | 613 |     assert self.n_bins % 2 == 0 | 
|  | 614 | +    rot_rng = rng | 
|  | 615 | +    if self._one_rng: | 
|  | 616 | +      rot_rng = jax.lax.tie_in(vecs, self._prng) | 
| 607 | 617 |     random_rotation = jax.random.normal( | 
| 608 |  | -        rng, (vecs.shape[0], vecs.shape[-1], self.n_bins//2)).astype('float32') | 
|  | 618 | +        rot_rng, | 
|  | 619 | +        (vecs.shape[0], vecs.shape[-1], self.n_bins//2)).astype('float32') | 
| 609 | 620 | 
 | 
| 610 | 621 |     # TODO(kitaev): making the vectors unit-length here is probably redundant. | 
| 611 | 622 |     vecs = self.make_unit_length(vecs) | 
| @@ -735,12 +746,18 @@ def binned_attn_vjp(sqk, sv, so_ct):  # pylint: disable=invalid-name | 
| 735 | 746 | class MergedMultiHashedCausalAttention(BaseCausalAttention): | 
| 736 | 747 |   """Hash-based causal attention, with multiple hashes.""" | 
| 737 | 748 | 
 | 
| 738 |  | -  def __init__(self, dropout, mode, n_bins=64, n_hashes=1, bin_by_time=False): | 
|  | 749 | +  def __init__(self, dropout, mode, n_bins=64, n_hashes=1, | 
|  | 750 | +               bin_by_time=False, one_rng=False): | 
| 739 | 751 |     del dropout, mode | 
| 740 | 752 |     super(MergedMultiHashedCausalAttention, self).__init__() | 
| 741 | 753 |     self.n_bins = n_bins | 
| 742 | 754 |     self.n_hashes = n_hashes | 
| 743 | 755 |     self.bin_by_time = bin_by_time | 
|  | 756 | +    seed = random.randint(0, 2**31 - 1) | 
|  | 757 | +    self._one_rng = one_rng | 
|  | 758 | +    self._prng = None | 
|  | 759 | +    if one_rng: | 
|  | 760 | +      self._prng = backend.random.get_prng(seed) | 
| 744 | 761 | 
 | 
| 745 | 762 |   def bin_vectors_by_time(self, vecs): | 
| 746 | 763 |     seqlen = vecs.shape[-2] | 
| @@ -770,8 +787,12 @@ def hash_vectors(self, vecs, rng): | 
| 770 | 787 |     # of vecs. Applying multiple hashes to the same input is important because | 
| 771 | 788 |     # it increases the probability of being in the same bin as relevant items. | 
| 772 | 789 |     assert self.n_bins % 2 == 0 | 
|  | 790 | +    rot_rng = rng | 
|  | 791 | +    if self._one_rng: | 
|  | 792 | +      rot_rng = jax.lax.tie_in(vecs, self._prng) | 
| 773 | 793 |     random_rotation = jax.random.normal( | 
| 774 |  | -        rng, (vecs.shape[0], vecs.shape[-1], self.n_bins//2)).astype('float32') | 
|  | 794 | +        rot_rng, | 
|  | 795 | +        (vecs.shape[0], vecs.shape[-1], self.n_bins//2)).astype('float32') | 
| 775 | 796 | 
 | 
| 776 | 797 |     # TODO(kitaev): making the vectors unit-length here is probably redundant. | 
| 777 | 798 |     vecs = self.make_unit_length(vecs) | 
|  | 
0 commit comments