diff --git a/CMakeLists.txt b/CMakeLists.txt index c8f19de94e59b..92b18924b80ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -294,6 +294,46 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) + +# +# _vmm extension (VMM) +# + +set(VLLM_VMM_EXT_SRC + "csrc/vmm/torch_bindings.cpp" + "csrc/vmm/vmm.cu" + ) + +define_gpu_extension_target( + _vmm_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_VMM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) + +# +# dattn extension (DATTN) +# + +set(VLLM_DATTN_EXT_SRC + "csrc/dattn/torch_bindings.cpp" + "csrc/dattn/dattn.cu" + ) + +define_gpu_extension_target( + _dattn_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_DATTN_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) + + # If CUTLASS is compiled on NVCC >= 12.5, it by default uses # cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the # driver API. This causes problems when linking with earlier versions of CUDA. diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index bcd170411e7cb..fe3104eec9e07 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -21,7 +21,8 @@ #include #include #include - +#include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -81,6 +82,39 @@ inline __device__ float block_sum(float* red_smem, float sum) { return VLLM_SHFL_SYNC(sum, 0); } +template +inline __device__ float propogate_qk_max(float* red_smem, float qk_max) { + // Decompose the thread index into warp / lane. + int warp_idx = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + + // Broadcast the max qk value to all threads. + qk_max = VLLM_SHFL_SYNC(qk_max, 0); + + return qk_max; +} + // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). template ( k_ptr + offset1 * BLOCK_SIZE * x + offset2); @@ -293,6 +334,13 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot( q_vecs[thread_group_offset], k_vecs); + +#if 0 + if(head_idx == 0 && seq_idx == 0) { + printf("[%d, %d, %d]: scale %f token_idx-%d, physical_block_offset %d, qk-%f\n", blockIdx.x, blockIdx.y, threadIdx.x, scale, token_idx, physical_block_offset, qk); + } +#endif + // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; @@ -306,28 +354,32 @@ __device__ void paged_attention_kernel( } } } - - // Perform reduction across the threads in the same warp to get the - // max qk value for each "warp" (not across the thread block yet). - // The 0-th thread of each thread group already has its max qk value. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - __syncthreads(); - - // TODO(woosuk): Refactor this part. - // Get the max qk value for the sequence. - qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - // Broadcast the max qk value to all threads. - qk_max = VLLM_SHFL_SYNC(qk_max, 0); + qk_max = propogate_qk_max(&red_smem[0], qk_max); + //if(threadIdx.x == 0) { + // printf("[%d, %d, %d]: qk_max %f, kv_head_stride %d num_queries_per_kv %d, q_stride-%d\n", blockIdx.x, blockIdx.y, threadIdx.x, qk_max, kv_head_stride, num_queries_per_kv, q_stride); + //} + +// // Perform reduction across the threads in the same warp to get the +// // max qk value for each "warp" (not across the thread block yet). +// // The 0-th thread of each thread group already has its max qk value. +// #pragma unroll +// for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { +// qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); +// } +// if (lane == 0) { +// red_smem[warp_idx] = qk_max; +// } +// __syncthreads(); +// +// // TODO(woosuk): Refactor this part. +// // Get the max qk value for the sequence. +// qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +// #pragma unroll +// for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +// qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); +// } +// // Broadcast the max qk value to all threads. +// qk_max = VLLM_SHFL_SYNC(qk_max, 0); // Get the sum of the exp values. float exp_sum = 0.f; @@ -668,6 +720,414 @@ __global__ void paged_attention_v2_reduce_kernel( } } +template +__global__ void dattention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + int64_t layer_offset, // layer offset in the units + int64_t whole_block_size, // whole block size (bytes), including KV of all layers together + int64_t max_seq_len, + const int64_t* cache_row_mapping, // [num_tokens] record cache ptr for this token + const int64_t* cache_col_mapping, // [num_tokens] record token index of the sequence + const int* __restrict__ seq_lens, // [num_seqs] + const int64_t q_stride, + const int64_t num_kv_heads, // [num_heads] + const float scale, + const float* __restrict__ alibi_slopes, // [num_heads] + const float kv_scale +) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int KV_HEAD_STRIDE = HEAD_SIZE * BLOCK_SIZE; + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_t); + float qk_max = -FLT_MAX; + +#if 0 + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } +#endif + + // NOTE: cache_row_idx or cache_col_idx can be -1 if the token is padded + cache_t * cache_start = reinterpret_cast(cache_row_mapping[seq_idx]); + + // Iterate over the key blocks. + // Each thread block will process one request's one head and one partition (up to 512 tokens) + // Each warp will process a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes dot product with the query. + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + #if 0 + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + #endif + // computing the starting address of the block for the given layer + cache_t * key_cache = cache_start + block_idx*whole_block_size + layer_offset; + + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + // Load a key to registers. Inside a block, each thread group will fetch lane/THREAD_GROUP_SIZe + // Each thread in a thread group has a different part of the key. + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; // token index inside the block + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + + cache_t* k_ptr = key_cache + kv_head_idx * KV_HEAD_STRIDE + physical_block_offset * x; + + #pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; +#if 0 + if(head_idx == 0 && seq_idx == 0) { + const cache_t * test_ptr = k_ptr + offset1 * BLOCK_SIZE * x + offset2; + printf("[%d, %d, %d]: token_idx-%d, physical_block_offset %d, j-%d, offset1-%d, offset2-%d, addr-%p\n", blockIdx.x, blockIdx.y, threadIdx.x, token_idx, physical_block_offset, j, offset1, offset2,test_ptr); + } +#endif + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, kv_scale); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); +#if 0 + if(head_idx == 0 && seq_idx == 0) { + printf("[%d, %d, %d]: token_idx-%d, physical_block_offset %d, qk-%f\n", blockIdx.x, blockIdx.y, threadIdx.x, token_idx, physical_block_offset, qk); + } +#endif + + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= seq_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across all threads in the same thread block + qk_max = propogate_qk_max(&red_smem[0], qk_max); + //if(threadIdx.x == 0) { + // printf("[%d, %d, %d]: scale %f qk_max %f. layer_offset %ld, kv_head_stride %d - %d. q_stride %ld\n", blockIdx.x, blockIdx.y, threadIdx.x, scale, qk_max, layer_offset, KV_HEAD_STRIDE, kv_head_stride, q_stride); + //} + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using V_quant_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { +#if 0 + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } +#endif + + // Load a key to registers. Inside a block, each thread group will fetch lane/THREAD_GROUP_SIZe + // Each thread in a thread group has a different part of the key. + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + // computing the starting address of the block + cache_t* v_ptr = cache_start + block_idx*whole_block_size + whole_block_size/2 + layer_offset + kv_head_idx * KV_HEAD_STRIDE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + kv_scale); + } + if (block_idx == num_seq_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + //if(seq_idx < 1) + // printf("[%d, %d, %d], qk_max %f: row_idx %d accs[%d]:%f\n", blockIdx.x, blockIdx.y, threadIdx.x, qk_max, row_idx, i, accs[i]); + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + //printf("[%d, %d, %d], seq_idx-%d, num_heads-%d, max_num_partitions-%d, HEAD_SIZE-%d, head_idx-%d, partition_idx-%d, NUM_ROWS_PER_THREAD-%d, NUM_V_VECS_PER_ROW-%d\n", blockIdx.x, blockIdx.y, threadIdx.x,seq_idx, num_heads, max_num_partitions, HEAD_SIZE, head_idx, partition_idx, NUM_ROWS_PER_THREAD, NUM_V_VECS_PER_ROW); + + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + //printf("[%d, %d, %d]: seq_idx-%d, row_idx-%d, accs[%d]:%f\n",blockIdx.x, blockIdx.y, threadIdx.x,seq_idx,row_idx,i,accs[i]); + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + } // namespace vllm #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ @@ -996,6 +1456,191 @@ void paged_attention_v2( CALL_V2_LAUNCHER_BLOCK_SIZE) } +#define LAUNCH_DATTENTION(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::dattention_kernel), \ + shared_mem_size); \ + if(max_num_partitions > 1) { \ + vllm::dattention_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr,\ + query_ptr, \ + layer_offset, \ + whole_block_size, max_seq_len, \ + row_ptr, \ + col_ptr, \ + seq_lens_ptr, \ + q_stride, num_kv_heads, scale, \ + alibi_slopes_ptr, kv_scale); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); \ + } \ + else \ + vllm::dattention_kernel \ + <<>>( \ + nullptr, nullptr, out_ptr,\ + query_ptr, \ + layer_offset, \ + whole_block_size, max_seq_len, \ + row_ptr, \ + col_ptr, \ + seq_lens_ptr, \ + q_stride, num_kv_heads, scale, \ + alibi_slopes_ptr, kv_scale); + +template +void dattention_launcher( + torch::Tensor& output, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, // [num_seqs, num_heads, head_size] + int64_t layer_idx, + int64_t num_layers, + int64_t max_seq_len, + torch::Tensor & seq_lens, + torch::Tensor & cache_row_mapping, + torch::Tensor & cache_col_mapping, + int64_t num_kv_heads, + double scale, + const c10::optional& alibi_slopes, + double kv_scale +) { + int64_t num_seqs = query.size(0); + int64_t num_heads = query.size(1); + int64_t head_size = query.size(2); + + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + + int64_t key_block_size = (num_heads * head_size) * BLOCK_SIZE; + int64_t layer_block_size = key_block_size * 2; + int64_t whole_block_size = layer_block_size * num_layers; + int64_t layer_offset = layer_idx * key_block_size; + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + scalar_t* out_ptr = reinterpret_cast(output.data_ptr()); + float* exp_sums_ptr = nullptr; + float* max_logits_ptr = nullptr; + scalar_t* tmp_out_ptr = nullptr; + if(max_num_partitions > 1) { + exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + } + scalar_t* query_ptr = reinterpret_cast(query.data_ptr()); + int* seq_lens_ptr = seq_lens.data_ptr(); + int64_t * row_ptr = reinterpret_cast(cache_row_mapping.data_ptr()); + int64_t * col_ptr = reinterpret_cast(cache_col_mapping.data_ptr()); + + int64_t q_stride = query.stride(0); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, max_num_partitions); + + // each thread block will be 128 threads + dim3 block(NUM_THREADS); + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_DATTENTION(64); + break; + case 80: + LAUNCH_DATTENTION(80); + break; + case 96: + LAUNCH_DATTENTION(96); + break; + case 112: + LAUNCH_DATTENTION(112); + break; + case 128: + LAUNCH_DATTENTION(128); + break; + case 192: + LAUNCH_DATTENTION(192); + break; + case 256: + LAUNCH_DATTENTION(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +void dattention( + torch::Tensor& output, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + int64_t layer_idx, + int64_t num_layers, + int64_t block_size, + int64_t max_seq_len, + torch::Tensor & seq_lens, + torch::Tensor & cache_row_mapping, + torch::Tensor & cache_col_mapping, + const std::string& kv_cache_dtype, + int64_t num_kv_heads, + double scale, + const c10::optional& alibi_slopes, + double kv_scale + ) { + assert(block_size == 16 || block_size == 32); + + if (kv_cache_dtype == "auto" && block_size == 16) { + if (query.dtype() == at::ScalarType::Float) { + dattention_launcher( + output, exp_sums, max_logits, tmp_out, query, layer_idx, num_layers, max_seq_len, + seq_lens, cache_row_mapping, cache_col_mapping, + num_kv_heads, scale, alibi_slopes, kv_scale); + } else if (query.dtype() == at::ScalarType::Half) { + dattention_launcher( + output, exp_sums, max_logits, tmp_out, query, layer_idx, num_layers, max_seq_len, + seq_lens, cache_row_mapping, cache_col_mapping, + num_kv_heads, scale, alibi_slopes, kv_scale); + } + } + else if (kv_cache_dtype == "auto" && block_size == 32) { + if (query.dtype() == at::ScalarType::Float) { + dattention_launcher( + output, exp_sums, max_logits, tmp_out, query, layer_idx, num_layers, max_seq_len, + seq_lens, cache_row_mapping, cache_col_mapping, + num_kv_heads, kv_scale, alibi_slopes, scale); + } else if (query.dtype() == at::ScalarType::Half) { + dattention_launcher( + output, exp_sums, max_logits, tmp_out, query, layer_idx, num_layers, max_seq_len, + seq_lens, cache_row_mapping, cache_col_mapping, + num_kv_heads, kv_scale, alibi_slopes, scale); + } + } + else { + printf("errors for dattention_launcher: dtype: %s, block_size %ld!!\n", kv_cache_dtype.c_str(), block_size); + exit(0); + } +} + #undef WARP_SIZE #undef MAX #undef MIN diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daaa..cb862c9074891 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -31,3 +31,24 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); + +// new add for vmm +void reshape_and_cache_vmm( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [max_batch_size, max_seq_len, num_heads, head_size] + torch::Tensor& value_cache, // [max_batch_size, max_seq_len, num_heads, head_size] + torch::Tensor& cache_row_mapping, // [num_tokens] record key/value write to which batch row in cache + torch::Tensor& cache_col_mapping, // [num_tokens] record key/value write to which token col in cache + const std::string& kv_cache_dtype); + +// new add for dAttention +void reshape_and_cache_dattn( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + int64_t layer_idx, // which layer to reshape + int64_t num_layers, // number of layers + int64_t block_size, // size for each layer's cache block (including kv cache) + torch::Tensor& cache_row_mapping, // [num_tokens] record key/value write to which batch row in cache + torch::Tensor& cache_col_mapping, // [num_tokens] record key/value write to which token col in cache + const std::string& kv_cache_dtype); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43c..ca43562790015 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -245,6 +245,125 @@ __global__ void reshape_and_cache_flash_kernel( } } } + +template +__global__ void reshape_and_cache_vmm_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [max_batch_size, max_seq_len, num_heads, head_size] + scalar_t* __restrict__ v_cache, // [max_batch_size, max_seq_len, num_heads, head_size] + const int64_t* __restrict__ cache_row_mapping, // [num_tokens] record key/value write to which batch row in cache + const int64_t* __restrict__ cache_col_mapping, // [num_tokens] record key/value write to which token col in cache + const int cache_batch_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size) { + const int64_t token_idx = blockIdx.x; + // // NOTE: cache_row_idx or cache_col_idx can be -1 if the token is padded + const int64_t cache_row_idx = cache_row_mapping[token_idx]; + const int64_t cache_col_idx = cache_col_mapping[token_idx]; + if (cache_row_idx < 0 || cache_col_idx < 0) { + return; + } + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int64_t tgt_idx = cache_row_idx * cache_batch_stride + cache_col_idx * n + i; + // const int head_idx = i / head_size; + // const int head_offset = i % head_size; + + // const int64_t tgt_idx = cache_row_idx * cache_batch_stride + cache_col_idx * num_heads * head_size + head_idx * head_size + head_offset; + + k_cache[tgt_idx] = key[src_key_idx]; + v_cache[tgt_idx] = value[src_value_idx]; + } +} + + +template +// TODO: make block_size, kv_block_size, head_size, key_stride to be constant number (constexpr) +// Then we can save some computation overhead during execution +__global__ void reshape_and_cache_dattn_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + int heads_per_thread_block, // number of heads for each thread block + int64_t block_size, // number of tokens inside a block + int64_t kv_block_size, // key or value block size in number of bytes for each layer + int64_t layer_offset, // layer offset in the units + int64_t whole_block_size, // whole block size (bytes), including KV of all layers together + const int64_t* cache_row_mapping, // [num_tokens] record cache ptr for this token + const int64_t* cache_col_mapping, // [num_tokens] record token index of the sequence + const int key_stride, const int value_stride, + const int head_size) { + // The index of the token + const int64_t index = blockIdx.x; + // The total number of heads for the model + const int64_t num_heads = gridDim.y * heads_per_thread_block; + constexpr int x = 16 / sizeof(cache_t); + + // NOTE: cache_row_idx or cache_col_idx can be -1 if the token is padded + const int64_t cache_address = cache_row_mapping[index]; + // token index in the sequence, which determins the position of kv cache + const int64_t token_idx = cache_col_mapping[index]; + if (cache_address <= 0 || token_idx < 0) { + return; + } + + // Note: each thread block is in charge of 4 heads of the same token + // Therefore, each warp is in charge of 1 head of the same token + int64_t block_idx = token_idx / block_size; + + // token index inside the current block: [0, 16) + int64_t block_offset = token_idx % block_size; + + // compute the block index for the current thread block + const int64_t head_block_idx = blockIdx.y; + + // Each head will be handled will be handled by each warp + int64_t warp_idx = threadIdx.x / WARP_SIZE; //[0,3] + assert (warp_idx <= 3); + + // head_idx for the token, should be less than num_heads. + int64_t head_idx = head_block_idx * heads_per_thread_block + warp_idx; + + // kv_block_size == head_size * block_size + // Compute the start address of the head of the block for KV cache + int64_t head_start = block_idx * whole_block_size + layer_offset + head_idx * kv_block_size ; + + int64_t thread_idx_in_warp = threadIdx.x % WARP_SIZE; + + cache_t* dest_key = reinterpret_cast(cache_address) + head_start; + + // whole_block_size: 2 * (num_heads * head_size * block_size * layers) + cache_t* dest_value = dest_key + whole_block_size/2; + + // Each thread block will copy one token's 4 heads, while each warp will copy one token's one head only + // since key: [num_tokens, num_heads, head_size] + //int64_t src_offset = index * num_heads * head_size + head_idx * head_size; + int64_t src_offset = index * key_stride + head_idx * head_size; + scalar_t* src_key = const_cast(key + src_offset); + scalar_t* src_value = const_cast(value + src_offset); + + // Each warp will handle only one token's one head + for (int i = thread_idx_in_warp; i < head_size; i += WARP_SIZE) { + // i == head_offset + // We are going to transfer [0,head_size) to [head_size/x, block_size, x] + int x_idx = i / x; + int x_offset = i % x; + + // [num_blocks, num_heads, head_size/x, block_size, x] + int64_t tgt_key_idx = x_idx * block_size * x + block_offset * x + x_offset; + dest_key[tgt_key_idx] = src_key[i]; + + // [num_blocks, num_heads, head_size, block_size] + int64_t tgt_value_idx = i * block_size + block_offset; + //if(head_idx == 0 && i < 8) + // printf("[%d, %d, %d]: index %ld offset [%d, %ld, %d] at %p\n", blockIdx.x, blockIdx.y, threadIdx.x, tgt_value_idx, x_idx, block_offset, x_offset, &dest_value[i]); + dest_value[tgt_value_idx] = src_value[i]; + } +} + } // namespace vllm // KV_T is the stored data type of kv-cache. @@ -329,6 +448,116 @@ void reshape_and_cache_flash( CALL_RESHAPE_AND_CACHE_FLASH); } +void reshape_and_cache_vmm( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [max_batch_size, max_seq_len, num_heads, head_size] + torch::Tensor& v_cache, // [max_batch_size, max_seq_len, num_heads, head_size] + torch::Tensor& cache_row_mapping, // [num_tokens] record key/value write to which batch row in cache + torch::Tensor& cache_col_mapping, // [num_tokens] record key/value write to which token col in cache + const std::string& kv_cache_dtype) { + + if (kv_cache_dtype != "auto") { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int cache_batch_stride = k_cache.stride(0); + TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "reshape_and_cache_vmm", [&] { + vllm::reshape_and_cache_vmm_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + k_cache.data_ptr(), v_cache.data_ptr(), + cache_row_mapping.data_ptr(),cache_col_mapping.data_ptr(), + cache_batch_stride, key_stride, value_stride, num_heads, head_size); + }); + +} + +#define CALL_RESHAPE_AND_CACHE_DATTN(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_dattn_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + heads_per_thread_block, block_size, \ + kv_block_size, layer_offset, \ + whole_block_size, \ + cache_row_mapping.data_ptr(), \ + cache_col_mapping.data_ptr(), \ + key_stride, value_stride, \ + head_size); + +void reshape_and_cache_dattn( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + int64_t layer_idx, // which layer to reshape + int64_t num_layers, // number of layers + int64_t block_size, // the number of tokens inside a block + torch::Tensor& cache_row_mapping, // [num_tokens] record key/value write to which batch row in cache + torch::Tensor& cache_col_mapping, // [num_tokens] record key/value write to which token col in cache + const std::string& kv_cache_dtype) { + + if (kv_cache_dtype != "auto") { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } + int num_tokens = key.size(0); + int num_heads = key.size(1); + const int head_size = key.size(2); + + const int key_stride = key.stride(0); + const int value_stride = value.stride(0); + + //printf("hihihi, num_tokens %d, head_size %d, key_stride %d, value_stride %d\n", num_tokens, head_size, key_stride, value_stride); + // We will dynamically decide heads_per_thread_block + int heads_per_thread_block = 4; + assert(num_heads % heads_per_thread_block == 0); + + int sm_for_heads = num_heads/heads_per_thread_block; + + int64_t kv_block_size = head_size * block_size; + int64_t whole_block_size = kv_block_size * num_heads * num_layers * 2; + int64_t layer_offset = layer_idx * kv_block_size * num_heads; + + //printf("key_block_size-%d, whole_block_size-%ld, num_layers-%d\n", key_block_size, whole_block_size, num_layers); + dim3 grid(num_tokens, sm_for_heads); + + // each thread block will be 128 threads + dim3 block(128); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_DATTN) + /* + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "reshape_and_cache_dattn", [&] { + vllm::reshape_and_cache_dattn_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + heads_per_thread_block, block_size, + kv_block_size, layer_offset, + whole_block_size, + cache_row_mapping.data_ptr(), + cache_col_mapping.data_ptr(), + key_stride, value_stride, + head_size); + }); + */ +} + namespace vllm { template diff --git a/csrc/dattn/dattn.cu b/csrc/dattn/dattn.cu new file mode 100644 index 0000000000000..a7de097d372ca --- /dev/null +++ b/csrc/dattn/dattn.cu @@ -0,0 +1,470 @@ +/* + Copyright (c) ByteDance Inc. + Authors: + - Tongping Liu (tongping.liu@bytedance.com) + - https://github.com/vllm-project/vllm/pull/6102/commits + */ + +#include +#include +#include +#include +#include +#include "dattn.h" + + +#define KV_UTILIZATION_RATE (0.9) + +static CUmemAllocationProp _prop = {}; +static CUmemAccessDesc _accessDescr = {}; +/* + In this allocator, we only have the following concepts, but without the concept of tokens. + The python portion should convert the number of tokens to tokens depending on their block_size (e.g., 16) + Region: virtual address space for a request. Currently, we support the space for max_seq_len. + */ +static uint64_t roundup(uint64_t size, uint64_t align_size) { + return ((size + align_size - 1)/align_size) * align_size; +} + +static int allocatePhyPages(void * ptr, uint64_t size) { + CUdeviceptr dptr = (CUdeviceptr)ptr; + + CUdevice dev; // device + CHECK_DRV(cuCtxGetDevice(&dev)); + _prop.location.id = dev; + _accessDescr.location = _prop.location; + + CUresult status = CUDA_SUCCESS; + CUmemGenericAllocationHandle allocationHandle; + if ((status = cuMemCreate(&allocationHandle, size, &_prop, 0)) == CUDA_SUCCESS) { + if ((status = cuMemMap(dptr, size, 0ULL, allocationHandle, 0ULL)) == CUDA_SUCCESS) { + if ((status = cuMemSetAccess(dptr, size, &_accessDescr, 1)) != CUDA_SUCCESS) { + fprintf(stderr, "cuMemMap success,but cuMemSetAccess failed!, err code: %d\n", status); + cuMemUnmap(dptr, size); + } + } + // always release the handle, but the memory is accessible util cuMemUnmap + if((status = cuMemRelease(allocationHandle)) != CUDA_SUCCESS) { + fprintf(stderr, "cuMemRelease failed, err code: %d\n", status); + } + } else { + fprintf(stderr, "cuMemCreate failed!, err code: %d\n", status); + } + return status == CUDA_SUCCESS ? 0 : -1; +} + +// Free the physical memory [ptr, ptr + size] +static void freePhysicalMemory(void* ptr, size_t size) { + CUdeviceptr dptr = (CUdeviceptr)ptr; + CHECK_DRV(cuMemUnmap(dptr, size)); + CHECK_DRV(cuMemAddressFree(dptr, size)); +} + +/* +** kvCacheRegion functions implementation +*/ +kvCacheRegion::kvCacheRegion(uint64_t region_size, uint64_t block_size, uint64_t page_size, CUdeviceptr ptr) { + this->region_size = region_size; + this->block_size = block_size; + this->page_size = page_size; + this->dptr = reinterpret_cast(ptr); + this->nextUnmapedAddr = reinterpret_cast(ptr); + + this->offset = 0; + this->total_pages = 0; + this->used_pages = 0; +} + +// Decontructor: release all physical pages of this region +kvCacheRegion::~kvCacheRegion() { + uint64_t size = this->total_pages * this->page_size; + freePhysicalMemory(this->dptr, size); + + // Note that since the region is detroyed, + // no need to clear other counters. +} + +/* +// get CUdeviceptr dptr +CUdeviceptr kvCacheRegion::getDeviceDptr(void) { + return reinterpret_cast(this->dptr); +} + +// get void * type pointer +void* kvCacheRegion::getVoidDptr(void) { + return reinterpret_cast(this->dptr); +} +*/ + +uint64_t kvCacheRegion::getAllocPhyPages(void) { + return this->total_pages; +} + +uint64_t kvCacheRegion::getUsedPhysicalPages(void) { + return this->used_pages; +} + +/* + kvCacheRegion function: allocate cached blocks + if the return value > 0, then it is succesful. + */ +int64_t kvCacheRegion::allocCacheBlocks(uint64_t blocks, uint64_t * used_pages) { + uint64_t size = blocks * this->block_size; + + int64_t toallocPages = -1; + + // Align the new offset to page_size + uint64_t alignedOffset = roundup(this->offset + size, this->page_size); + + // Check how many pages should we allocated this time + char * alignedAddr = this->dptr + alignedOffset; + if( alignedAddr > this->nextUnmapedAddr) { + + // Check whether alignedAddr is actually aligned well + assert((alignedAddr - this->nextUnmapedAddr)%this->page_size == 0); + + toallocPages = (alignedAddr - this->nextUnmapedAddr)/this->page_size; + + assert(toallocPages >= 0); + + uint64_t allocSize = toallocPages * this->page_size; + + // Allocate physical pages, which will exit if can't allocate successfully + if (toallocPages > 0 && allocatePhyPages(this->nextUnmapedAddr, allocSize) == 0) { + this->nextUnmapedAddr = alignedAddr; + + // Update the used pages correspondingly. The statement works even when this->offset is not aligned to page_size + *used_pages += toallocPages; + + // Update the offset after allocating these blocks. + this->offset += size; + assert(this->offset <= alignedOffset); + } + } + + return toallocPages; +} + +// freeUnusedPages from a region, and return freed pages +int kvCacheRegion::freeUnusedPages(void) { + int freedPages = 0; + + // Free pages only when total_pages is larger than used_pages + if(this->total_pages > this->used_pages) { + assert(this->nextUnmapedAddr > (this->dptr + offset)); + + // Get the offset of next page, since we can't collect a page if its partialy used + uint64_t alignedOffset = roundup(offset, this->page_size); + + // startAddr points to the beginning of the next page + char * startAddr = this->dptr + alignedOffset; + + uint64_t size = this->nextUnmapedAddr - startAddr; + assert((size % this->page_size) == 0); + + freedPages = size/this->page_size; + // free all unused pages of this region. + // If a page is partially used, then it cannot be freed + if(size > 0) { + freePhysicalMemory(startAddr, size); + this->total_pages -= freedPages; + this->nextUnmapedAddr = startAddr; + // No need to change offset here. + } + } + + return freedPages; +} + +/* +** kvCacheAllocator functions implementation +*/ +kvCacheAllocator::kvCacheAllocator(int64_t max_seq_length, int64_t layers_num, int64_t heads_num, int64_t head_size, int64_t tokens_per_block, int64_t dtype_size) { + uint64_t key_cache_block_per_layer = tokens_per_block * heads_num * head_size * dtype_size; + uint64_t value_cache_block_per_layer = key_cache_block_per_layer; + uint64_t cache_block_size = (key_cache_block_per_layer + value_cache_block_per_layer) * layers_num; + + //fprintf(stderr, "kvCacheAllocator initialization: key_cache_block_per_layer-%d, cache_block_size-%d\n", key_cache_block_per_layer, cache_block_size); + // Getting the cuda device and force the initialization + CUdevice dev; // device + CHECK_RT(cudaFree(0)); // Force and check the initialization of the runtime + CHECK_DRV(cuCtxGetDevice(&dev)); + + size_t aligned_sz; + _prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + _prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + _prop.location.id = dev; + _accessDescr.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + _accessDescr.location = _prop.location; + + CHECK_DRV(cuMemGetAllocationGranularity(&aligned_sz, &_prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + uint64_t max_blocks = roundup(max_seq_length, tokens_per_block)/tokens_per_block; + uint64_t region_size = max_blocks * cache_block_size; + + this->page_size = aligned_sz; + this->region_size = ((region_size + aligned_sz - 1) / aligned_sz) * aligned_sz; + this->block_size = cache_block_size; + + printf("kvCacheAllocator: page_size-%d, region_size-%d, block_size-%d\n", this->page_size, this->region_size, this->block_size); + + // TODO: finding out how much physical blocks it includes. This is just for the reference or watermark, as + // there is no need to rely on pre-assigned values if physical blocks are allocated on-demand + size_t freeMem, totalMem; + CHECK_RT(cudaMemGetInfo(&freeMem, &totalMem)); + + this->watermark_pages = (((uint64_t)(freeMem * KV_UTILIZATION_RATE))/this->page_size); + + // Doing other initialization + this->total_pages = 0; + this->used_pages = 0; + this->active_regions = 0; +} + +int64_t kvCacheAllocator::getPageSize() { + return this->page_size; +} + + +// reserve function, reserve virtual address space for a request, and also allocate the first physical block +int64_t kvCacheAllocator::reserveRegion(int64_t req_id) { + CUdeviceptr ptr; + kvCacheRegion * region = nullptr; + + // Check whether there are some cached regions + if(this->cached_regions.size()) { + // Pop the latest region from cached vector, which is more efficient and therefore it is the default method + region = _getLastCachedRegion(); + } + else { + //printf("region_size == %d bytes == %d MB\n", this->region_size, this->region_size%2097152); + // The expensive way to get a new region. Only invoked when no cached regions + // Allocate the virtual address for this region + CHECK_DRV(cuMemAddressReserve(&ptr, this->region_size, 0ULL, 0ULL, 0ULL)); + + // Create a new region from the scratch + region = new kvCacheRegion(this->region_size, this->block_size, this->page_size, ptr); + } + + std::lock_guard lock(this->mutex); + + // Record the region information + this->active_regions += 1; + //fprintf(stderr, "Reserve region for req_id - %d, with total %d\n", req_id, this->active_regions); + this->active_regions_map[req_id] = region; + + return static_cast(ptr); +} + +// Release the region with the given req_id +void kvCacheAllocator::_releaseRegion(int64_t req_id) { + // Find the region corresponding to the given req_id + if(this->active_regions_map.count(req_id) != 0) { + fprintf(stderr, "ERROR: req_id-%d does not exist at all.!\n", req_id); + exit(-1); + } + + std::lock_guard lock(this->mutex); + + kvCacheRegion * region = this->active_regions_map[req_id]; + // Delete this region from active_regions_map that only keep + this->active_regions_map.erase(req_id); + this->active_regions--; + // Note that as we don't actually release physical cache blocks. + // Therefore, we don't need to change the active_blocks here. + + // Cache the given region, as it can be used for the future ideally. + // In order to reduce the overhead of memory management, we did not + // reclaim physical blocks until necessary. + _cacheReleasedRegion(region); +} + +// Cache the released region. Don't release the virtual address and physical cache blocks +void kvCacheAllocator::_cacheReleasedRegion(kvCacheRegion * region) { + this->cached_regions.push_back(region); +} + +// Get the lastly-released region. If the region has some physical blocks, +// they will be re-utilized as well. +// Note that using cached regions is way more efficient than allocating a new region +kvCacheRegion * kvCacheAllocator::_getLastCachedRegion(void) { + assert(!this->cached_regions.empty()); + + kvCacheRegion * region = this->cached_regions.back(); + this->cached_regions.pop_back(); + + return region; +} + +// This function is invoked when the number of physical pages is above +// the preset threshold. It performs the garbage collecton of physical pages +void kvCacheAllocator::_gcPhyPages(int64_t toCollectPages) { + + assert(toCollectPages > 0); + + // first, collect the pages in cached regions. + kvCacheRegion * region; + + // First, collect pages from cached_regions as it won't affect active requests. + while(!this->cached_regions.empty() && toCollectPages > 0) { + // Release Least-Recently-Used regions at first + region = this->cached_regions.front(); + this->cached_regions.pop_front(); + + int pages = region->getAllocPhyPages(); + if(pages > 0) { + this->total_pages -= pages; + toCollectPages -= pages; + } + + // deconstruct this region, which will collect all physical pages inside + delete region; + } + + // Check active regions if necessary + while(toCollectPages > 0) { + // Collect pages from active regions + for(auto it = this->active_regions_map.begin(); it != this->active_regions_map.end(); it++) { + // it->second points to the region + region = it->second; + + int pages = region->freeUnusedPages(); + if(pages > 0) { + // Update the total_pages for the allocator + this->total_pages -= pages; + + toCollectPages -= pages; + } + + // Exit the loop if we collect enough pages + if(toCollectPages <= 0) { + break; + } + } + } + +} + +// alloc function, allocate physical memory, map to the reserved virtual address +// This function is designed for both prefill and decoding phase, where prefill may +// require to save KV cache of multiple tokens, which should not invoke this function multiple times. +// Similarly, the python code may get the physical blocks for multiple tokens during the decoding phase +// Note that the allocator doesn't care about tokens (which should be handled by the python code), but only blocks here. +int64_t kvCacheAllocator::_allocCacheBlocksForRequest(int64_t req_id, int64_t blocks) { + int64_t pages = -1; + + // Find the region corresponding to the given req_id, which should reserveRegion before + // If the req_id doesn't exist at all, it is the bug that should be fixed. + if(this->active_regions_map.count(req_id) == 0) { + fprintf(stderr, "ERROR: req_id %d does not exist at all.!\n", req_id); + exit(-1); + } + + std::lock_guard lock(this->mutex); + + kvCacheRegion * region = this->active_regions_map[req_id]; + + pages = region->allocCacheBlocks(blocks, &this->used_pages); + + if(pages > 0) { + this->total_pages += pages; + + // check whether we need to purge physical memory + if(this->total_pages >= this->watermark_pages && this->total_pages > this->used_pages) { + int toCollectPages = std::min(this->total_pages - this->used_pages, this->total_pages - this->watermark_pages); + + // Garbage collection for physical pages. + _gcPhyPages(toCollectPages); + } + } + + return pages; +} + +// Allocate cache blocks for a range of requests. Each request information will be an vector, with +// the request id as the first, and then number of blocks as the second. +int64_t kvCacheAllocator::allocCacheBlocks(std::vector> req_blocks) { + int64_t pages = 0; + + for(auto row : req_blocks) { + uint64_t req_id = row[0]; + uint64_t blocks = row[1]; + + //fprintf(stderr, "allocating req_id-%d and blocks-%d\n", req_id, blocks); + pages += _allocCacheBlocksForRequest(req_id, blocks); + } + + return pages; +} +// Release regions specified in the vector +void kvCacheAllocator::releaseRegions(std::vector regions) { + for(auto region : regions) { + _releaseRegion(region); + } +} + + +int64_t kvCacheAllocator::getAllocPhyPages(int64_t req_id) { + int64_t pages = 0; + + if(req_id == 0) { + pages = this->total_pages; + } + else { + // Find the region corresponding to the given req_id, which should reserveRegion before + // If the req_id doesn't exist at all, it is the bug that should be fixed. + if(this->active_regions_map.count(req_id) != 0) { + fprintf(stderr, "ERROR: req_id does not exist at all.!"); + exit(-1); + } + + std::lock_guard lock(this->mutex); + + kvCacheRegion * region = this->active_regions_map[req_id]; + pages = region->getAllocPhyPages(); + } + + return pages; +} + +void kvCacheAllocator::collectPhyPages(int64_t pages) { + if(pages == 0) { + // Collect pages defined by watermark + pages = std::min(this->total_pages - this->used_pages, this->total_pages - this->watermark_pages); + } + + _gcPhyPages(pages); + return; +} + +#if 0 +// TODO: we need to delete this function!!! +// free function, unmap the virtual address spaceļ¼Œrelease physical memory +// handles and free virtual address space +int64_t kvCacheAllocator::freeCacheBlock(const c10::intrusive_ptr& ptr) { + CUresult status = CUDA_SUCCESS; + if (ptr->dptr != 0) { + status = cuMemUnmap(ptr->dptr, ptr->reservedPageNum * pageSize); + // status = cuMemUnmap(ptr.dptr, ptr.allocatedPageNum * pageSize); + if (status != CUDA_SUCCESS) { + printf("cuMemUnmap failed! error-code: %d\n", status); + } else { + for (int i = 0; i < ptr->handles.size(); i++) { + status = cuMemRelease(ptr->handles[i]); + if (status != CUDA_SUCCESS) { + printf("cuMemRelease failed! error-code: %d\n", status); + return status; + } + } + ptr->handles.clear(); + + status = cuMemAddressFree(ptr->dptr, ptr->reservedPageNum * pageSize); + if (status != CUDA_SUCCESS) { + printf("cuMemAddressFree failed! error-code: %d\n", status); + } + } + } + return status; +} + +#endif \ No newline at end of file diff --git a/csrc/dattn/dattn.h b/csrc/dattn/dattn.h new file mode 100644 index 0000000000000..89770ff3056e3 --- /dev/null +++ b/csrc/dattn/dattn.h @@ -0,0 +1,167 @@ +#pragma once +/* + Copyright (c) ByteDance Inc. + Authors: + - Tongping Liu (tongping.liu@bytedance.com) + - https://github.com/vllm-project/vllm/pull/6102/commits + */ +//#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define _MB (1 << 20) + +using namespace std; + +static inline void +checkRtError(cudaError_t res, const char *tok, const char *file, unsigned line) { + if (res != cudaSuccess) { + std::cerr << file << ':' << line << ' ' << tok + << "failed (" << (unsigned)res << "): " << cudaGetErrorString(res) << std::endl; + abort(); + } +} + +#define CHECK_RT(x) checkRtError(x, #x, __FILE__, __LINE__); + +static inline void +checkDrvError(CUresult res, const char *tok, const char *file, unsigned line) { + if (res != CUDA_SUCCESS) { + const char *errStr = NULL; + (void)cuGetErrorString(res, &errStr); + std::cerr << file << ':' << line << ' ' << tok + << "failed (" << (unsigned)res << "): " << errStr << std::endl; + abort(); + } +} + +#define CHECK_DRV(x) checkDrvError(x, #x, __FILE__, __LINE__); + + +// kvCacheRegion class, to warp CUdeviceptr, used to kv-cache tensor +// record the reserved virtual address size and allocated physical memory size. +// TODO: we may avoid expose this class externally in the future. +class kvCacheRegion : public torch::CustomClassHolder{ +private: + char * dptr; + + // the number of bytes for the request's virtual address space (region) + uint64_t region_size; + + // the size of a kv cache block in bytes, which is NOT the number tokens inside a block + uint64_t block_size; + + // The real page_size supported by the hardware, which could be larger or smaller than block_size + uint64_t page_size; + + // The number of allocated physical pages for the current region. + uint64_t total_pages; + + // Note that total_pages can be larger than use_pages, as it may inherit a + // live region that already has many allocated physical pages. + uint64_t used_pages; + + // virtual address of the next page that needs to be mapped. + // Typically, (nextUnmapedAddr - dptr)/page_size == total_pagees + char * nextUnmapedAddr; + + // The difference between the used address (the end of invoking allocBlocks) and the starting pointer of the region + uint64_t offset; + +public: + + kvCacheRegion(uint64_t region_size, uint64_t block_size, uint64_t page_size, CUdeviceptr ptr); + + ~kvCacheRegion(); + + // get CUdeviceptr dptr + CUdeviceptr getStartDptr(); + + // get the number of physical pages + uint64_t getAllocPhyPages(void); + uint64_t getUsedPhysicalPages(void); + int64_t allocCacheBlocks(uint64_t blocks, uint64_t * used_pages); + int freeUnusedPages(void); +}; + + +// kvCacheAllocator class, used for memory allocation of kv-cachemanager, memory allocation is based on page granularity, +class kvCacheAllocator : public torch::CustomClassHolder{ +private: + + /* + The following information are about physical blocks. + + total_pages is the total number of physical pages that have been assigned from the allocator. + used_pages can be less than total_pages, as used_pages will be incremented only when allocCacheBlock is invoked. + */ + uint64_t total_pages; + uint64_t used_pages; + + // How many regions (requests) in this allocator + uint64_t active_regions; + + // If total_pages is larger than the watermark, then we will start to garbage collect physical pages + // More specifically, we will reclaim pages from cached regions at first, and then from active regions + uint64_t watermark_pages; + + uint64_t region_size; + uint64_t block_size; + uint64_t page_size; + CUdevice device; + std::mutex mutex; + + // the hashtable to record the relationship between regions and ptrs + unordered_map active_regions_map; + std::deque cached_regions; + + // Internal functions + void _cacheReleasedRegion(kvCacheRegion * region); + kvCacheRegion * _getLastCachedRegion(void); + void _gcPhyPages(int64_t toCollectPages); + void _initializeAllocHandles(void); + // Release the virtual address space for a region that is related to one request + void _releaseRegion(int64_t req_id); + // Allocate physical memory, map to the reserved virtual address space of dptr, and set access permission + int64_t _allocCacheBlocksForRequest(int64_t req_id, int64_t blocks = 1); + + +public: + + //kvCacheAllocator(); + // The default contructor. Otherwise, torch bindings will complain it. + kvCacheAllocator(int64_t max_seq_length, int64_t layers_num, int64_t heads_num, int64_t head_size, int64_t tokens_per_block, int64_t dtype_size); + // { + // Nothing to do + //} + + ~kvCacheAllocator() = default; + + //void initialization(int64_t max_seq_length, int64_t layers_num, int64_t heads_num, int64_t head_size, int64_t tokens_per_block, int64_t dtype_size); + + + // get the granularity of the physical memory allocation + int64_t getPageSize(void); + + // Reserve the virtual address space for a region that is related to one request + // In particular, the regionSize == 2 * max_seq_length * layers_num * heads_num * head_size * dtype_size + // "2" here is to allocate Key and Value cache together, which helps to reduce the fragmentation + int64_t reserveRegion(int64_t req_id); + + void releaseRegions(std::vector regions); + + int64_t allocCacheBlocks(std::vector> reqs_blocks); + + // Allow the python code to know the physical memory used for the whole + // kv cache or the memory for the specified request (when req_id is not 0). + int64_t getAllocPhyPages(int64_t req_id = 0); + void collectPhyPages(int64_t pages = 0); +}; \ No newline at end of file diff --git a/csrc/dattn/torch_bindings.cpp b/csrc/dattn/torch_bindings.cpp new file mode 100644 index 0000000000000..a5f22d3b673d5 --- /dev/null +++ b/csrc/dattn/torch_bindings.cpp @@ -0,0 +1,24 @@ +/* + Copyright (c) ByteDance Inc. + Authors: + - Tongping Liu (tongping.liu@bytedance.com) + - https://github.com/vllm-project/vllm/pull/6102/commits + */ +#include "registration.h" +#include "dattn.h" + + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + // kvCacheAllocator class bind + m.class_("kvCacheAllocator") + //.def(torch::init<>()) + .def(torch::init()) + .def("reserveRegion", &kvCacheAllocator::reserveRegion) + .def("releaseRegions", &kvCacheAllocator::releaseRegions) + .def("allocCacheBlocks", &kvCacheAllocator::allocCacheBlocks) + .def("getAllocPhyPages", &kvCacheAllocator::getAllocPhyPages) + .def("collectPhyPages", &kvCacheAllocator::collectPhyPages); + //.def("freeCacheBlock", &kvCacheAllocator::freeCacheBlock) +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 15e9ebe87408a..3f4934a4c42d6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -26,6 +26,25 @@ void paged_attention_v2( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); +void dattention( + torch::Tensor& output, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + int64_t layer_idx, + int64_t num_layers, + int64_t block_size, + int64_t max_seq_len, + torch::Tensor & seq_lens, + torch::Tensor & cache_row_mapping, + torch::Tensor & cache_col_mapping, + const std::string& kv_cache_dtype, + int64_t num_kv_heads, + double kv_scale, + const c10::optional& alibi_slopes, + double scale); + void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 045203c3de8a8..ce568d01adc21 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -47,6 +47,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); + ops.def( + "dattention(" + " Tensor! output, Tensor exp_sums, Tensor max_logits," + " Tensor tmp_out, Tensor query," + " int layer_idx, int num_layers, int block_size, int max_seq_len," + " Tensor seq_lens, Tensor row_mapping, Tensor col_mapping," + " str kv_cache_dtype, int num_kv_heads, float kv_scale," + " Tensor? alibi_slopes, float scale) -> ()" + ); + ops.impl("dattention", torch::kCUDA, &dattention); + // Activation ops // Activation function used in SwiGLU. ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); @@ -387,6 +398,29 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " "str kv_cache_dtype) -> ()"); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); + + // new add for vmm + cache_ops.def( + "reshape_and_cache_vmm(Tensor key, Tensor value," + " Tensor! key_cache," + " Tensor! value_cache," + " Tensor cache_row_mapping," + " Tensor cache_col_mapping," + " str kv_cache_dtype) -> ()"); + cache_ops.impl("reshape_and_cache_vmm", torch::kCUDA, + &reshape_and_cache_vmm); +// new add for dAttention + cache_ops.def( + "reshape_and_cache_dattn(Tensor key, Tensor value," + " int layer_idx," + " int num_layers," + " int block_size," + " Tensor cache_row_mapping," + " Tensor cache_col_mapping," + " str kv_cache_dtype) -> ()"); + cache_ops.impl("reshape_and_cache_dattn", torch::kCUDA, + &reshape_and_cache_dattn); + } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/csrc/vmm/torch_bindings.cpp b/csrc/vmm/torch_bindings.cpp new file mode 100644 index 0000000000000..1e1cb3b45ca19 --- /dev/null +++ b/csrc/vmm/torch_bindings.cpp @@ -0,0 +1,28 @@ +#include "registration.h" +#include "vmm.h" + + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + + // CacheDevicePtr class bind + m.class_("CacheDevicePtr") + .def(torch::init<>()) + // .def_readwrite("dptr", &CacheDevicePtr::dptr); + // dptr cann't bind success because of the type of dptr is CUdeviceptr(=unsigned long long), which is not supported in torch + .def_readwrite("reservedPageNum", &CacheDevicePtr::reservedPageNum) + .def_readwrite("allocatedPageNum", &CacheDevicePtr::allocatedPageNum); + + // CacheAllocator class bind + m.class_("CacheAllocator") + .def(torch::init<>()) + .def("setPageSize", &CacheAllocator::setPageSize) + .def("reserveCachePtr", &CacheAllocator::reserveCachePtr) + .def("allocCachePtr", &CacheAllocator::allocCachePtr) + .def("freeCachePtr", &CacheAllocator::freeCachePtr) + .def("releaseCachePtr", &CacheAllocator::releaseCachePtr); + + // other util functions bind + m.def("wrap_cache_ptr_to_tensor", &wrap_cache_ptr_to_tensor); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/csrc/vmm/vmm.cu b/csrc/vmm/vmm.cu new file mode 100644 index 0000000000000..e54339eec1410 --- /dev/null +++ b/csrc/vmm/vmm.cu @@ -0,0 +1,242 @@ +#include "vmm.h" + +#include + +#include +#include +#include +#include +#include + +/* +** CacheDevicePtr functions implementation +*/ + +CacheDevicePtr::CacheDevicePtr() + : dptr(0), reservedPageNum(0), allocatedPageNum(0) {} + +CacheDevicePtr::~CacheDevicePtr() { + if (dptr != 0) { + auto status = cuMemUnmap(dptr, reservedPageNum * pageSize); + + for (int i = 0; i < handles.size(); i++) { + auto status = cuMemRelease(handles[i]); + } + + status = cuMemAddressFree(dptr, reservedPageNum * pageSize); + } +} + +void CacheDevicePtr::setPageSize(int64_t num) { pageSize = num * 2 * _MB; } + +// get CUdeviceptr dptr +CUdeviceptr CacheDevicePtr::get_dptr() { return dptr; } + +// get void * type pointer +void* CacheDevicePtr::get_void_ptr() { return reinterpret_cast(dptr); } + + + +/* +** CacheAllocator functions implementation +*/ + +CacheAllocator::CacheAllocator() { + // get current device gpu id + int currentDevice; + auto cudaStatus = cudaGetDevice(¤tDevice); + TORCH_CHECK(cudaStatus == cudaSuccess, "cudaGetDevice failed!"); + + // set memory allocation property struct CUmemAllocationProp, + // which is used to control the specific behavior of memory allocation + prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = currentDevice; + + // set memory access descriptor struct CUmemAccessDesc, + // which is used to control the access permission of memory + accessDescr = {}; + accessDescr.location.id = prop.location.id; + accessDescr.location.type = prop.location.type; + accessDescr.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + // get the granularity of device memory allocation + // getGranurality(); + + // set the page size of memory allocation, default is equal to granurality + // pageSize = granurality; +} + +int64_t CacheAllocator::getGranurality() { + cuMemGetAllocationGranularity(&granurality, &prop, + CU_MEM_ALLOC_GRANULARITY_MINIMUM); + printf("Granularity: %ld Bytes\n", granurality); + return granurality; +} + +void CacheAllocator::setPageSize(int64_t num) { pageSize = num * granurality; } + +// reserve function, reserve virtual address space +int64_t CacheAllocator::reserveCachePtr(const c10::intrusive_ptr& ptr, int64_t pageNum) { + if (pageNum == 0) { + return CUDA_SUCCESS; + } + + size_t size = pageNum * pageSize; + auto status = cuMemAddressReserve(&(ptr->dptr), size, 0, 0, 0); + + if (status != CUDA_SUCCESS) { + printf("cuMemAddressReserve failed! error-code: %d\n", status); + } else { + ptr->reservedPageNum += pageNum; + } + + return status; +} + +// alloc function, allocate physical memory, map to the reserved virtual address +// space of dptr, and set access permission +int64_t CacheAllocator::allocCachePtr(const c10::intrusive_ptr& ptr, + int64_t pageNum, int64_t offset) { + if (pageNum == 0) { + return CUDA_SUCCESS; + } + // size = ((size - 1) / pageSize + 1) * pageSize; + size_t size = pageNum * pageSize; + auto start_dptr = ptr->dptr + offset; + + CUresult status = CUDA_SUCCESS; + CUmemGenericAllocationHandle allocationHandle; + if ((status = cuMemCreate(&allocationHandle, size, &prop, 0)) == + CUDA_SUCCESS) { + if ((status = cuMemMap(start_dptr, size, 0, allocationHandle, 0)) == + CUDA_SUCCESS) { + if ((status = cuMemSetAccess(start_dptr, size, &accessDescr, 1)) == + CUDA_SUCCESS) { + // ptr.handles.push_back(allocationHandle); // handles is unused now + ptr->allocatedPageNum += pageNum; + } else { + printf("cuMemMap success,but cuMemSetAccess failed!, err code: %d\n", + status); + cuMemUnmap(start_dptr, size); + } + } + // if (status != CUDA_SUCCESS) { + // printf("cuMemMap or cuMemsetAccess failed!, err code: %d\n", status); + // cuMemRelease(allocationHandle); + // } + cuMemRelease( + allocationHandle); // always release the handle, but the memory is + // still can access util cuMemUnmap + } else { + printf("cuMemCreate failed!, err code: %d\n", status); + } + return status; +} + +// free function, unmap the virtual address spaceļ¼Œrelease physical memory +// handles and free virtual address space +int64_t CacheAllocator::freeCachePtr(const c10::intrusive_ptr& ptr) { + CUresult status = CUDA_SUCCESS; + if (ptr->dptr != 0) { + status = cuMemUnmap(ptr->dptr, ptr->reservedPageNum * pageSize); + // status = cuMemUnmap(ptr.dptr, ptr.allocatedPageNum * pageSize); + if (status != CUDA_SUCCESS) { + printf("cuMemUnmap failed! error-code: %d\n", status); + } else { + for (int i = 0; i < ptr->handles.size(); i++) { + status = cuMemRelease(ptr->handles[i]); + if (status != CUDA_SUCCESS) { + printf("cuMemRelease failed! error-code: %d\n", status); + return status; + } + } + ptr->handles.clear(); + + status = cuMemAddressFree(ptr->dptr, ptr->reservedPageNum * pageSize); + if (status != CUDA_SUCCESS) { + printf("cuMemAddressFree failed! error-code: %d\n", status); + } + } + } + return status; +} + +// releaseCachePtrPages function, unmap the virtual address spaceļ¼Œrelease +// physical memory handles but not free virtual address space +int64_t CacheAllocator::releaseCachePtr(const c10::intrusive_ptr& ptr, int64_t pageNum, int64_t offset) { + if (pageNum == 0) { + return CUDA_SUCCESS; + } + auto start_dptr = ptr->dptr + offset; + CUresult status = CUDA_SUCCESS; + if (ptr->dptr != 0) { + status = cuMemUnmap(start_dptr, pageNum * pageSize); + // status = cuMemUnmap(ptr.dptr, ptr.allocatedPageNum * pageSize); + if (status != CUDA_SUCCESS) { + printf("cuMemUnmap failed! error-code: %d\n", status); + } else { + for (int i = 0; i < ptr->handles.size(); i++) { + status = cuMemRelease(ptr->handles[i]); + if (status != CUDA_SUCCESS) { + printf("cuMemRelease failed! error-code: %d\n", status); + return status; + } + } + ptr->handles.clear(); + } + } + return status; +} + + + +/* +** vmm other util functions implementation +*/ + +torch::Tensor wrap_dptr_to_tensor(CUdeviceptr d_ptr, const std::string dtype, + at::ArrayRef shape) { + // get current device gpu id + int currentDevice; + auto cudaStatus = cudaGetDevice(¤tDevice); + TORCH_CHECK(cudaStatus == cudaSuccess, "cudaGetDevice failed!"); + + auto _type = c10::kFloat; + + const std::unordered_map typeMap = { + // float data type + {"float64", c10::kDouble}, + {"float32", c10::kFloat}, + {"float16", c10::kHalf}, + {"float", c10::kFloat}, + {"double", c10::kDouble}, + {"half", c10::kHalf}, + {"bfloat16", c10::kBFloat16}, + // integer data type + {"int64", c10::kLong}, + {"int32", c10::kInt}, + {"int16", c10::kShort}, + {"int8", c10::kChar}, + {"int", c10::kInt}, + {"uint8", c10::kByte}}; + + _type = typeMap.at(dtype); + + // set the data type and device of the Tensor + auto options = + torch::TensorOptions().dtype(_type).device(torch::kCUDA, currentDevice); + + // create a Tensor from the CUdeviceptr + torch::Tensor tensor = + torch::from_blob(reinterpret_cast(d_ptr), shape, options); + + return tensor; +} + +torch::Tensor wrap_cache_ptr_to_tensor(const c10::intrusive_ptr& ptr, + const std::string dtype, + at::ArrayRef shape) { + return wrap_dptr_to_tensor(ptr->dptr, dtype, shape); +} \ No newline at end of file diff --git a/csrc/vmm/vmm.h b/csrc/vmm/vmm.h new file mode 100644 index 0000000000000..a1b39320d20b8 --- /dev/null +++ b/csrc/vmm/vmm.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include "torch/custom_class.h" +#include "c10/util/intrusive_ptr.h" + +#include +#include + +#include +#include +#include +#include +#include + +#define _MB (1 << 20) + + +// CacheDevicePtr class, to warp CUdeviceptr, used to kv-cache tensor +// record the reserved virtual address size and allocated physical memory size. +class CacheDevicePtr : public torch::CustomClassHolder{ +private: + +public: + CUdeviceptr dptr; + int64_t reservedPageNum; + int64_t allocatedPageNum; + + int64_t pageSize = 2*_MB; // page size, the minimum unit of kvcache memory allocation + + std::vector handles; // used to store the allocated physical memory handles + + // std::vector handle_sizes; // used to store the number of pages of each allocated handle + + CacheDevicePtr(); + + ~CacheDevicePtr(); + + // set the page size, the page size must be a multiple of the granularity + void setPageSize(int64_t num = 1); + + // get CUdeviceptr dptr + CUdeviceptr get_dptr(); + + // get void * type pointer + void *get_void_ptr(); + +}; + + +// CacheAllocator class, used for memory allocation of kv-cachemanager, memory allocation is based on page granularity, +class CacheAllocator : public torch::CustomClassHolder{ +private: + CUmemAllocationProp prop; + CUmemAccessDesc accessDescr; + + size_t granurality = 2*_MB; // memory allocation granularity supported by the gpu, 2MB for nvidia gpu + size_t pageSize = 2*_MB; // page size, the minimum unit of kvcache memory allocation, default 2MB + +public: + + CacheAllocator(); + + ~CacheAllocator() = default; + + // get the granularity of the memory allocation + int64_t getGranurality(); + + // set the page size, the page size must be a multiple of the granularity + void setPageSize(int64_t num = 1); + + // reserve function, reserve virtual address space + // int reserveCachePtr(CacheDevicePtr& ptr, int64_t pageNum=1); + int64_t reserveCachePtr(const c10::intrusive_ptr& ptr, int64_t pageNum=1); + + // alloc function, allocate physical memory, map to the reserved virtual address space of dptr, and set access permission + // int allocCachePtr(CacheDevicePtr& ptr, int64_t pageNum=1, int64_t offset=0); + int64_t allocCachePtr(const c10::intrusive_ptr& ptr, int64_t pageNum=1, int64_t offset=0); + + // free function, unmap the virtual address spaceļ¼Œrelease physical memory handles and free virtual address space + // int freeCachePtr(CacheDevicePtr& ptr); + int64_t freeCachePtr(const c10::intrusive_ptr& ptr); + + // releaseCachePtrPages function, unmap the virtual address spaceļ¼Œrelease physical memory handles but not free virtual address space + // int releaseCachePtr(CacheDevicePtr& ptr); + int64_t releaseCachePtr(const c10::intrusive_ptr& ptr, int64_t pageNum=0, int64_t offset=0); + +}; + +// warp CUdeviceptr to torch tensor +torch::Tensor wrap_dptr_to_tensor(CUdeviceptr d_ptr, const std::string dtype, at::ArrayRef shape); + +// torch::Tensor wrap_cache_ptr_to_tensor(CacheDevicePtr &ptr, const std::string dtype, at::ArrayRef shape); +torch::Tensor wrap_cache_ptr_to_tensor(const c10::intrusive_ptr&ptr, const std::string dtype, at::ArrayRef shape); \ No newline at end of file diff --git a/vllm/_dattn_ops.py b/vllm/_dattn_ops.py new file mode 100644 index 0000000000000..96cc56dd51050 --- /dev/null +++ b/vllm/_dattn_ops.py @@ -0,0 +1,104 @@ +''' + Copyright (c) ByteDance Inc. + Authors: + - Tongping Liu (tongping.liu@bytedance.com) + - https://github.com/vllm-project/vllm/pull/6102/commits +''' + +import torch +from vllm.logger import init_logger +from typing import List, Optional, Tuple, Type + +logger = init_logger(__name__) + +try: + import vllm._dattn_C # noqa: F401 +except ImportError as e: + logger.warning("Import dattn error msg: %s", e.msg) + +""" +# It seems that there is no need for this function, since we will utilize the +# same +# cache device ptr, used for kv cache tensor +class kvCacheRegion: + def __init__(self): + self._ptr = torch.classes._dattn_C.kvCacheRegion() + + @property + def reserved_page_num(self): + return self._ptr.revervedPageNum + + @reserved_page_num.setter + def reserved_page_num(self, value:int): + self._ptr.reservedPageNum = value + + @property + def allocated_page_num(self): + return self._ptr.allocatedPageNum + + @allocated_page_num.setter + def allocated_page_num(self, value:int): + self._ptr.allocatedPageNum = value +""" + + +# cache allocator based dAttention, used to manage kv cache tensor +class kvCacheAllocator: + def __init__(self, + max_seq_length, + layers_num, + heads_num, + head_size, + block_size, + dtype_size, + ): + self.block_size = block_size + self._allocator = torch.classes._dattn_C.kvCacheAllocator(max_seq_length, + layers_num, + heads_num, + head_size, + block_size, + dtype_size + ) + # self.page_size = self._allocator.getPageSize() + + # def reserve_cache_ptr(self, ptr:CacheDevicePtr, page_num:int = 1): + def reserve_cache_region(self, req_id: int = 1): + # print(f"NOOW, in reserve_cache_region, with req_id:{req_id}") + ptr = self._allocator.reserveRegion(req_id) + # print(f"NOOW, in reserve_cache_region, with req_id:{req_id}, ptr:{ptr}") + # TODO: wrap the ptr to a tensor + # return wrapDptr2Tensor() + return ptr + + # def alloc_cache_ptr(self, ptr:CacheDevicePtr, page_num:int = 1, offset:int = 0): + + # def free_cache_ptr(self, ptr:CacheDevicePtr): + # def release_cache_ptr(self, ptr:CacheDevicePtr, page_num: int = 0, offset: int = 0): + + def release_cache_regions(self, req_ids: List[int]): + self._allocator.releaseRegions(req_ids) + return + + def alloc_cache_blocks(self, req_blocks: List[List[int]]): + return self._allocator.allocCacheBlocks(req_blocks) + + # If the memory is not sufficient, then the python code (as the major control part) + + # can instruct the native library to release some memory. If pages is not specified, + # then the library will collect as much as possible (based on the predefined watermark) + # Otherwise, it will at least collect the specified ``pages'' here + def collect_cached_pages(self, pages: int = 0): + self._allocator.collectPhyPages(pages) + return + + # Get the memory usage for a specific request (when req_id is 0) or the whole allocator + + def get_kvcache_memory_usage(self, req_id: int = 0): + pages = self._allocator.getAllocPhyPages(req_id) + + return pages * self.page_size + +# other utils +# def wrap_cache_ptr_to_tensor(ptr:CacheDevicePtr, dtype_str:str, shape:Tuple[int, ...]): +# return torch.ops._vmm_C.wrap_cache_ptr_to_tensor(ptr._ptr, dtype_str, shape) \ No newline at end of file diff --git a/vllm/_vmm_ops.py b/vllm/_vmm_ops.py new file mode 100644 index 0000000000000..ad3bdbac7e500 --- /dev/null +++ b/vllm/_vmm_ops.py @@ -0,0 +1,59 @@ +import torch +from vllm.logger import init_logger +from typing import List, Optional, Tuple, Type + +logger = init_logger(__name__) + +try: + import vllm._vmm_C # noqa: F401 +except ImportError as e: + logger.warning("Import vmm error msg: %s", e.msg) + + +# cache device ptr, used for kv cache tensor +class CacheDevicePtr: + def __init__(self): + self._ptr = torch.classes._vmm_C.CacheDevicePtr() + + @property + def reserved_page_num(self): + return self._ptr.revervedPageNum + + @reserved_page_num.setter + def reserved_page_num(self, value: int): + self._ptr.reservedPageNum = value + + @property + def allocated_page_num(self): + return self._ptr.allocatedPageNum + + @allocated_page_num.setter + def allocated_page_num(self, value: int): + self._ptr.allocatedPageNum = value + + +# cache allocator based vmm, used to manage kv cache tensor +class CacheAllocator: + def __init__(self): + self._allocator = torch.classes._vmm_C.CacheAllocator() + + def set_page_size(self, page_size: int): + return self._allocator.setPageSize(page_size) + + def reserve_cache_ptr(self, ptr: CacheDevicePtr, page_num: int = 1): + return self._allocator.reserveCachePtr(ptr._ptr, page_num) + + def alloc_cache_ptr(self, ptr: CacheDevicePtr, page_num: int = 1, offset: int = 0): + return self._allocator.allocCachePtr(ptr._ptr, page_num, offset) + + def free_cache_ptr(self, ptr: CacheDevicePtr): + return self._allocator.freeCachePtr(ptr._ptr) + + def release_cache_ptr(self, ptr: CacheDevicePtr, page_num: int = 0, + offset: int = 0): + return self._allocator.releaseCachePtr(ptr._ptr, page_num, offset) + + +# other utils +def wrap_cache_ptr_to_tensor(ptr: CacheDevicePtr, dtype_str: str, shape: Tuple[int, ...]): + return torch.ops._vmm_C.wrap_cache_ptr_to_tensor(ptr._ptr, dtype_str, shape) \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e073d616bf01d..25c37f5735a79 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -144,6 +144,17 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): cross_slot_mapping: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None + # Added for dAttention + use_dattn: bool = False + num_layers: int = 0 + block_size: int = 0 + + # Added for both dAttention and vmm + cache_batch_idx: Optional[torch.Tensor] = None # (batch_size, ) the index of batch in cache + cache_row_mapping: Optional[torch.Tensor] = None # (num_tokens,) record key/value write to which seq row in cache + cache_col_mapping: Optional[ + torch.Tensor] = None # (num_tokens,) record key/value write to which token col in cache + def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt @@ -203,21 +214,36 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) + if not self.use_dattn: + assert self.block_tables is not None + block_tables=self.block_tables[:self.num_prefills] + slot_mapping=self.slot_mapping[:self.num_prefill_tokens] + else: + block_tables = None + slot_mapping = None + # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, + slot_mapping=slot_mapping, + block_tables=block_tables, query_start_loc=query_start_loc, context_lens_tensor=context_lens_tensor, - block_tables=block_tables, + # block_tables=block_tables, use_cuda_graph=False, + use_dattn=self.use_dattn, + num_layers=self.num_layers, + block_size=self.block_size, + cache_batch_idx=self.cache_batch_idx, + cache_row_mapping=self.cache_row_mapping, + cache_col_mapping=self.cache_col_mapping, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, @@ -230,11 +256,21 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - + assert self.seq_lens_tensor is not None if self._cached_decode_metadata is not None: # Recover cached decode-phase attention # metadata structure return self._cached_decode_metadata + + if not self.use_dattn: + assert self.block_tables is not None + block_tables=self.block_tables[:self.num_prefills] + slot_mapping=self.slot_mapping[:self.num_prefill_tokens] + else: + block_tables = None + slot_mapping = None + + assert ((self.seq_lens_tensor is not None) or (self.encoder_seq_lens_tensor is not None)) @@ -257,6 +293,13 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: max_decode_seq_len=self.max_decode_seq_len, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, + + use_dattn=self.use_dattn, + num_layers=self.num_layers, + block_size=self.block_size, + cache_batch_idx=self.cache_batch_idx, + cache_row_mapping=self.cache_row_mapping, + cache_col_mapping=self.seq_lens_tensor[self.num_prefills:], # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, @@ -433,6 +476,7 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.print = 0 suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: @@ -544,15 +588,59 @@ def forward( # Update self-attention KV cache (prefill/decode) updated_slot_mapping = attn_metadata.slot_mapping - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory - # profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - updated_slot_mapping, - self.kv_cache_dtype, - k_scale, v_scale) + # # Reshape the input keys and values and store them in the cache. + # # If kv_cache is not provided, the new key and value tensors are + # # not cached. This happens during the initial memory + # # profiling run. + # PagedAttention.write_to_paged_cache(key, value, key_cache, + # value_cache, + # updated_slot_mapping, + # self.kv_cache_dtype, + # k_scale, v_scale) + + if kv_cache is not None: + if attn_metadata.use_dattn != True: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # print(f"key.shape:{key.shape}, key_cache.shape:{key_cache.shape} before write_to_paged_cache") + # print(f"value.shape:{value.shape}, value_cache.shape:{value_cache.shape} before write_to_paged_cache") + # print(f"attn_metadata.slot_mapping:{attn_metadata.slot_mapping}") + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, kv_scale) + # torch.set_printoptions(precision=2, sci_mode=False) + # print(f"key.shape:{key.shape}, key_cache.shape:{key_cache.shape}") + # print(f"key.shape:{key[:2,:1,].shape}\n{key[:2,:1,]}") + # print(f"key_cache.shape:{key_cache[-1,-1,:,:,:].shape}\n{key_cache[-1,-1,:,:,:]}") + # print(f"key_cache.shape:{key_cache[-1,:1,:,:,:].shape}\n{key_cache[-1,:1,:,:,:]}") + # print(f"value.shape:{value.shape}, newshape:{value[:,-1,].shape}, value:{value[:,-1,]}") + # print(f"value_cache:{value_cache.shape}, value-shape:{value_cache[-1:,-1:,:,:].shape}, value: {value_cache[-1:,-1:,:,:]}") + # torch.set_printoptions(profile="default") + # exit(0) + else: + # print(f"before write_to_paged_cache now with kv_cache:{kv_cache}") + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + layer_idx = kv_cache.item() + # torch.set_printoptions(precision=2, sci_mode=False) + # print(f"key.shape:{key.shape}, type:{key.dtype} \n{key[:2,:1,]}") + # print(f"value.shape:{value.shape}, type:{key.dtype} \n{value[:3,:1,]}") + # if layer_idx == 0: + # print(f"attn_metadata:{attn_metadata}") + # print(f"attn_metadata.cache_row_mapping:{attn_metadata.cache_row_mapping}") + # print(f"attn_metadata.cache_col_mapping:{attn_metadata.cache_col_mapping}") + PagedAttention.write_to_paged_cache_dattn(key, value, layer_idx, + attn_metadata.num_layers, + attn_metadata.block_size, + attn_metadata.cache_row_mapping, + attn_metadata.cache_col_mapping, + self.kv_cache_dtype) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -633,20 +721,67 @@ def forward( block_tables_arg, ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) - output[num_prefill_tokens:] = PagedAttention.forward_decode( - decode_query, - key_cache, - value_cache, - block_tables_arg, - seq_lens_arg, - max_seq_len_arg, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - k_scale, - v_scale, - ) + # output[num_prefill_tokens:] = PagedAttention.forward_decode( + # decode_query, + # key_cache, + # value_cache, + # block_tables_arg, + # seq_lens_arg, + # max_seq_len_arg, + # self.kv_cache_dtype, + # self.num_kv_heads, + # self.scale, + # self.alibi_slopes, + # k_scale, + # v_scale, + # ) + + if not attn_metadata.use_dattn: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_decode_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) + + # print(f"original, output.shape:{output[:,:1,]}", file=sys.stderr) + else: + assert attn_metadata.use_dattn == True + layer_idx = kv_cache.item() + + # print(f"decoding: layer_idx:{layer_idx}, decode_meta.num_layers:{decode_meta.num_layers}, decode_meta.block_size:{decode_meta.block_size}, decode_meta.cache_row_mapping:{decode_meta.cache_row_mapping.shape}, decode_meta.cache_col_mapping:{decode_meta.cache_col_mapping}") + # print(f"decoding: layer_idx:{layer_idx}, decode_meta.num_layers:{decode_meta.num_layers}") + # print(f"decoding: layer_idx:{layer_idx}, cache_row_mapping:{decode_meta.cache_row_mapping.shape}, cache_col_mapping:{decode_meta.cache_col_mapping}") + # print(f"decode_meta.seq_lens_tensor:{decode_meta.seq_lens_tensor}, decode_meta.max_decode_seq_len:{decode_meta.max_decode_seq_len}") + output[num_prefill_tokens:] = PagedAttention.forward_decode_dattn( + decode_query, + layer_idx, + decode_meta.num_layers, + decode_meta.block_size, + decode_meta.max_decode_seq_len, + decode_meta.seq_lens_tensor, + decode_meta.cache_row_mapping, + decode_meta.cache_col_mapping, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) + # exit(0) + + # print(f"layer-{layer_idx}: output.shape:{output[:,:1,]}", file=sys.stderr) + # if layer_idx == 11: + # exit(0) + + self.print += 1 # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 92023d5b75f5a..a49060eb725d2 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -83,6 +83,28 @@ def write_to_paged_cache( v_scale, ) + @staticmethod + def write_to_paged_cache_dattn( + key: torch.Tensor, + value: torch.Tensor, + layer_idx: int, + num_layers: int, + block_size: int, + cache_row_mapping: torch.Tensor, + cache_col_mapping: torch.Tensor, + kv_cache_dtype: str, + ) -> None: + ops.reshape_and_cache_dattn( + key, + value, + layer_idx, + num_layers, + block_size, + cache_row_mapping, + cache_col_mapping, + kv_cache_dtype + ) + @staticmethod def forward_decode( query: torch.Tensor, @@ -189,6 +211,67 @@ def forward_decode( ) return output + @staticmethod + def forward_decode_dattn( + query: torch.Tensor, + layer_idx: int, + num_layers: int, + block_size: int, + max_seq_len: int, + seq_lens: torch.Tensor, + cache_row_mapping: torch.Tensor, + cache_col_mapping: torch.Tensor, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + kv_scale: float, + ) -> torch.Tensor: + output = torch.empty_like(query) + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + if max_num_partitions > 1: + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + # import sys + # print(f"forward_decode_dattn on long sequence: Query shape: {output.shape}", file=sys.stderr) + else: + tmp_output = None + exp_sums = None + max_logits = None + # print(f"forward_decode_dattn: Query shape: {output.shape}") + ops.dattention( + output, + exp_sums, + max_logits, + tmp_output, + query, + layer_idx, + num_layers, + block_size, + max_seq_len, + seq_lens, + cache_row_mapping, + cache_col_mapping, + kv_cache_dtype, + num_kv_heads, + scale, + alibi_slopes, + kv_scale, + ) + + return output + @staticmethod def forward_prefix( query: torch.Tensor, diff --git a/vllm/config.py b/vllm/config.py index 7a15606836dcc..45c5de1b6815e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -30,6 +30,8 @@ logger = init_logger(__name__) +_GB = 1 << 30 +_MB = 1 << 20 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 @@ -589,6 +591,9 @@ def __init__( num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, + use_vmm: bool = False, # new add for vmm + use_dattn: bool = False, # new add for dattention + block_bytes_size: int = 2 * _MB, cpu_offload_gb: float = 0, ) -> None: self.block_size = block_size @@ -607,6 +612,11 @@ def __init__( self.num_gpu_blocks = None self.num_cpu_blocks = None + # new add for vmm + self.block_bytes_size = block_bytes_size + self.use_vmm = use_vmm + self.use_dattn = use_dattn + def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info diff --git a/vllm/core/block_manager_dattn.py b/vllm/core/block_manager_dattn.py new file mode 100644 index 0000000000000..27357fea4c723 --- /dev/null +++ b/vllm/core/block_manager_dattn.py @@ -0,0 +1,384 @@ +''' + Copyright (c) ByteDance Inc. + Authors: + - Tongping Liu (tongping.liu@bytedance.com) + + Adopted from https://github.com/vllm-project/vllm/pull/6102/commits +''' +from collections import deque +from typing import Dict, List, Optional, Tuple + +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.logger import init_logger +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device, Counter +from collections import deque + +logger = init_logger(__name__) + + +class CacheBufferAllocator: + def __init__(self, num_cache_buffers: int): + self.num_cache_buffers = num_cache_buffers + self.free_buffers = deque(range(num_cache_buffers)) + + def allocate(self) -> int: + buffer_id = self.free_buffers.popleft() + return buffer_id + + def free(self, buffer_id: int): + self.free_buffers.append(buffer_id) + + def reset(self): + self.free_buffers = deque(range(self.num_cache_buffers)) + + def get_num_free_buffers(self): + return len(self.free_buffers) + + def get_num_total_buffers(self): + return self.num_cache_buffers + + +class BlockSpaceManagerDAttn(BlockSpaceManager): + """Manages the mapping between logical and physical token blocks.""" + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + enable_caching: bool = False, + num_cache_buffers: int = 0, + ) -> None: + + if enable_caching or (sliding_window is not None): + raise NotImplementedError("Prefix Caching or Sliding window is not supported in VMM now.") + + self.enable_caching = enable_caching + + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + self.num_free_gpu_blocks = num_gpu_blocks + self.num_free_cpu_blocks = num_cpu_blocks + + self.num_cache_buffers = num_cache_buffers # == self.scheduler_config.max_num_seqs + + print(f"self.num_cache_buffers:{self.num_cache_buffers} inside BlockSpaceManagerDAttn") + # use to alloc cache buffer id for seq + self.gpu_allocator = CacheBufferAllocator(num_cache_buffers) + + # Watermark indicates that the least amount of blocks should be free. + self.watermark = watermark + assert watermark >= 0.0 + + self.watermark_blocks = int(watermark * num_gpu_blocks) + + # Mapping from cache buffer ID to the number of allocated blocks. + self.allocated_block_counts: Dict[int, int] = {} + self.modified_block_counts: Dict[int, int] = {} + self.waiting_free_buffers: List[Tuple[int, int]] = [] + self.waiting_free_blocks: int = 0 + self.free_buffer_ids: List[int] = [] + self.free_latency: int = 10 + + # TODO: this is very confusing + self.iter_counter = Counter() + + self._init_alloc() + + def _init_alloc(self) -> None: + # we init alloc one block for warp in cache_engine_vmm + self.allocated_block_counts[0] = 1 + self.num_free_gpu_blocks -= 1 + + def _predict_gen_len(self, seq: Sequence) -> int: + # TODO:this function is used to predict the generated content length, + # which can used to pre allocate the memory handles + return 1 + + def _get_seq_num_required_blocks(self, seq: Sequence) -> int: + return 0 if seq is None else seq.n_blocks + + # This will be invoked in the prefill phase + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + # get_seqs will collect a list of sequence with status equalling to SequenceStatus.WAITING + # then we will get the first sequence in this group + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = self._get_seq_num_required_blocks(seq) + num_required_blocks += self._predict_gen_len(seq) + + # If the sequence is not allocated yet, its cache_buffer_id must be -1. + assert seq.cache_buffer_id == -1 + + num_free_gpu_blocks = self.num_free_gpu_blocks + \ + self.waiting_free_blocks + + # Ensure that one request should not use more than 90% or 99% of memory + # This can avoid frequent cache eviction + if (self.num_total_gpu_blocks - num_required_blocks < + self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + # Make sure that we are not holding more than schedule_config.max_num_seqs + # TODO: is there a potential issue? If self.gpu_allocator.get_num_free_buffers()=0, and + # waiting_free_buffer exists (only one will exit soon), then can multiple requests get admitted? + if self.gpu_allocator.get_num_free_buffers() > 0 or self.waiting_free_buffers: + return AllocStatus.OK + else: + return AllocStatus.LATER + else: + return AllocStatus.LATER + + # This function is only invoked by _allocate_and_set_running (invoked by _schedule_prefills) + # That is, it is allocated when admitting a new request in prefill phase. + # Therefore, it will invoke self._allocate_buffer() to allocate a request and then + # update the seq.cache_buffer_id, seq.data.cache_buffer_id, self.allocated_block_counts[buffer_id] + # TODO: for instance, if there is a request with 26 tokens, then it will need two + # blocks?? + def allocate(self, seq_group: SequenceGroup) -> None: + # No need to do this, as we have checked before + # if seq_group.is_encoder_decoder(): + # raise NotImplementedError("Encoder-decoder is not supported in VMM now.") + + # check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + # Allocate decoder sequences + # + # NOTE: Here we assume that all sequences in the group have the same + # decoder prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + + need_blocks_num = self._get_seq_num_required_blocks(seq) + + # TODO: Don't know why we will need this _predict_gen_len?? + need_blocks_num += self._predict_gen_len(seq) + buffer_id, allocated_num = self._allocate_buffer(need_blocks_num) + + seq.cache_buffer_id = buffer_id + seq.data.cache_buffer_id = buffer_id + self.allocated_block_counts[buffer_id] = allocated_num + self.modified_block_counts[buffer_id] = allocated_num + # predict generate content length and pre allocate the blocks + # need_blocks_num += self._predict_gen_len(seq) + + def _allocate_buffer(self, need_blocks_num: int) -> Tuple[int, int]: + if self.waiting_free_buffers: + return self._allocate_from_waiting_buffer(need_blocks_num) + else: + assert self.num_free_gpu_blocks >= need_blocks_num + buffer_id = self.gpu_allocator.allocate() + self.num_free_gpu_blocks -= need_blocks_num + return buffer_id, need_blocks_num + + """ + Allocate need_blocks_num from waiting buffer that holds freed sequences + """ + + def _allocate_from_waiting_buffer(self, + need_blocks_num: int) -> Tuple[int, int]: + buffer_id, _ = self.waiting_free_buffers.pop(0) + allocated_num = self.allocated_block_counts[buffer_id] + self.waiting_free_blocks -= allocated_num + + # If the number of blocks is not sufficient, let's allocate more blocks. + # However, I don't know whether these new blocks are related to the given req_id + if allocated_num < need_blocks_num: + # TODO: this allocation has the issue, as it can't guarantee that + # the blocks are allocated from the specified request id. + self._allocate_extra_blocks(need_blocks_num - allocated_num) + allocated_num = need_blocks_num + # If we reuse a buffer that's too long, we may need to free the memory + # that's more than we currently need (need_blocks_num) + # But now, frequent frees are an overhead, so we don't do it. + # TODO: Reduced overhead or asynchronous free + # else: + # self.num_free_gpu_blocks += (allocated_num - need_blocks_num) + # allocated_num = need_blocks_num + + return buffer_id, allocated_num + + def _allocate_extra_blocks(self, extra_blocks: int) -> None: + if self.num_free_gpu_blocks >= extra_blocks: + # It is actually deducted free_gpu_blocks. + self.num_free_gpu_blocks -= extra_blocks + else: + extra_need_blocks = extra_blocks - self.num_free_gpu_blocks + self.num_free_gpu_blocks = 0 + + self._allocate_from_waiting_buffers(extra_need_blocks) + + # free some blocks from waiting buffers to allocate + # The name is very confusing, as it is similar to _allocate_from_waiting_buffer + def _allocate_from_waiting_buffers(self, blocks_to_alloc: int) -> None: + while self.waiting_free_buffers and blocks_to_alloc > 0: + free_id, _ = self.waiting_free_buffers.pop(0) + free_blocks = self.allocated_block_counts[free_id] + self.waiting_free_blocks -= free_blocks + self.free_buffer_ids.append(free_id) + self.allocated_block_counts[free_id] = 0 + blocks_to_alloc -= free_blocks + + assert blocks_to_alloc <= 0 + self.num_free_gpu_blocks -= blocks_to_alloc + + # Invoked by _schedule_running in running phase. + def can_append_slots(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> bool: + assert (num_lookahead_slots == 0 + ), "lookahead allocation not supported in BlockSpaceManagerDAttn." + + # FIXME: this is wrong for vAttention, as it requires many blocks for + # a token (unless its num_free_gpu_blocks already consider the number of layers ) + # Simple heuristic: If there is at least one free block + # for each sequence, we can append. + num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) + num_free_gpu_blocks = self.num_free_gpu_blocks + \ + self.waiting_free_blocks + return num_seqs <= num_free_gpu_blocks + + # FIXME: there is no handling on num_lookahead_slots, which should be handled. + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int = 0, + ) -> List[Tuple[int, int]]: + """Allocate a physical slot for a new token.""" + + buffer_id = seq.cache_buffer_id + + # If the sequence is allocated, its cache_buffer_id must >= 0. + assert buffer_id >= 0 + + logical_blocks_num = seq.n_blocks + allocated_num = self.allocated_block_counts[buffer_id] + + # If we need to allocate a new physical block + if allocated_num < logical_blocks_num: + # Currently this code only supports adding one physical block + assert allocated_num == logical_blocks_num - 1 + + # Added one new block??? Why, this is confusing? + self._allocate_extra_blocks(1) + self.allocated_block_counts[buffer_id] = logical_blocks_num + self.modified_block_counts[buffer_id] = logical_blocks_num + return [] + + else: + # the last block is not full, no need to allocate a new block + return [] + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + raise NotImplementedError("Forking is not supported in BlockSpaceManagerVMM now.") + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + raise NotImplementedError("Swap-in is not supported in BlockSpaceManagerVMM now.") + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + raise NotImplementedError("Swap-in is not supported in BlockSpaceManagerVMM now.") + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + raise NotImplementedError("Swap-out is not supported in BlockSpaceManagerVMM now.") + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + raise NotImplementedError("Swap-out is not supported in BlockSpaceManagerVMM now.") + + """ + Free a sequence. We will append the seq to waiting_free_buffers. + Initially, we did this inside the memory management library. Maybe we should do it here as well. + """ + + def free(self, seq: Sequence) -> None: + # Here, we just append free seq to waiting_free_buffers. + waiting_free_id = seq.cache_buffer_id + + # If no blocks are allocated in the sequence, then this sequence may be deallocated. + if waiting_free_id not in self.allocated_block_counts or \ + self.allocated_block_counts[waiting_free_id] == 0: + # Already freed or haven't been scheduled yet. + return + + # Get free_blocks of this sequence + free_blocks = self.allocated_block_counts[waiting_free_id] + self.waiting_free_buffers.append((waiting_free_id, + self.iter_counter.counter)) + self.waiting_free_blocks += free_blocks + + def reset(self) -> None: + # Free decoder block tables + self.allocated_block_counts.clear() + self.num_free_gpu_blocks = self.num_total_gpu_blocks + self.num_free_cpu_blocks = self.num_total_cpu_blocks + + self.waiting_free_buffers = [] + self.modified_block_counts = {} + self.free_buffer_ids = [] + self.gpu_allocator.reset() + + def get_block_table(self, seq: Sequence) -> List[int]: + # logger.warning("block table is not used in BlockSpaceManagerVMM now.") + return [] + + def get_num_free_gpu_blocks(self) -> int: + return self.num_free_gpu_blocks + + def get_num_free_cpu_blocks(self) -> int: + return self.num_free_cpu_blocks + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + # logger.warning("Access all blocks in seq is not supported in BlockSpaceManagerVMM now.") + pass + + def get_common_computed_block_ids(self, + seq_group: SequenceGroup) -> List[int]: + # logger.warning("Common computed block ids is not supported in BlockSpaceManagerVMM now.") + return None # type: ignore + + def mark_blocks_as_computed(self, seq_group: SequenceGroup) -> None: + # logger.warning("Mark blocks as computed is not supported in BlockSpaceManagerVMM now.") + pass + + def get_allocated_block_count(self, seq_id: int) -> int: + return self.allocated_block_counts[seq_id] + + def check_and_free_waiting_buffers(self, now_iter: int) -> None: + while self.waiting_free_buffers and \ + self.waiting_free_buffers[0][1] - now_iter >= self.free_latency: + free_id, _ = self.waiting_free_buffers.pop(0) + free_blocks = self.allocated_block_counts[free_id] + self.waiting_free_blocks -= free_blocks + self.num_free_gpu_blocks += free_blocks + self.free_buffer_ids.append(free_id) + self.allocated_block_counts[free_id] = 0 + + def step(self) -> Tuple[Dict[int, int], List[int]]: + # next() is a built-in function for the iterator, which will execute __next__() + iter = next(self.iter_counter) + modified_block_counts = self.modified_block_counts + free_buffer_ids = self.free_buffer_ids + + # Whether we need to invoke this before returning free_buffer_ids?? + self.check_and_free_waiting_buffers(iter) + + # step() is invoked once after _schedule() inside Scheduler::schedule(). It is invoked once for every decode or prefill + # We actually uses self.free_buffer_ids and self.modified_block_counts to track all requests + # checked by the whole _schedule(). This is a hacky solution but may work correctly. + self.modified_block_counts = {} + self.free_buffer_ids = [] + return modified_block_counts, free_buffer_ids diff --git a/vllm/core/block_manager_vmm.py b/vllm/core/block_manager_vmm.py new file mode 100644 index 0000000000000..d1087380653a1 --- /dev/null +++ b/vllm/core/block_manager_vmm.py @@ -0,0 +1,326 @@ +from collections import deque +from typing import Dict, List, Optional, Tuple + +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.logger import init_logger +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device, Counter +from collections import deque + +logger = init_logger(__name__) + + +class CacheBufferAllocator: + def __init__(self, num_cache_buffers: int): + self.num_cache_buffers = num_cache_buffers + self.free_buffers = deque(range(num_cache_buffers)) + + def allocate(self) -> int: + buffer_id = self.free_buffers.popleft() + return buffer_id + + def free(self, buffer_id: int): + self.free_buffers.append(buffer_id) + + def reset(self): + self.free_buffers = deque(range(self.num_cache_buffers)) + + def get_num_free_buffers(self): + return len(self.free_buffers) + + def get_num_total_buffers(self): + return self.num_cache_buffers + + +class BlockSpaceManagerVMM(BlockSpaceManager): + """Manages the mapping between logical and physical token blocks.""" + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + enable_caching: bool = False, + num_cache_buffers: int = 0, + ) -> None: + + if enable_caching or (sliding_window is not None): + raise NotImplementedError("Prefix Caching or Sliding window is not supported in VMM now.") + + self.enable_caching = enable_caching + + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + self.num_free_gpu_blocks = num_gpu_blocks + self.num_free_cpu_blocks = num_cpu_blocks + + self.num_cache_buffers = num_cache_buffers # == self.scheduler_config.max_num_seqs + + # use to alloc cache buffer id for seq + self.gpu_allocator = CacheBufferAllocator(num_cache_buffers) + + self.watermark = watermark + assert watermark >= 0.0 + + self.watermark_blocks = int(watermark * num_gpu_blocks) + + # Mapping from cache buffer ID to the number of allocated blocks. + self.allocated_block_counts: Dict[int, int] = {} + self.modified_block_counts: Dict[int, int] = {} + self.waiting_free_buffers: List[Tuple[int, int]] = [] + self.waiting_free_blocks: int = 0 + self.free_buffer_ids: List[int] = [] + self.free_latency: int = 10 + self.iter_counter = Counter() + + self.init_alloc() + + def init_alloc(self) -> None: + # we init alloc one block for warp in cache_engine_vmm + self.allocated_block_counts[0] = 1 + self.num_free_gpu_blocks -= 1 + + def predict_gen_len(self, seq: Sequence) -> int: + # TODO:this function is used to predict the generated content length, + # which can used to pre allocate the memory handles + return 1 + + def _get_seq_num_required_blocks(self, seq: Sequence) -> int: + return 0 if seq is None else seq.n_blocks + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = self._get_seq_num_required_blocks(seq) + num_required_blocks += self.predict_gen_len(seq) + + # If the sequence is not allocated yet, its cache_buffer_id must be -1. + assert seq.cache_buffer_id == -1 + + num_free_gpu_blocks = self.num_free_gpu_blocks + \ + self.waiting_free_blocks + + # Use watermark to avoid frequent cache eviction. + if (self.num_total_gpu_blocks - num_required_blocks < + self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + if self.gpu_allocator.get_num_free_buffers() > 0 or self.waiting_free_buffers: + return AllocStatus.OK + else: + return AllocStatus.LATER + else: + return AllocStatus.LATER + + def allocate(self, seq_group: SequenceGroup) -> None: + if seq_group.is_encoder_decoder(): + raise NotImplementedError("Encoder-decoder is not supported in VMM now.") + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + # Allocate decoder sequences + # + # NOTE: Here we assume that all sequences in the group have the same + # decoder prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + + need_blocks_num = self._get_seq_num_required_blocks(seq) + need_blocks_num += self.predict_gen_len(seq) + buffer_id, allocated_num = self._allocate_buffer(need_blocks_num) + + seq.cache_buffer_id = buffer_id + seq.data.cache_buffer_id = buffer_id + self.allocated_block_counts[buffer_id] = allocated_num + self.modified_block_counts[buffer_id] = allocated_num + + # predict generate content length and pre allocate the blocks + # need_blocks_num += self.predict_gen_len(seq) + + def _allocate_buffer(self, need_blocks_num: int) -> Tuple[int, int]: + if self.waiting_free_buffers: + return self._allocate_from_waiting_buffer(need_blocks_num) + else: + assert self.num_free_gpu_blocks >= need_blocks_num + buffer_id = self.gpu_allocator.allocate() + self.num_free_gpu_blocks -= need_blocks_num + return buffer_id, need_blocks_num + + def _allocate_from_waiting_buffer(self, + need_blocks_num: int) -> Tuple[int, int]: + buffer_id, _ = self.waiting_free_buffers.pop(0) + allocated_num = self.allocated_block_counts[buffer_id] + self.waiting_free_blocks -= allocated_num + + if allocated_num < need_blocks_num: + self._allocate_extra_blocks(need_blocks_num - allocated_num) + allocated_num = need_blocks_num + # If we reuse a buffer that's too long, we may need to free the memory + # that's more than we currently need (need_blocks_num) + # But now, frequent frees are an overhead, so we don't do it. + # TODO: Reduced overhead or asynchronous free + # else: + # self.num_free_gpu_blocks += (allocated_num - need_blocks_num) + # allocated_num = need_blocks_num + + return buffer_id, allocated_num + + def _allocate_extra_blocks(self, extra_blocks: int) -> None: + if self.num_free_gpu_blocks >= extra_blocks: + self.num_free_gpu_blocks -= extra_blocks + else: + extra_need_blocks = extra_blocks - self.num_free_gpu_blocks + self.num_free_gpu_blocks = 0 + self._allocate_from_waiting_buffers(extra_need_blocks) + + # free some blocks from waiting buffers to allocate + def _allocate_from_waiting_buffers(self, blocks_to_alloc: int) -> None: + while self.waiting_free_buffers and blocks_to_alloc > 0: + free_id, _ = self.waiting_free_buffers.pop(0) + free_blocks = self.allocated_block_counts[free_id] + self.waiting_free_blocks -= free_blocks + self.free_buffer_ids.append(free_id) + self.allocated_block_counts[free_id] = 0 + blocks_to_alloc -= free_blocks + + assert blocks_to_alloc <= 0 + self.num_free_gpu_blocks -= blocks_to_alloc + + def can_append_slots(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> bool: + assert (num_lookahead_slots == 0 + ), "lookahead allocation not supported in BlockSpaceManagerVMM." + + # Simple heuristic: If there is at least one free block + # for each sequence, we can append. + num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) + num_free_gpu_blocks = self.num_free_gpu_blocks + \ + self.waiting_free_blocks + return num_seqs <= num_free_gpu_blocks + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int = 0, + ) -> List[Tuple[int, int]]: + """Allocate a physical slot for a new token.""" + + buffer_id = seq.cache_buffer_id + + # If the sequence is allocated, its cache_buffer_id must >= 0. + assert buffer_id >= 0 + + logical_blocks_num = seq.n_blocks + allocated_num = self.allocated_block_counts[buffer_id] + + # If we need to allocate a new physical block + if allocated_num < logical_blocks_num: + # Currently this code only supports adding one physical block + assert allocated_num == logical_blocks_num - 1 + self._allocate_extra_blocks(1) + self.allocated_block_counts[buffer_id] = logical_blocks_num + self.modified_block_counts[buffer_id] = logical_blocks_num + return [] + + else: + # the last block is not full, no need to allocate a new block + return [] + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + raise NotImplementedError("Forking is not supported in BlockSpaceManagerVMM now.") + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + raise NotImplementedError("Swap-in is not supported in BlockSpaceManagerVMM now.") + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + raise NotImplementedError("Swap-in is not supported in BlockSpaceManagerVMM now.") + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + raise NotImplementedError("Swap-out is not supported in BlockSpaceManagerVMM now.") + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + raise NotImplementedError("Swap-out is not supported in BlockSpaceManagerVMM now.") + + def free(self, seq: Sequence) -> None: + # Here, we just append free seq to waiting_free_buffers. + waiting_free_id = seq.cache_buffer_id + if waiting_free_id not in self.allocated_block_counts or \ + self.allocated_block_counts[waiting_free_id] == 0: + # Already freed or haven't been scheduled yet. + return + + free_blocks = self.allocated_block_counts[waiting_free_id] + self.waiting_free_buffers.append((waiting_free_id, + self.iter_counter.counter)) + self.waiting_free_blocks += free_blocks + + def reset(self) -> None: + # Free decoder block tables + self.allocated_block_counts.clear() + self.num_free_gpu_blocks = self.num_total_gpu_blocks + self.num_free_cpu_blocks = self.num_total_cpu_blocks + + self.waiting_free_buffers = [] + self.modified_block_counts = {} + self.free_buffer_ids = [] + self.gpu_allocator.reset() + + def get_block_table(self, seq: Sequence) -> List[int]: + # logger.warning("block table is not used in BlockSpaceManagerVMM now.") + return [] + + def get_num_free_gpu_blocks(self) -> int: + return self.num_free_gpu_blocks + + def get_num_free_cpu_blocks(self) -> int: + return self.num_free_cpu_blocks + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + # logger.warning("Access all blocks in seq is not supported in BlockSpaceManagerVMM now.") + pass + + def get_common_computed_block_ids(self, + seq_group: SequenceGroup) -> List[int]: + # logger.warning("Common computed block ids is not supported in BlockSpaceManagerVMM now.") + return None # type: ignore + + def mark_blocks_as_computed(self, seq_group: SequenceGroup) -> None: + # logger.warning("Mark blocks as computed is not supported in BlockSpaceManagerVMM now.") + pass + + def get_allocated_block_count(self, seq_id: int) -> int: + return self.allocated_block_counts[seq_id] + + def check_and_free_waiting_buffers(self, now_iter: int) -> None: + while self.waiting_free_buffers and \ + self.waiting_free_buffers[0][1] - now_iter >= self.free_latency: + free_id, _ = self.waiting_free_buffers.pop(0) + free_blocks = self.allocated_block_counts[free_id] + self.waiting_free_blocks -= free_blocks + self.num_free_gpu_blocks += free_blocks + self.free_buffer_ids.append(free_id) + self.allocated_block_counts[free_id] = 0 + + def step(self) -> Tuple[Dict[int, int], List[int]]: + iter = next(self.iter_counter) + modified_block_counts = self.modified_block_counts + free_buffer_ids = self.free_buffer_ids + self.check_and_free_waiting_buffers(iter) + self.modified_block_counts = {} + self.free_buffer_ids = [] + return modified_block_counts, free_buffer_ids \ No newline at end of file diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 96f8dd851b2f4..66bd940df5d38 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -41,6 +41,13 @@ def get_block_space_manager_class(version: str): EmbeddingModelBlockSpaceManager) return EmbeddingModelBlockSpaceManager + if version == "vmm": # new add for vmm + from vllm.core.block_manager_vmm import BlockSpaceManagerVMM + return BlockSpaceManagerVMM + + if version == "dattn": # new add for vmm + from vllm.core.block_manager_dattn import BlockSpaceManagerDAttn + return BlockSpaceManagerDAttn raise ValueError(f"Unknown version {version=}") @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3fa95f57b737..dad9b4d977499 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -134,6 +134,11 @@ class SchedulerOutputs: running_queue_size: int preempted: int + # new add for vmm + allocated_block_counts: Dict[int, int] = field(default_factory=dict) + free_buffer_ids: List[int] = field(default_factory=list) + + def __post_init__(self): # Swap in and swap out should never happen at the same time. assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) @@ -311,12 +316,21 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config + self.use_vmm = cache_config.use_vmm + self.use_dattn = cache_config.use_dattn + self.prefillcount = 0 + self.schedcount = 0 + self.runningcount = 0 + version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" if self.scheduler_config.embedding_mode: version = "embedding" - + if self.use_vmm: + version = "vmm" + if self.use_dattn: + version = "dattn" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) @@ -329,12 +343,27 @@ def __init__( num_cpu_blocks //= pipeline_parallel_size # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching) + if not cache_config.use_vmm and not cache_config.use_dattn: + self.block_manager = BlockSpaceManagerImpl( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching) + else: # vmm block space manager. + self.block_manager = BlockSpaceManagerImpl( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching, + num_cache_buffers=self.scheduler_config.max_num_seqs) + # self.block_manager = BlockSpaceManagerImpl( + # block_size=self.cache_config.block_size, + # num_gpu_blocks=num_gpu_blocks, + # num_cpu_blocks=num_cpu_blocks, + # sliding_window=self.cache_config.sliding_window, + # enable_caching=self.cache_config.enable_prefix_caching) # Sequence groups in the WAITING state. # Contain new prefill or preempted requests. @@ -1151,8 +1180,10 @@ def schedule( for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data - block_tables[seq_id] = self.block_manager.get_block_table(seq) - self.block_manager.access_all_blocks_in_seq(seq, now) + # block_tables[seq_id] = self.block_manager.get_block_table(seq) + # self.block_manager.access_all_blocks_in_seq(seq, now) + if not self.use_vmm: + block_tables[seq_id] = self.block_manager.get_block_table(seq) if self.cache_config.enable_prefix_caching: common_computed_block_nums = ( diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index df07842edfa56..0020356b04d8e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -66,8 +66,20 @@ def _split_tensor_dict( metadata_list.append( (key, TensorMetadata(device, value.dtype, value.size()))) tensor_list.append(value) + elif isinstance(value, dict): + if key == "allocated_block_counts" or key == "free_buffer_ids": + # if allocated_block_counts, no need to split_tensor_dict its values + metadata_list.append((prefix + key, value)) + else: + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%") + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) else: metadata_list.append((key, value)) + return metadata_list, tensor_list diff --git a/vllm/sequence.py b/vllm/sequence.py index 07ceccf123541..a9189f22976dd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -696,6 +696,21 @@ def multi_modal_data(self) -> "MultiModalDataDict": def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + def __repr__(self) -> str: + return (f"SequenceGroupMetadata(\n" + f"-- request_id={self.request_id}, \n" + f"-- is_prompt={self.is_prompt}, \n" + f"-- seq_data={self.seq_data}, \n" + f"-- sampling_params={self.sampling_params}, \n" + f"-- block_tables={self.block_tables}, \n" + f"-- do_sample={self.do_sample}, \n" + f"-- token_chunk_size={self.token_chunk_size}, \n" + f"-- lora_request={self.lora_request}, \n" + f"-- computed_block_nums={self.computed_block_nums}, \n" + f"-- multi_modal_data={self.multi_modal_data}, \n" + f"-- encoder_seq_data={self.encoder_seq_data}, \n" + f"-- cross_block_table={self.cross_block_table})") + @property def prompt_adapter_id(self) -> int: return self.prompt_adapter_request.prompt_adapter_id \ @@ -1263,6 +1278,11 @@ class ExecuteModelRequest( num_steps: int = 1 # Finished request ids since last step. finished_requests_ids: List[str] = msgspec.field(default_factory=list) + + # new add for vmm and dattn + allocated_block_counts: Dict[int, int]= msgspec.field(default_factory=dict), + free_buffer_ids: List[int] = msgspec.field(default_factory=list) + # The last sampled token ids for multi step decoding. last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback diff --git a/vllm/utils.py b/vllm/utils.py index 060b387ec7834..cc7efe0adf3e6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -138,6 +138,25 @@ torch.int64: np.int64, } +# new add for vmm (use in wrap ptr to tensor) +TORCH_DTYPE_TO_STR_DTYPE = { + torch.double: "double", + torch.float: "float", + torch.float64: "float64", + torch.float32: "float32", + torch.float16: "float16", + torch.half: "half", + torch.bfloat16: "bfloat16", + + torch.int: "int", + torch.int64: "int64", + torch.int32: "int32", + torch.int16: "int16", + torch.int8: "int8", + + torch.uint8: "uint8", +} + P = ParamSpec('P') K = TypeVar("K") T = TypeVar("T") @@ -676,6 +695,36 @@ def create_kv_caches_with_random_flash( value_caches.append(key_value_cache[:, 1]) return key_caches, value_caches +def create_kv_caches_with_random_flash_non_page( + batch_size: int, + seq_len: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + assert cache_dtype != "fp8" + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (2, batch_size, seq_len, num_heads, head_size) + scale = head_size**-0.5 + key_caches, value_caches = [], [] + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + key_value_cache.uniform_(-scale, scale) + key_caches.append(key_value_cache[0]) + value_caches.append(key_value_cache[1]) + return key_caches, value_caches + def create_kv_caches_with_random( num_blocks: int, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 252440c7b7e08..59630af6f6026 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -65,6 +65,8 @@ def __init__( # Initialize the cache. self.gpu_cache = self._allocate_kv_cache( self.num_gpu_blocks, self.device_config.device_type) + + #print(f"NNNNNNNNNNNNN self.gpu_cache:{self.gpu_cache}") self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") def _allocate_kv_cache( @@ -115,6 +117,7 @@ def get_cache_block_size( key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block total = num_attention_layers * (key_cache_block + value_cache_block) + print(f"head_size:{head_size}, num_heads:{num_heads}, num_attention_layers:{num_attention_layers}, block_size: {cache_config.block_size}, key_cache_block:{key_cache_block},total:{total/1024}KB") if cache_config.cache_dtype == "auto": dtype = model_config.dtype else: diff --git a/vllm/worker/cache_engine_dattn.py b/vllm/worker/cache_engine_dattn.py new file mode 100644 index 0000000000000..058036d1ac997 --- /dev/null +++ b/vllm/worker/cache_engine_dattn.py @@ -0,0 +1,203 @@ +''' + Copyright (c) ByteDance Inc. + Authors: + - Tongping Liu (tongping.liu@bytedance.com) + - https://github.com/vllm-project/vllm/pull/6102/commits +''' +"""CacheEngine class for managing the KV cache.""" +from typing import List, Dict, Tuple + +import torch + +from vllm.attention import get_attn_backend +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, DeviceConfig +from vllm.logger import init_logger +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available, TORCH_DTYPE_TO_STR_DTYPE, get_dtype_size + +from vllm import _dattn_ops as dattn + +logger = init_logger(__name__) + + +class CacheEngineDAttn: + """Manages the KV cache. + + This class is responsible for initializing and managing the GPU and CPU KV + caches. It also provides methods for performing KV cache operations, such + as swapping and copying. + """ + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + ) -> None: + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + if self.device_config.device_type != "cuda": + raise RuntimeError("DATTN only support cuda device.") + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + self.block_bytes_size = self.cache_config.block_bytes_size + self.num_gpu_blocks = cache_config.num_gpu_blocks + # self.num_cpu_blocks = cache_config.num_cpu_blocks + + print(f"self.block_bytes_size-{self.block_bytes_size}") + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + self.dtype_size = get_dtype_size(self.dtype) + self.max_batch_size = self.scheduler_config.max_num_seqs + self.max_seq_len = self.scheduler_config.max_model_len + + # If max_seq_len is not divisible by block_size, + # round up to the nearest value that is. + if self.max_seq_len % self.block_size != 0: + logger.warning("Note: self.max_seq_len mod self.block_size != 0") + exit(0) + + self.token_size = self.num_kv_heads * self.head_size + self.sequence_buffer_size = self.max_seq_len * self.token_size + self.sequence_buffer_bytes_size = self.sequence_buffer_size * self.dtype_size + self.cache_space_size = self.max_batch_size * self.sequence_buffer_size + self.cache_space_bytes_size = self.cache_space_size * self.dtype_size + + assert ( + self.cache_space_bytes_size) % self.block_bytes_size == 0, "cache_space_bytes_size must be divisible by block_bytes_size" + + self.cache_space_page_num = self.cache_space_bytes_size // self.block_bytes_size + + logger.info("CacheEngineDAttn basic info: { block_size: %d, dtype_size: %d, head_size: %d, " + "num_kv_heads: %d, max_seq_len: %d, max_batch_size: %d, num_layers: %d," + "token_size: %d, sequence_buffer_size: %d, cache_space_size: %d, " + "cache_space_bytes_size: %d, cache_space_page_num: %d }", + self.block_size, self.dtype_size, self.head_size, + self.num_kv_heads, self.max_seq_len, self.max_batch_size, self.num_layers, + self.token_size, self.sequence_buffer_size, self.cache_space_size, + self.cache_space_bytes_size, self.cache_space_page_num) + + self.device_cache_allocator = dattn.kvCacheAllocator(self.max_seq_len, self.num_layers, self.num_kv_heads, + self.head_size, self.block_size, self.dtype_size) + + # record the number of allocated blocks in a cache space for each request + self.allocated_block_counts = [0 for _ in range(self.max_batch_size)] + + # Get attention backend. + self.attn_backend = get_attn_backend( + model_config.get_num_attention_heads(parallel_config), + self.head_size, + self.num_kv_heads, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) + + self.kv_cache_ptrs = self._reserve_gpu_kv_cache() + self.gpu_cache = self._create_fake_kv_cache() + + """ + In dAttention's design, we are required to pass the layer index so + that CUDA kernel could use it to get the kv_cache. For other mechanisms, like + PagedAttention or vAttention, they are passing different kv_vache for different layers. + """ + + def _create_fake_kv_cache(self) -> List[torch.Tensor]: + fake_kv_caches = [] + + for i in range(self.num_layers): + fake_kv_caches.append(torch.tensor(i)) + + return fake_kv_caches + + def _reserve_gpu_kv_cache(self) -> List[int]: + kv_cache_ptrs = [] + + for i in range(self.max_batch_size): + kv_cache_ptrs.append(self.device_cache_allocator.reserve_cache_region(i)) + # print(f"i:{i}, virtual address:{hex(kv_cache[i])}") + + return kv_cache_ptrs + + def swap_in(self, src_to_dst: torch.Tensor) -> None: + raise NotImplementedError("swap_in is not implemented for DATTN now.") + + def swap_out(self, src_to_dst: torch.Tensor) -> None: + raise NotImplementedError("swap_out is not implemented for DATTN now.") + + # TODO: we need to implement the copy_blocks + def copy(self, src_to_dsts: torch.Tensor) -> None: + self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + # print(f"CacheEngineDAttn:head_size:{head_size}, num_heads:{num_heads}, num_attention_layers:{num_attention_layers}, block_size: {cache_config.block_size}, key_cache_block:{key_cache_block},total:{total/1024}KB") + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + # print(f"CacheEngineDAttn:cache_config.block_bytes_size:{dtype_size * total}") + return dtype_size * total + + """ + This function tried to allocate the physical pages for all requests + Initially, vmm patch will allocate physical pages for each request. + That is, each request will invoke from python to C++ library function directly. + + However, it is not a wise approach, as that can increase the overhead by 100X based on our experiments. + Instead, we should invoke c++ library function just once by passing an array with [req_id, new_blocks] + + Note that self.allocated_block_counts[buffer_id] will track the number of allocated blocks + in this function. Let's utilize the same mechanism at the first step. + TODO: we may change this later. To my understanding, it is better to track the number of blocks at sequence + """ + + def alloc_seqs(self, allocated_block_counts: Dict[int, int]): + to_alloc_blocks = [] + """Allocate cache handles for the given number of blocks.""" + for buffer_id, num_blocks in allocated_block_counts.items(): + allocated_blocks = self.allocated_block_counts[buffer_id] + num_blocks -= allocated_blocks + # print(f"CacheEngineDAttn: buffer_id-{buffer_id}, num_blocks:{num_blocks}") + if num_blocks > 0: + to_alloc_blocks.append([buffer_id, num_blocks]) + self.allocated_block_counts[buffer_id] += num_blocks + + # Allocate physical blocks for all requests. + self.device_cache_allocator.alloc_cache_blocks(to_alloc_blocks) + + def free_seqs(self, free_buffer_ids: List[int]): + """Free cache handles for the given buffer ids.""" + for req in free_buffer_ids: + print(f"BOOWWWW free_seqs with req:{req}") + self.device_cache_allocator.release_cache_regions(free_buffer_ids) + + +def _get_dtype_size(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() diff --git a/vllm/worker/cache_engine_vmm.py b/vllm/worker/cache_engine_vmm.py new file mode 100644 index 0000000000000..66f813966ec68 --- /dev/null +++ b/vllm/worker/cache_engine_vmm.py @@ -0,0 +1,250 @@ +"""CacheEngine class for managing the KV cache.""" +from typing import List, Dict, Tuple + +import torch + +from vllm.attention import get_attn_backend +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, DeviceConfig +from vllm.logger import init_logger +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available, TORCH_DTYPE_TO_STR_DTYPE + +from vllm import _vmm_ops as vmm + +logger = init_logger(__name__) + + +class CacheEngineVMM: + """Manages the KV cache. + + This class is responsible for initializing and managing the GPU and CPU KV + caches. It also provides methods for performing KV cache operations, such + as swapping and copying. + """ + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + ) -> None: + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + if self.device_config.device_type != "cuda": + raise RuntimeError("VMM only support cuda device.") + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + self.block_bytes_size = self.cache_config.block_bytes_size + self.num_gpu_blocks = cache_config.num_gpu_blocks + self.num_cpu_blocks = cache_config.num_cpu_blocks + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + self.dtype_size = get_dtype_size(self.dtype) + self.max_batch_size = self.scheduler_config.max_num_seqs + self.max_seq_len = self.scheduler_config.max_model_len + + # If max_seq_len is not divisible by block_size, + # round up to the nearest value that is. + if self.max_seq_len % self.block_size != 0: + self.max_seq_len = ((self.max_seq_len // self.block_size + 1) + * self.block_size) + logger.warning("self.max_seq_len mod self.block_size != 0,round up max_seq_len to %d", self.max_seq_len) + + self.token_size = self.num_kv_heads * self.head_size + self.sequence_buffer_size = self.max_seq_len * self.token_size + self.sequence_buffer_bytes_size = self.sequence_buffer_size * self.dtype_size + # sequence_buffer_bytes_size is the number of bytes for each sequence + self.cache_space_size = self.max_batch_size * self.sequence_buffer_size + # Maximum memory for all batched requests together for each layer. + self.cache_sapce_bytes_size = self.cache_space_size * self.dtype_size + + # block_bytes_size == 2M + assert ( + self.cache_sapce_bytes_size) % self.block_bytes_size == 0, "cache_sapce_bytes_size must be divisible by block_bytes_size" + + self.cache_space_page_num = self.cache_sapce_bytes_size // self.block_bytes_size + + logger.info("CacheEngineVMM basic info: { block_size: %d, dtype_size: %d, head_size: %d, " + "num_kv_heads: %d, max_seq_len: %d, max_batch_size: %d, num_layers: %d," + "token_size: %d, sequence_buffer_size: %d, cache_space_size: %d, " + "cache_sapce_bytes_size: %d, cache_space_page_num: %d }", + self.block_size, self.dtype_size, self.head_size, + self.num_kv_heads, self.max_seq_len, self.max_batch_size, self.num_layers, + self.token_size, self.sequence_buffer_size, self.cache_space_size, + self.cache_sapce_bytes_size, self.cache_space_page_num) + + self.device_cache_allocator = vmm.CacheAllocator() + # record the allocated handles for each buffer in a cache space + self.allocated_block_counts = [0 for _ in range(self.max_batch_size)] + + # Get attention backend. + self.attn_backend = get_attn_backend( + model_config.get_num_attention_heads(parallel_config), + self.head_size, + self.num_kv_heads, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) + + # Initialize the cache. + self.gpu_cache_ptr = self._reserve_gpu_kv_cache() + self.gpu_cache = self._init_gpu_kv_cache_tensor() + + # TODO: Implement CPU cache and swap + # self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") + + def _reserve_gpu_kv_cache(self) -> List[List[vmm.CacheDevicePtr]]: + kv_cache_ptrs = [] + for i in range(self.num_layers): + key_ptr = vmm.CacheDevicePtr() + value_ptr = vmm.CacheDevicePtr() + + if (self.device_cache_allocator.reserve_cache_ptr(key_ptr, self.cache_space_page_num) == 0) \ + and (self.device_cache_allocator.reserve_cache_ptr(value_ptr, self.cache_space_page_num) == 0): + kv_cache_ptrs.append([key_ptr, value_ptr]) + else: + raise RuntimeError("Failed to reserve cache ptr.") + + return kv_cache_ptrs + + def _init_gpu_kv_cache_tensor(self) -> List[List[torch.Tensor]]: + kv_cache: List[List[torch.Tensor]] = [] + + # self.alloc_one_seq(0, 1) + # We have to allocate one block for each ptr, otherwise wrap to tensor will fail + # Here we allocate one block for each sequence buffer of each ptr + alloc_dict[0] = 1 + self.alloc_seqs(alloc_dict) + + for i in range(self.num_layers): + _key_cache_ptr = self.gpu_cache_ptr[i][0] + _value_cache_ptr = self.gpu_cache_ptr[i][1] + + shape = (self.max_batch_size, self.max_seq_len, self.num_kv_heads, self.head_size) + dtype = TORCH_DTYPE_TO_STR_DTYPE[self.dtype] + key_cache_tensor: torch.Tensor = vmm.wrap_cache_ptr_to_tensor(_key_cache_ptr, dtype, shape) + value_cache_tensor: torch.Tensor = vmm.wrap_cache_ptr_to_tensor(_value_cache_ptr, dtype, shape) + + kv_cache.append([key_cache_tensor, value_cache_tensor]) + + return kv_cache + + def _allocate_kv_cache( + self, + num_blocks: int, + device: str = 'cpu', + ) -> List[torch.Tensor]: + """Allocates KV cache on the specified device.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + pin_memory = is_pin_memory_available() if device == "cpu" else False + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. + kv_cache.append( + torch.zeros(kv_cache_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device)) + return kv_cache + + def swap_in(self, src_to_dst: torch.Tensor) -> None: + raise NotImplementedError("swap_in is not implemented for VMM now.") + + def swap_out(self, src_to_dst: torch.Tensor) -> None: + raise NotImplementedError("swap_out is not implemented for VMM now.") + + def copy(self, src_to_dsts: torch.Tensor) -> None: + self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + return (cache_config.block_bytes_size * + model_config.get_num_layers(parallel_config) * 2) + # single block bytes size * layer num * 2 (key and value) + # return cache_config.block_bytes_size * model_config.get_num_layers(parallel_config) * 2 + + def alloc_seqs(self, allocated_block_counts: Dict[int, int]): + """Allocate cache handles for the given number of blocks.""" + for buffer_id, num_blocks in allocated_block_counts.items(): + allocated_blocks = self.allocated_block_counts[buffer_id] + num_blocks -= allocated_blocks + start_offset = buffer_id * self.sequence_buffer_bytes_size + if num_blocks > 0: + allocated_blocks = self.allocated_block_counts[buffer_id] + offset = (start_offset + allocated_blocks * self.block_bytes_size) + self.alloc_one_seq(buffer_id, num_blocks, offset) + + # But now, frequent frees are an overhead, so we don't do it. + # TODO: Reduced overhead or asynchronous free + # elif num_blocks < 0: # release the extra blocks + # offset = (start_offset + (allocated_blocks + num_blocks) * + # self.block_bytes_size) + # self.free_one_seq(buffer_id, -num_blocks, offset) + + def alloc_one_seq(self, buffer_id: int, num_blocks: int = 1, offset: int = 0): + """Allocate cache handles for the given number of blocks.""" + for i in range(self.num_layers): + _key_cache_ptr = self.gpu_cache_ptr[i][0] + _value_cache_ptr = self.gpu_cache_ptr[i][1] + + status1 = self.device_cache_allocator.alloc_cache_ptr(_key_cache_ptr, num_blocks, offset) + status2 = self.device_cache_allocator.alloc_cache_ptr(_value_cache_ptr, num_blocks, offset) + if status1 != 0 or status2 != 0: + logger.error("VMM Alloc: buffer_id: %d, num_blocks: %d, offset: %d", + buffer_id, num_blocks, offset) + raise RuntimeError(f"Failed to allocate cache handles. status1: {status1}, status2: {status2}") + + self.allocated_block_counts[buffer_id] += num_blocks + + def free_seqs(self, free_buffer_ids: List[int]): + """Free cache handles for the given buffer ids.""" + for buffer_id in free_buffer_ids: + num_blocks = self.allocated_block_counts[buffer_id] + offset = buffer_id * self.sequence_buffer_bytes_size + self.free_one_seq(buffer_id, num_blocks, offset) + + def free_one_seq(self, buffer_id: int, num_blocks: int = 0, offset: int = 0): + """Free cache handles for the given buffer id.""" + for i in range(self.num_layers): + _key_cache_ptr = self.gpu_cache_ptr[i][0] + _value_cache_ptr = self.gpu_cache_ptr[i][1] + + status1 = self.device_cache_allocator.release_cache_ptr( + _key_cache_ptr, num_blocks, offset) + status2 = self.device_cache_allocator.release_cache_ptr( + _value_cache_ptr, num_blocks, offset) + if status1 != 0 or status2 != 0: + logger.error("VMM Free: buffer_id: %d, num_blocks: %d, offset: %d", + buffer_id, num_blocks, offset) + raise RuntimeError( + f"Failed to free cache handles. status1: {status1}, status2: {status2}" + ) + # logger.info("VMM Free: buffer_id: %d, num_blocks: %d, offset: %d", + # buffer_id, num_blocks, offset) + self.allocated_block_counts[buffer_id] -= num_blocks + + +def _get_dtype_size(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e8c472df8b5fc..d0ccb849e1e4d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -994,12 +994,39 @@ def __init__( self.mm_registry = mm_registry self.multi_modal_input_mapper = mm_registry \ .create_input_mapper(model_config) + + + # TODO: we will soon support these features in VMM + self.use_vmm = cache_config.use_vmm + self.use_dattn = cache_config.use_dattn + if self.use_vmm: + if self.lora_config: + #TODO: + raise NotImplementedError("VMM is not supported with LoRA ") + if self.sliding_window: + #TODO: + raise NotImplementedError("VMM is not supported with sliding window") + if self.attn_backend.get_name() != "flash-attn": + raise NotImplementedError("VMM is only supported with flash-attn") + elif self.use_dattn: + if self.lora_config: + #TODO: + raise NotImplementedError("DATTN is not supported with LoRA ") + if self.sliding_window: + #TODO: + raise NotImplementedError("DATTN is not supported with sliding window") + #print(f"dAttention's current backend is {self.attn_backend.get_name()}") + self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization self.model: nn.Module # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + + + + self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None set_cpu_offload_max_bytes( @@ -1010,6 +1037,19 @@ def __init__( self.sampling_metadata_cache: SamplingMetadataCache = \ SamplingMetadataCache() + # Add dAttention support + self.kv_cache_ptrs: Optional[List[int]] = None + self.num_layers: int = 0 + + # Initialize kv_cache_ptrs when dAttention is used + def init_kv_cache_attribute(self, kv_cache_ptrs: List[int], block_size: int, num_layers: int) -> None: + self.kv_cache_ptrs = kv_cache_ptrs + self.block_size = block_size + self.num_layers = num_layers + + def _get_kv_ptr(self, index: int) -> int: + return self.kv_cache_ptrs[index] + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with CudaMemoryProfiler() as m: @@ -1587,6 +1627,11 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() + import sys + if kv_caches != None and kv_caches[0] != None: + print(f"GPUModelRunner with num_steps-{num_steps}, input_tokens:{len(model_input.input_tokens)}, kv_cache:{len(kv_caches)} - {kv_caches[0].size()}", file=sys.stderr) + else: + print(f"GPUModelRunner with num_steps-{num_steps}, input_tokens:{len(model_input.input_tokens)}, kv_cache:{len(kv_caches)}", file=sys.stderr) hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3851843afc960..ce1123d54b39c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -24,11 +24,15 @@ from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine +from vllm.worker.cache_engine_vmm import CacheEngineVMM +from vllm.worker.cache_engine_dattn import CacheEngineDAttn from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +import time +from vllm.logger import init_logger logger = init_logger(__name__) @@ -64,6 +68,8 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config + self.use_vmm = cache_config.use_vmm + self.use_dattn = cache_config.use_dattn self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -112,7 +118,8 @@ def __init__( ) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine: List[CacheEngine] + self.cache_engine: List[Union[CacheEngine, CacheEngineVMM, CacheEngineDAttn]] + # self.cache_engine: List[CacheEngine] # Initialize gpu_cache as embedding models don't initialize kv_caches self.gpu_cache: Optional[List[List[torch.Tensor]]] = None self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} @@ -267,6 +274,32 @@ def initialize_cache(self, num_gpu_blocks: int, def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None + + if self.use_vmm: # using VMM + self.cache_engine = [ + CacheEngineVMM(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + elif self.use_dattn: # Using DAttn + #print(f"NOOOOOW, before initialization of CacheEngineDAttn!") + self.cache_engine = [ + CacheEngineDAttn(self.cache_config, self.model_config, + self.parallel_config, self.scheduler_config, + self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + # Initialize kv_cache_ptrs immediately + for ve in range(self.parallel_config.pipeline_parallel_size): + self.model_runner.init_kv_cache_attribute(self.cache_engine[ve].kv_cache_ptrs, self.cache_engine[ve].block_size, self.cache_engine[ve].num_layers) + + else:# Not using VMM or VAttn + self.cache_engine = [ + CacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] self.cache_engine = [ CacheEngine(self.cache_config, self.model_config, self.parallel_config, self.device_config) @@ -313,12 +346,21 @@ def prepare_worker_input( device=self.device, dtype=torch.int64).view(-1, 2) + if self.use_vmm or self.use_dattn: + allocated_block_counts = execute_model_req.allocated_block_counts + free_buffer_ids = execute_model_req.free_buffer_ids + else: + allocated_block_counts = None + free_buffer_ids = None + return WorkerInput( num_seq_groups=num_seq_groups, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, + allocated_block_counts=allocated_block_counts, + free_buffer_ids=free_buffer_ids, num_steps=num_steps, ) @@ -338,6 +380,14 @@ def execute_worker(self, worker_input: WorkerInput) -> None: and worker_input.blocks_to_copy.numel() > 0): self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + if (self.use_vmm or self.use_dattn) and (worker_input.free_buffer_ids is not None and len(worker_input.free_buffer_ids) > 0): + self.cache_engine[virtual_engine].free_seqs( + worker_input.free_buffer_ids) + + if (self.use_vmm or self.use_dattn) and worker_input.allocated_block_counts is not None: + self.cache_engine[virtual_engine].alloc_seqs(worker_input.allocated_block_counts) + + def _get_cached_seq_group_metadata( self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, @@ -431,9 +481,24 @@ def vocab_size(self) -> int: def get_cache_block_size_bytes(self) -> int: """Get the size of the KV cache block size in bytes. """ - return CacheEngine.get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) + # return CacheEngine.get_cache_block_size(self.cache_config, + # self.model_config, + # self.parallel_config) + + if self.use_vmm: + return CacheEngineVMM.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + elif self.use_dattn: + return CacheEngineDAttn.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + + + else: + return CacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) def init_worker_distributed_environment( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6ba4f272315ce..c92414d0bb204 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -131,6 +131,10 @@ class WorkerInput: blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None virtual_engine: int = 0 + + allocated_block_counts : Optional[Dict[int, int]] = None # new add for vmm + free_buffer_ids: Optional[List[int]] = None + num_steps: int = 1 @classmethod @@ -148,6 +152,8 @@ def from_broadcasted_tensor_dict( blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), virtual_engine=tensor_dict["virtual_engine"], + allocated_block_counts=tensor_dict.pop("allocated_block_counts"), # new add for vmm + free_buffer_ids=tensor_dict.pop("free_buffer_ids"), num_steps=tensor_dict.pop("num_steps"), ) @@ -162,6 +168,9 @@ def as_broadcastable_tensor_dict( "blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_copy": self.blocks_to_copy, "virtual_engine": self.virtual_engine, + # new add for vmm and dattn + "allocated_block_counts": self.allocated_block_counts, + "free_buffer_ids": self.free_buffer_ids, "num_steps": self.num_steps, }