Skip to content

Add default_dtype option #537

@radka-j

Description

@radka-j

PyTorch by default creates tensors with type float32 but we have run into issues where a method returned a tensor of float64 and this then threw errors when it was incompatible with the other tensors.

As discussed in #520 see comment by @sgreenbury:

I think this would be great to potentially have a default constant somewhere (e.g. DEFAULT_DTYPE: torch.dtype = torch.float32) that we can import to use such as here. We might also consider adding API to e.g. the base simulator/emulators so that dtype= can be provided.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Status

📋 Product backlog

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions