Skip to content

Commit

Permalink
Merge branch 'main' into nlp/decode-yield-eot-token
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu authored Jan 24, 2025
2 parents 2ab9092 + 42c52bf commit aef0b8b
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 21 deletions.
6 changes: 3 additions & 3 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ echo "Using pip executable: $PIP_EXECUTABLE"
# NOTE: If a newly-fetched version of the executorch repo changes the value of
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
# package versions.
PYTORCH_NIGHTLY_VERSION=dev20250119
PYTORCH_NIGHTLY_VERSION=dev20250124

# Nightly version for torchvision
VISION_NIGHTLY_VERSION=dev20250119
VISION_NIGHTLY_VERSION=dev20250124

# Nightly version for torchtune
TUNE_NIGHTLY_VERSION=dev20250119
TUNE_NIGHTLY_VERSION=dev20250124

# The pip repository that hosts nightly torch packages. cpu by default.
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
Expand Down
2 changes: 2 additions & 0 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,8 @@ def callback(x, *, done_generating=False):
max_seq_length=max_seq_length,
attention_backend=self.builder_args.attention_backend,
)
if generator_args.chat_mode:
start_pos += encoded.size(0)
for token_tensor, metrics in generator_func:
if token_tensor is not None:
start_pos += token_tensor.size(0)
Expand Down
20 changes: 3 additions & 17 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def __init__(self, config: TransformerArgs) -> None:
self.layers[str(layer_id)] = TransformerBlock(config)

if config.stage_idx == config.n_stages - 1:
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.output.weight = self.tok_embeddings.weight
Expand Down Expand Up @@ -751,8 +751,8 @@ def __init__(self, config: TransformerArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
self.ffn_norm = nn.RMSNorm(config.dim, config.norm_eps)
self.attention_norm = nn.RMSNorm(config.dim, config.norm_eps)
# None for llama architecture, set for granite architectures
self.residual_multiplier = (
config.residual_multiplier
Expand Down Expand Up @@ -928,20 +928,6 @@ def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))


class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight


def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]):
# Check for the presence of the required keys
required_keys = {
Expand Down
2 changes: 1 addition & 1 deletion torchchat/utils/scripts/install_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ install_executorch_python_libs() {
echo "Building and installing python libraries"
if [ "${ENABLE_ET_PYBIND}" = false ]; then
echo "Not installing pybind"
bash ./install_requirements.sh
bash ./install_requirements.sh --pybind off
else
echo "Installing pybind"
bash ./install_requirements.sh --pybind xnnpack
Expand Down

0 comments on commit aef0b8b

Please sign in to comment.