diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index aad4da282b78..26ea74a5fa2e 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -38,6 +38,7 @@ ) from ...modeling_utils import PreTrainedModel from ...utils import ( + add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, @@ -54,6 +55,9 @@ from ...modeling_flash_attention_utils import _flash_attention_forward +_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" + + class GlmRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -1222,6 +1226,11 @@ def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 55bf89d1c56b..2be64b60cd46 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -44,6 +44,8 @@ from .configuration_glm import GlmConfig +_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" + logger = logging.get_logger(__name__)