Skip to content

Commit

Permalink
Improved coordination of CUDA contexts with PyTorch (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman authored Nov 9, 2021
1 parent bca1b1d commit 02f9935
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions platforms/cuda/src/CudaTorchKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2018-2020 Stanford University and the Authors. *
* Portions copyright (c) 2018-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
Expand All @@ -31,6 +31,7 @@

#include "CudaTorchKernels.h"
#include "CudaTorchKernelSources.h"
#include "openmm/common/ContextSelector.h"
#include "openmm/internal/ContextImpl.h"
#include <map>
#include <cuda_runtime_api.h>
Expand Down Expand Up @@ -67,7 +68,7 @@ void CudaCalcTorchForceKernel::initialize(const System& system, const TorchForce

// Inititalize CUDA objects.

cu.setAsCurrent();
ContextSelector selector(cu);
map<string, string> defines;
CUmodule program = cu.createModule(CudaTorchKernelSources::torchForce, defines);
copyInputsKernel = cu.getKernel(program, "copyInputs");
Expand All @@ -86,9 +87,12 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
posData = posTensor.data_ptr<float>();
boxData = boxTensor.data_ptr<float>();
}
void* inputArgs[] = {&posData, &boxData, &cu.getPosq().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(),
&numParticles, cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer()};
cu.executeKernel(copyInputsKernel, inputArgs, numParticles);
{
ContextSelector selector(cu);
void* inputArgs[] = {&posData, &boxData, &cu.getPosq().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(),
&numParticles, cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer()};
cu.executeKernel(copyInputsKernel, inputArgs, numParticles);
}
vector<torch::jit::IValue> inputs = {posTensor};
if (usePeriodic)
inputs.push_back(boxTensor);
Expand All @@ -105,7 +109,6 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
// cudaDeviceSynchronize(); // synchronizing the whole device is not necessary and may even cause problem
// synchronizing the current context and check the return status
CHECK_RESULT(cuCtxSynchronize(), "Error synchronizing CUDA context");
cu.setAsCurrent();
void* data;
if (cu.getUseDoublePrecision()) {
if (!(forceTensor.dtype() == torch::kFloat64))
Expand All @@ -118,8 +121,11 @@ double CudaCalcTorchForceKernel::execute(ContextImpl& context, bool includeForce
data = forceTensor.data_ptr<float>();
}
int paddedNumAtoms = cu.getPaddedNumAtoms();
void* forceArgs[] = {&data, &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms};
cu.executeKernel(addForcesKernel, forceArgs, numParticles);
{
ContextSelector selector(cu);
void* forceArgs[] = {&data, &cu.getForce().getDevicePointer(), &cu.getAtomIndexArray().getDevicePointer(), &numParticles, &paddedNumAtoms};
cu.executeKernel(addForcesKernel, forceArgs, numParticles);
}
posTensor.grad().zero_();
}
return energyTensor.item<double>();
Expand Down

0 comments on commit 02f9935

Please sign in to comment.