Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def create_dir_symlink(src, dest):
nccl_version = ".".join(str(torch.cuda.nccl.version())[:2])
else:
nccl_version = ".".join(map(str, torch.cuda.nccl.version()[:2]))
if hasattr(torch.cuda, 'is_bf16_supported'):
if hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_available():
bf16_support = torch.cuda.is_bf16_supported()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bf16_support value is only used in unit testing - perhaps we could add a try...except block around this?

Copy link
Contributor Author

@aphedges aphedges Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first commit would be sufficient to prevent the crash because it checks whether CUDA is available.

However, that doesn't solve the problem of building and running on different hardware with different bf16 support.

Copy link
Contributor

@mrwyattii mrwyattii Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bf16 support check is only utilized in unit tests - so it shouldn't affect users that build/run on different hardware

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's set up that way in CI, but if a developer runs the unit tests on a different machine than the build, it could still cause problems.

if torch_available and hasattr(torch.version, 'hip') and torch.version.hip is not None:
hip_version = ".".join(torch.version.hip.split('.')[:2])
Expand Down