Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update real_accelerator.py #6845

Merged
merged 8 commits into from
Dec 14, 2024
Merged

Update real_accelerator.py #6845

merged 8 commits into from
Dec 14, 2024

Conversation

keiwoo
Copy link
Contributor

@keiwoo keiwoo commented Dec 10, 2024

Comment out or delete accelerate_name="cpu" when xpu is not detected.

When xpu is not detected it just pass at lines from 68 to 74 if DS_ACCELERATOR is set. However, cpu is assigned to accelerate_name if it cannot import intel_extension_for_pytorch or find xpu, namely, at line from 125 to 133 whenDS_ACCELERATOR is not set.

I found this problem yesterday and spent whole afternoon figuring it out. I got intel_extension_for_pytorch installed with other package which I do not use actually and have no idea about this. Then I found that it cpu is assigned to accelerate_name directly if it cannot find xpu and it affects cuda detection. In fact, cpu will be assigned finally if cuda is even not detected at line from 170 to 177.

@loadams loadams requested a review from delock December 10, 2024 16:29
@loadams
Copy link
Contributor

loadams commented Dec 10, 2024

Hi @keiwoo - the goal of this file is to detect what accelerator you have, unless you set it with the DS_ACCELERATOR environment variable. These lines are only executed if this is false: ipex._C._has_xpu() - in that case, if the user does have intel_extensions_for_pytorch, but no XPU listed, what accelerator do we have?

Could you clarify if when you first ran this if you had intel_extension_for_pytorch installed as well as what the other package was?

Tagging @Liangliang-Ma from the XPU team as well.

@loadams loadams self-requested a review December 10, 2024 16:43
@keiwoo
Copy link
Contributor Author

keiwoo commented Dec 11, 2024

hey @loadams, thanks for your review. I totally understand the goal of this file. We just skip the part of detection with the DS_ACCELERATOR environment variable. What you said was that

if the user does have intel_extensions_for_pytorch, but no XPU listed, what accelerator do we have?

I suggest that we can just do nothing just and keep it staying with None as it detects other accelerator below. I will show you in at line from 141 to 149.

        if accelerator_name is None:
            try:
                import torch.mps

                # should use torch.mps.is_available() if it exists someday but this is used as proxy
                torch.mps.current_allocated_memory()
                accelerator_name = "mps"
            except (RuntimeError, ImportError) as e:
                pass

In this case, we will always have torch installed, absolutly. But what accelerator do we have when torch.mps.current_allocated_memory() returns error? Of course we pass and still have accelerator_name = None until the code run line 170 to detect cuda. As you can see, if all accelerators are not detected, line 177 will return accelerator_name = "cpu" finally as I mantioned before.

                if torch.cuda.is_available():  #ignore-cuda
                    accelerator_name = "cuda"
                else:
                    if accel_logger is not None:
                        accel_logger.warn(
                            "Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it."
                        )
                    accelerator_name = "cpu"

Hope everything is explained well and I will test which package install intel_extensions_for_pytorch simultaneously.

@keiwoo
Copy link
Contributor Author

keiwoo commented Dec 11, 2024

I found the reason why intel_extension_for_pytorch was installed. It is the dependency for compiling bitsandbytes from source. link with highlight

@loadams
Copy link
Contributor

loadams commented Dec 11, 2024

I found the reason why intel_extension_for_pytorch was installed. It is the dependency for compiling bitsandbytes from source. link with highlight

Hi @keiwoo - I see, I misunderstood and thought you were using an XPU but were having issues detecting it, you are using another accelerator and because intel_extension_for_pytorch is installed, you're getting into this part of the file when that's undesirable - is that correct?

@tjruwase
Copy link
Contributor

@keiwoo, thanks for your work here. I agree to avoiding using cpu as fallback for a specific accelerator. Although, your PR addresses the xpu cause, I think the cuda case should also be removed. And the selection of cpu added as catch-all for when accelerator detection fails (i.e., accelerator_name==None), around here.

What do you think?

Avoid using cpu as fallback for a specific accelerator and the selection of cpu added as catch-all  when accelerator detection fails
@keiwoo
Copy link
Contributor Author

keiwoo commented Dec 12, 2024

@microsoft-github-policy-service agree

Copy link
Contributor Author

@keiwoo keiwoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with you @tjruwase. That would be more clear and I have made some revision. How about that?

@delock
Copy link
Collaborator

delock commented Dec 13, 2024

@keiwoo Thanks for this PR. Yes I think it make sense to set accelerator only when no other GPU can be found, your PR makes this intention clear. Currently for accelerator selection there are four different hints:

  1. Existance of extension suggest use this accelerator (npu, hpu, mps)
  2. Existance of extension + device detection suggest use this accelerator (xpu)
  3. no extension needed, device detection (cuda)
  4. Use this accelerator when no other accelerator selected (cpu)

I think eventurally all accelerators may need device detection to simplify environemnt management in hybrid cloud, the change to cpu detection in this PR conforms with this goal.

Copy link
Collaborator

@delock delock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@loadams loadams merged commit fc7c070 into microsoft:master Dec 14, 2024
11 checks passed
loadams added a commit that referenced this pull request Dec 17, 2024
loadams added a commit that referenced this pull request Dec 18, 2024
…or whl building) (#6886)

This fixes a bug introduced in #6845, which breaks the `no-torch`
workflow that we require in order to do releases where we do not require
torch to be in the environment when building an sdist. This adds the
same logic to the cpuaccelerator that the cudaaccelerator had where we
don't require torch to be installed to build the whl.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants