Skip to content

Commit

Permalink
Add --concurrency option
Browse files Browse the repository at this point in the history
Minor improvements to help text

Clean up bounded_parallel_map function a bit
  • Loading branch information
KerfuffleV2 committed Aug 24, 2023
1 parent 0ddeeba commit 655990c
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ARCH=gguf.MODEL_ARCH.LLAMA
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]

DEFAULT_CONCURRENCY = 8
#
# data types
#
Expand Down Expand Up @@ -717,21 +718,21 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
with factory(max_workers = max_workers) as executor:
futures: List[concurrent.futures.Future[Out]] = []
done = False
for i in range(concurrency):
for _ in range(concurrency):
try:
nexti = next(iterable)
futures.append(executor.submit(func, next(iterable)))
except StopIteration:
done = True
break
futures.append(executor.submit(func, nexti))
while not done or futures:

while futures:
result = futures.pop(0).result()
while len(futures) < concurrency:
while not done and len(futures) < concurrency:
try:
nexti = next(iterable)
futures.append(executor.submit(func, next(iterable)))
except StopIteration:
done = True
break
futures.append(executor.submit(func, nexti))
yield result

def check_vocab_size(params: Params, vocab: Vocab) -> None:
Expand Down Expand Up @@ -850,13 +851,13 @@ def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray):
return (lazy_tensor.data_type, tensor.ndarray)

@staticmethod
def maybe_do_quant(item: Tuple[DataType, NDArray]) -> NDArray:
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
if item[0] == DT_Q8_0:
return quantize_array_q8_0(item[1])
return item[1]

@staticmethod
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab) -> None:
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
check_vocab_size(params, vocab)

of = OutputFile(fname_out)
Expand All @@ -873,11 +874,11 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM
of.write_tensor_info()

# tensor data
ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = 8)
ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
if ftype == GGMLFileType.MostlyQ8_0:
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quant, ndarrays, concurrency = 8, max_workers = 8, factory = ProcessPoolExecutor)
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
else:
ndarrays = map(OutputFile.maybe_do_quant, ndarrays)
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays)

start = time.time()
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
Expand Down Expand Up @@ -1073,12 +1074,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format (default: based on input)")
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
args = parser.parse_args(args_in)

if args.dump_single:
Expand Down Expand Up @@ -1132,7 +1134,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
params.ftype = ftype
print(f"Writing {outfile}, format {ftype}")

OutputFile.write_all(outfile, ftype, params, model, vocab)
OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
print(f"Wrote {outfile}")


Expand Down

0 comments on commit 655990c

Please sign in to comment.