Skip to content
Merged

Fix glm #34388

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,26 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_glm import GlmConfig


if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward

from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...processing_utils import Unpack


_CHECKPOINT_FOR_DOC = "dummy"


_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b"


Expand Down
4 changes: 1 addition & 3 deletions src/transformers/models/glm/modular_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@
from .configuration_glm import GlmConfig


_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b"

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "dummy"
_CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b"


class GlmRMSNorm(Phi3RMSNorm):
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -39,17 +40,13 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_phi3 import Phi3Config


if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
Expand Down