Skip to content

Commit

Permalink
mla python part
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Feb 17, 2025
1 parent 235c24e commit 7ec7f02
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 64 deletions.
9 changes: 3 additions & 6 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,9 @@ def get_gencode_flags():
if cc >= 80:
sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"]

sources += [
"./gpu/append_attention.cu",
"./gpu/append_attn/get_block_shape_and_split_kv_block.cu",
"./gpu/append_attn/decoder_write_cache_with_rope_kernel.cu",
"./gpu/append_attn/speculate_write_cache_with_rope_kernel.cu",
]
sources += ["./gpu/append_attention.cu", "./gpu/multi_head_latent_attention.cu"]

sources += find_end_files("./gpu/append_attn", ".cu")
sources += find_end_files("./gpu/append_attn/template_instantiation", ".cu")


Expand Down
42 changes: 24 additions & 18 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ class PredictorArgument:
)
return_full_hidden_states: int = field(default=False, metadata={"help": "whether return full hidden_states"})

mla_use_matrix_absorption: bool = field(default=False, metadata={"help": "implement mla with matrix-absorption."})

def __post_init__(self):
if self.speculate_method is not None:
self.append_attn = True
Expand Down Expand Up @@ -418,7 +420,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
self.tgt_pos = None
else:
self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape]
self.num_layers, self.num_attention_heads, self.head_dim = (
self.num_layers, self.num_key_value_heads, self.head_dim = (
len(self.cache_kvs),
self.cache_kvs[0].shape[-3],
self.cache_kvs[0].shape[-1],
Expand Down Expand Up @@ -454,7 +456,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
self.num_layers,
2,
config.batch_size,
self.num_attention_heads,
self.num_key_value_heads,
prefix_cache.shape[-2],
self.head_dim,
],
Expand All @@ -464,7 +466,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
]
else:
prefix_cache = paddle.zeros(
[self.num_layers, 2, config.batch_size, self.num_attention_heads, 128, self.head_dim],
[self.num_layers, 2, config.batch_size, self.num_key_value_heads, 128, self.head_dim],
dtype=self.dtype,
)
self.pre_caches = [
Expand Down Expand Up @@ -759,7 +761,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
BasePredictor.__init__(self, config, tokenizer)

self.num_layers = len(self.cache_kvs_shape) // 2
self.num_attention_heads = self.cache_kvs_shape[0][-3]
self.num_key_value_heads = self.cache_kvs_shape[0][-3]
self.head_dim = self.cache_kvs_shape[0][-1]
self.max_block_nums = self.cache_kvs_shape[0][0]
self.batch_size = config.batch_size
Expand All @@ -780,7 +782,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
config.max_length -= self.pre_cache_length
self.pre_caches = [
paddle.zeros(
[config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim],
[config.batch_size, self.num_key_value_heads, self.pre_cache_length, self.head_dim],
dtype=self.dtype,
)
for _ in range(2 * self.num_layers)
Expand All @@ -804,19 +806,19 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):

if config.cachekv_int8_type == "dynamic":
self.k_quant_scales = [
paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32")
paddle.zeros([config.batch_size, self.num_key_value_heads], dtype="float32")
for _ in range(self.num_layers)
]
self.v_quant_scales = [
paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32")
paddle.zeros([config.batch_size, self.num_key_value_heads], dtype="float32")
for _ in range(self.num_layers)
]
self.k_dequant_scales = [
paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32")
paddle.zeros([config.batch_size, self.num_key_value_heads], dtype="float32")
for _ in range(self.num_layers)
]
self.v_dequant_scales = [
paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32")
paddle.zeros([config.batch_size, self.num_key_value_heads], dtype="float32")
for _ in range(self.num_layers)
]

Expand Down Expand Up @@ -884,7 +886,7 @@ def init_model_inputs(self, config: PredictorArgument):
shape=[config.batch_size, 1, 1, config.total_max_length], fill_value=1, dtype=self.dtype
)
arange_tensor_encoder = paddle.arange(config.total_max_length).astype(self.dtype)
alibi_slopes = llm_utils.get_alibi_slopes(self.num_attention_heads)
alibi_slopes = llm_utils.get_alibi_slopes(self.num_key_value_heads)
alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder
alibi_encoder = alibi.tile([config.batch_size, 1, config.total_max_length, 1])
alibi_decoder = alibi.tile(
Expand Down Expand Up @@ -912,7 +914,7 @@ def init_model_inputs(self, config: PredictorArgument):
shape=[config.batch_size, 1, 1, config.total_max_length], fill_value=1, dtype=self.dtype
)
arange_tensor_encoder = paddle.arange(config.total_max_length).astype(self.dtype)
alibi_slopes = llm_utils.get_alibi_slopes(self.num_attention_heads)
alibi_slopes = llm_utils.get_alibi_slopes(self.num_key_value_heads)
alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder
alibi_encoder = alibi.tile([config.batch_size, 1, config.total_max_length, 1])
alibi_decoder = alibi.tile(
Expand Down Expand Up @@ -1025,7 +1027,9 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = N
BlockInferencePredictorMixin.__init__(self, config, tokenizer)

cachekv_dtype = self.dtype if config.cachekv_int8_type is None else "uint8"
self.cache_kvs = [paddle.zeros(shape, dtype=cachekv_dtype) for shape in self.cache_kvs_shape]
self.cache_kvs = [
paddle.zeros(shape, dtype=cachekv_dtype) if shape is not None else None for shape in self.cache_kvs_shape
]

self.model = model

Expand Down Expand Up @@ -1149,12 +1153,14 @@ def __init__(

cachekv_dtype = config.dtype if config.cachekv_int8_type is None else "uint8"
for i in range(len(self.cache_kvs_shape) // 2):
self.model_inputs["key_caches_{}".format(i)] = paddle.zeros(
self.cache_kvs_shape[2 * i], dtype=cachekv_dtype
)
self.model_inputs["value_caches_{}".format(i)] = paddle.zeros(
self.cache_kvs_shape[2 * i + 1], dtype=cachekv_dtype
)
if self.cache_kvs_shape[2 * i] is not None:
self.model_inputs["key_caches_{}".format(i)] = paddle.zeros(
self.cache_kvs_shape[2 * i], dtype=cachekv_dtype
)
if self.cache_kvs_shape[2 * i + 1] is not None:
self.model_inputs["value_caches_{}".format(i)] = paddle.zeros(
self.cache_kvs_shape[2 * i + 1], dtype=cachekv_dtype
)

for i in range(self.num_layers):
if self.config.cachekv_int8_type == "dynamic":
Expand Down
75 changes: 57 additions & 18 deletions paddlenlp/experimental/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
self.quant_type == "weight_only_int8" or self.quant_type == "weight_only_int4"
), f"Expected quant_type equal to 'weight_only_int8' or 'weight_only_int4', but received {self.quant_type}"

assert config.append_attn is True

Check warning on line 202 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L202

Added line #L202 was not covered by tests

self.first_k_dense_replace = config.first_k_dense_replace
self.n_routed_experts = config.n_routed_experts

Expand Down Expand Up @@ -435,7 +437,32 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
for idx in range(self.num_layers)
]

self.prefill_cache_k_buffer: paddle.Tensor = None
self.prefill_cache_v_buffer: paddle.Tensor = None
if self.config.mla_use_matrix_absorption:
max_batch_size = 32
max_block_nums = max_batch_size * (self.max_seq_len + config.block_size - 1) // config.block_size
cache_k_shape = [

Check warning on line 445 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L440-L445

Added lines #L440 - L445 were not covered by tests
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.qk_nope_head_dim + config.qk_rope_head_dim,
]
cache_v_shape = [

Check warning on line 451 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L451

Added line #L451 was not covered by tests
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.v_head_dim,
]
self.prefill_cache_k_buffer = paddle.empty(shape=cache_k_shape, dtype=paddle.get_default_dtype())
self.prefill_cache_v_buffer = paddle.empty(shape=cache_v_shape, dtype=paddle.get_default_dtype())
self.register_buffer("prefill_cache_k_buffer", self.prefill_cache_k_buffer, persistable=False)
self.register_buffer("prefill_cache_v_buffer", self.prefill_cache_v_buffer, persistable=False)

Check warning on line 460 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L457-L460

Added lines #L457 - L460 were not covered by tests

mla_config = MLAConfig(
use_matrix_absorption=self.config.mla_use_matrix_absorption,
prefill_cache_k_buffer=self.prefill_cache_k_buffer,
prefill_cache_v_buffer=self.prefill_cache_v_buffer,
q_lora_rank=self.config.q_lora_rank,
kv_lora_rank=self.config.kv_lora_rank,
qk_nope_head_dim=self.config.qk_nope_head_dim,
Expand Down Expand Up @@ -507,7 +534,7 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str):
rank_id=config.tensor_parallel_rank,
moe_config=moe_config,
mla_config=mla_config,
append_attn=config.append_attn,
append_attn=True,
speculate_config=speculate_config,
)

Expand Down Expand Up @@ -977,7 +1004,7 @@ def get_cache_kvs_shape(
max_length (int | None, optional): the max_length of cache_kvs. Defaults to None.
Returns:
list[paddle.Tensor]: the list tensor shape for cache
list[list[int]]: the list tensor shape for cache
"""
max_block_per_seq = (config.max_seq_len + config.block_size - 1) // config.block_size
if max_batch_size == -1:
Expand All @@ -986,22 +1013,34 @@ def get_cache_kvs_shape(
max_block_nums = max_batch_size * max_block_per_seq

cache_kvs = []
for _ in range(config.num_hidden_layers):
cache_k_shape = [
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.qk_nope_head_dim + config.qk_rope_head_dim,
]
cache_v_shape = [
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.v_head_dim,
]
cache_kvs.append(cache_k_shape)
cache_kvs.append(cache_v_shape)
return cache_kvs
if config.mla_use_matrix_absorption:
for _ in range(config.num_hidden_layers):
cache_latent_shape = [

Check warning on line 1018 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1016-L1018

Added lines #L1016 - L1018 were not covered by tests
max_block_nums,
1,
config.block_size,
config.kv_lora_rank + config.qk_rope_head_dim,
]
cache_kvs.append(cache_latent_shape)
cache_kvs.append(None)
return cache_kvs

Check warning on line 1026 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1024-L1026

Added lines #L1024 - L1026 were not covered by tests
else:
for _ in range(config.num_hidden_layers):
cache_k_shape = [

Check warning on line 1029 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1028-L1029

Added lines #L1028 - L1029 were not covered by tests
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.qk_nope_head_dim + config.qk_rope_head_dim,
]
cache_v_shape = [

Check warning on line 1035 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1035

Added line #L1035 was not covered by tests
max_block_nums,
config.num_key_value_heads // max(config.tensor_parallel_degree, 1),
config.block_size,
config.v_head_dim,
]
cache_kvs.append(cache_k_shape)
cache_kvs.append(cache_v_shape)
return cache_kvs

Check warning on line 1043 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1041-L1043

Added lines #L1041 - L1043 were not covered by tests

def prepare_inputs_for_generation(self, **kwargs):
# only last token for inputs_ids if cache is defined in kwargs
Expand Down
Loading

0 comments on commit 7ec7f02

Please sign in to comment.