|
| 1 | +# Copyright 2023-2024 SGLang Team |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# |
| 6 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Unless required by applicable law or agreed to in writing, software |
| 9 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +# See the License for the specific language governing permissions and |
| 12 | +# limitations under the License. |
| 13 | +# ============================================================================== |
| 14 | +# Copyright 2024 Bytedance Ltd. and/or its affiliates |
| 15 | +# |
| 16 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 17 | +# you may not use this file except in compliance with the License. |
| 18 | +# You may obtain a copy of the License at |
| 19 | +# |
| 20 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 21 | +# |
| 22 | +# Unless required by applicable law or agreed to in writing, software |
| 23 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 24 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 25 | +# See the License for the specific language governing permissions and |
| 26 | +# limitations under the License. |
| 27 | + |
| 28 | +import os |
| 29 | +import torch |
| 30 | +from torch.distributed.device_mesh import init_device_mesh |
| 31 | + |
| 32 | +from sglang.srt.entrypoints.verl_engine import VerlEngine |
| 33 | + |
| 34 | +from transformers import AutoTokenizer, AutoModelForCausalLM |
| 35 | +from transformers import GenerationConfig |
| 36 | + |
| 37 | +from verl.utils.torch_functional import pad_sequence_to_length |
| 38 | + |
| 39 | + |
| 40 | +def levenshtein(s1, s2): |
| 41 | + m, n = len(s1), len(s2) |
| 42 | + # Initialize matrix of zeros |
| 43 | + dp = [[0] * (n + 1) for _ in range(m + 1)] |
| 44 | + # Initialize first column and first row of the matrix |
| 45 | + for i in range(m + 1): |
| 46 | + dp[i][0] = i # Deletion from s1 to empty string |
| 47 | + for j in range(n + 1): |
| 48 | + dp[0][j] = j # Insertion to s1 from empty string |
| 49 | + # Compute the Levenshtein distance matrix |
| 50 | + for i in range(1, m + 1): |
| 51 | + for j in range(1, n + 1): |
| 52 | + cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match |
| 53 | + dp[i][j] = min( |
| 54 | + dp[i - 1][j] + 1, # Deletion |
| 55 | + dp[i][j - 1] + 1, # Insertion |
| 56 | + dp[i - 1][j - 1] + cost # Substitution |
| 57 | + ) |
| 58 | + return dp[m][n] |
| 59 | + |
| 60 | + |
| 61 | +def are_lists_similar(a, b): |
| 62 | + if len(a) != len(b): |
| 63 | + print("The lists are of different lengths.") |
| 64 | + return False |
| 65 | + |
| 66 | + total_length = 0 |
| 67 | + total_diff = 0 |
| 68 | + |
| 69 | + for s1, s2 in zip(a, b): |
| 70 | + max_len = max(len(s1), len(s2)) |
| 71 | + total_length += max_len |
| 72 | + diff = levenshtein(s1, s2) |
| 73 | + total_diff += diff |
| 74 | + print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n") |
| 75 | + |
| 76 | + percentage_difference = (total_diff / total_length) * 100 |
| 77 | + print(f"Total difference: {percentage_difference:.2f}%") |
| 78 | + |
| 79 | + return percentage_difference <= 10 |
| 80 | + |
| 81 | + |
| 82 | +def initialize_global_process_group(timeout_second=36000): |
| 83 | + from datetime import timedelta |
| 84 | + |
| 85 | + import torch.distributed |
| 86 | + |
| 87 | + # NOTE MODIFIED should provide backend=None to have nccl+gloo |
| 88 | + # torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) |
| 89 | + torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second)) |
| 90 | + |
| 91 | + local_rank = int(os.environ["LOCAL_RANK"]) |
| 92 | + rank = int(os.environ["RANK"]) |
| 93 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 94 | + |
| 95 | + if torch.distributed.is_initialized(): |
| 96 | + torch.cuda.set_device(local_rank) |
| 97 | + return local_rank, rank, world_size |
| 98 | + |
| 99 | + |
| 100 | +def test_sglang_spmd(): |
| 101 | + assert torch.cuda.device_count() >= 2, 'At least 2 GPUs is required to run tp+dp tests.' |
| 102 | + initialize_global_process_group() |
| 103 | + # fill rollout config |
| 104 | + max_prompt_length = 16 |
| 105 | + max_response_length = 16 |
| 106 | + |
| 107 | + # Initialize model and token |
| 108 | + local_cache_path = '~/.cache/verl/rlhf' |
| 109 | + local_cache_path = os.path.expanduser(local_cache_path) |
| 110 | + hdfs_path = 'Qwen/Qwen2-7B-Instruct' |
| 111 | + from verl.utils.fs import copy_to_local |
| 112 | + local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) |
| 113 | + tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side='left') |
| 114 | + |
| 115 | + preencode_prompts = [ |
| 116 | + "Who won the Champions League in 2019?", |
| 117 | + "The founder of Apple is", |
| 118 | + "What's your name", |
| 119 | + ] |
| 120 | + tokenizer.pad_token = tokenizer.eos_token |
| 121 | + prompts = tokenizer(preencode_prompts, return_tensors='pt', padding=True) |
| 122 | + input_ids = prompts['input_ids'] |
| 123 | + attention_mask = prompts['attention_mask'] |
| 124 | + |
| 125 | + input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) |
| 126 | + attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) |
| 127 | + |
| 128 | + actor_model = AutoModelForCausalLM.from_pretrained(local_model_path) |
| 129 | + actor_model.to(torch.bfloat16) |
| 130 | + |
| 131 | + sampling_params = dict(n=1, |
| 132 | + temperature=0, |
| 133 | + top_p=1, |
| 134 | + top_k=-1, |
| 135 | + max_new_tokens=max_response_length, |
| 136 | + presence_penalty=0.0, |
| 137 | + frequency_penalty=0.0, |
| 138 | + repetition_penalty=1.0, |
| 139 | + skip_special_tokens=True, |
| 140 | + spaces_between_special_tokens=True, |
| 141 | + ignore_eos=False) |
| 142 | + |
| 143 | + tensor_parallel_size = 4 |
| 144 | + device_mesh_kwargs = dict(mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"]) |
| 145 | + inference_device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) |
| 146 | + |
| 147 | + for k in ["TORCHELASTIC_USE_AGENT_STORE"]: |
| 148 | + if k in os.environ: |
| 149 | + del os.environ[k] |
| 150 | + print('building sglang rollout engine') |
| 151 | + llm = VerlEngine(model_path=local_model_path, |
| 152 | + dtype="bfloat16", |
| 153 | + mem_fraction_static=0.5, |
| 154 | + device_mesh_cpu=inference_device_mesh_cpu["tp"], |
| 155 | + base_gpu_id=0, |
| 156 | + gpu_id_step=1) |
| 157 | + |
| 158 | + llm.release_memory_occupation() |
| 159 | + print("start generation") |
| 160 | + input_ids = input_ids.cuda() |
| 161 | + attention_mask = attention_mask.cuda() |
| 162 | + batch_size = input_ids.size(0) |
| 163 | + |
| 164 | + generation_config = GenerationConfig(do_sample=False) |
| 165 | + actor_model.cuda() |
| 166 | + output = actor_model.generate( |
| 167 | + input_ids=input_ids, |
| 168 | + attention_mask=attention_mask, |
| 169 | + max_new_tokens=max_response_length, |
| 170 | + # max_length=max_length, |
| 171 | + eos_token_id=tokenizer.eos_token_id, |
| 172 | + pad_token_id=tokenizer.pad_token_id, |
| 173 | + generation_config=generation_config, |
| 174 | + # renormalize_logits=True, |
| 175 | + output_scores=False, # this is potentially very large |
| 176 | + return_dict_in_generate=True, |
| 177 | + use_cache=False) # may OOM when use_cache = True |
| 178 | + seq = output.sequences |
| 179 | + response = seq[:, max_prompt_length:] |
| 180 | + |
| 181 | + hf_response_tokens = tokenizer.batch_decode(response) |
| 182 | + print(f"hf response: {hf_response_tokens}") |
| 183 | + print(f"{sampling_params=}") |
| 184 | + idx_list = [] |
| 185 | + batch_size = input_ids.shape[0] |
| 186 | + |
| 187 | + pad_token_id = (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id) |
| 188 | + for i in range(batch_size): |
| 189 | + idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) |
| 190 | + |
| 191 | + outputs = llm.generate(input_ids=idx_list, sampling_params=sampling_params) |
| 192 | + sglang_response_tokens = [] |
| 193 | + |
| 194 | + for output in outputs: |
| 195 | + print(f"{output=}") |
| 196 | + generated_text = output["text"] |
| 197 | + sglang_response_tokens.append(generated_text) |
| 198 | + |
| 199 | + print(f"sglang response: {sglang_response_tokens}") |
| 200 | + assert are_lists_similar(hf_response_tokens, sglang_response_tokens), \ |
| 201 | + f"Strings differ more than 10%:\n" |
| 202 | + print("Check Pass") |
| 203 | + |
| 204 | + |
| 205 | +def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor): |
| 206 | + # remove the left padding in the prompt token_id |
| 207 | + # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id |
| 208 | + non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] |
| 209 | + token_ids = prompt_token_ids[non_pad_index:].tolist() |
| 210 | + return token_ids |
0 commit comments