Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiprocess inference warning: ignoring nprocs #7633

Open
CutieQing opened this issue Jul 4, 2024 · 2 comments
Open

Multiprocess inference warning: ignoring nprocs #7633

CutieQing opened this issue Jul 4, 2024 · 2 comments

Comments

@CutieQing
Copy link

❓ Questions and Help

When I made multiprocess inference of huggingface transformers frame, I used xmp.spawn(perform_inference, args=(args,), nprocs=4), and I wanted to run 4 scripts once. However, it reported a warning that WARNING:root:Unsupported nprocs (4), ignoring... I wonder if it is a bug or it has any mistake in my infer script.

My infer script is as following:

device = xm.xla_device()
print(f"tpu name: {device}")

sentences = ["Sample-1", "Sample-2"] * args.batch_size
print(f"sentences length: {len(sentences)}")

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = AutoModel.from_pretrained(args.model_name_or_path).to(device)
model.eval()

for i in range(20):
    if i == 19:
        print(f"log port: {port}")
        xp.trace_detached(f'localhost:{port}', './profiles/', duration_ms=2000)
    with xp.StepTrace('bge_test'):
        with xp.Trace('build_graph'):
            encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
            with torch.no_grad():
                start = time.perf_counter()
                model_output = model(**encoded_input)
                end = time.perf_counter()
                sentence_embeddings = model_output[0][:, 0]
                print("inference time:", (end - start))

sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
print("Sentence embeddings: ", sentence_embeddings)

if name == "main":
torch.set_default_dtype(torch.float32)
args = get_args()

xmp.spawn(perform_inference, args=(args,), nprocs=4)

detail log

WARNING:root:Unsupported nprocs (4), ignoring...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720080892.528224 2908632 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720080892.528293 2908632 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720080892.528300 2908632 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720080892.544289 2908627 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720080892.544426 2908627 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720080892.544434 2908627 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720080892.728254 2908631 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720080892.728326 2908631 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720080892.728332 2908631 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720080892.916441 2908634 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720080892.916616 2908634 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720080892.916625 2908634 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720080893.409535 2908636 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720080893.409646 2908636 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720080893.409654 2908636 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720080893.658751 2908630 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720080893.658883 2908630 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720080893.658891 2908630 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720080893.659256 2908635 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1720080893.659285 2908633 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1720080893.659431 2908635 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720080893.659440 2908635 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
I0000 00:00:1720080893.659455 2908633 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1720080893.659465 2908633 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46.
port: 52003
tpu name: xla:0
sentences length: 16384
port: 40841
tpu name: xla:0
sentences length: 16384
port: 40729
tpu name: xla:0
sentences length: 16384
port: 51387
tpu name: xla:0
sentences length: 16384
port: 53707
tpu name: xla:0
sentences length: 16384
port: 45223
tpu name: xla:0
sentences length: 16384
port: 37585
tpu name: xla:0
sentences length: 16384
port: 36559
tpu name: xla:0
sentences length: 16384
inference time: 0.034876358113251626
inference time: 0.03664895799010992
inference time: 0.026097089052200317
inference time: 0.02792046801187098
inference time: 0.02882425906136632
inference time: 0.029096698039211333
inference time: 0.02789105800911784
inference time: 0.027401939034461975
inference time: 0.014182109967805445
inference time: 0.013394199078902602
inference time: 0.013075169990770519
inference time: 0.012977780075743794
inference time: 0.01341874001082033
...

@BitPhinix
Copy link
Contributor

BitPhinix commented Jul 5, 2024

This is expected. Torch xla either supports running on one device (nprocs = 1) or all devices. If you pass anything other than nprocs = 1, it will be ignored and fall back to running on all available devices.

See

if nprocs == 1:
return _run_singleprocess(spawn_fn)
elif nprocs is not None:
logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs)
run_multiprocess(spawn_fn, start_method=start_method)

@BitPhinix
Copy link
Contributor

Although I guess ideally, it shouldn't warn if nprocs == num devices. https://pytorch.org/xla/release/1.6/index.html#torch_xla.distributed.xla_multiprocessing.spawn is a bit unclear on this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants