Skip to content

Commit

Permalink
Merge pull request #17 from yaoyic/upstream-master
Browse files Browse the repository at this point in the history
[BUGFIX] Synchronize CUDA context before force copying
  • Loading branch information
peastman authored Oct 7, 2020
2 parents 27e845c + 3a3b484 commit d00a0fe
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions platforms/cuda/src/CudaTorchKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,21 @@
#include "CudaTorchKernelSources.h"
#include "openmm/internal/ContextImpl.h"
#include <map>
#include <cuda_runtime_api.h>

using namespace TorchPlugin;
using namespace OpenMM;
using namespace std;

// macro for checking the result of synchronization operation on CUDA
// copied from `openmm/platforms/cuda/src/CudaParallelKernels.cpp`
#define CHECK_RESULT(result, prefix) \
if (result != CUDA_SUCCESS) { \
std::stringstream m; \
m<<prefix<<": "<<cu.getErrorString(result)<<" ("<<result<<")"<<" at "<<__FILE__<<":"<<__LINE__; \
throw OpenMMException(m.str());\
}

CudaCalcTorchForceKernel::~CudaCalcTorchForceKernel() {
}

Expand Down Expand Up @@ -80,11 +90,17 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
vector<torch::jit::IValue> inputs = {posTensor};
if (usePeriodic)
inputs.push_back(boxTensor);
// synchronizing the current context before switching to PyTorch
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context");
torch::Tensor energyTensor = module.forward(inputs).toTensor();
if (includeForces) {
energyTensor.backward();
// Note: "forceTensor" needs to be cloned due to a shared context (https://github.com/openmm/openmm-torch/issues/13)
torch::Tensor forceTensor = posTensor.grad().clone();
// make sure that all calculations on PyTorch side is properly finished before changing CUDA context or starting the `addForcesKernel` of this plugin
// cudaDeviceSynchronize(); // synchronizing the whole device is not necessary and may even cause problem
// synchronizing the current context and check the return status
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context");
cu.setAsCurrent();
void* data;
if (cu.getUseDoublePrecision()) {
Expand Down

0 comments on commit d00a0fe

Please sign in to comment.