Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

context add generator #39475

Merged
merged 2 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/framework/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) {
}
}

GeneratorState Generator::GetState() {
pten::Generator::GeneratorState Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mu_);
state_.cpu_engine = *engine_;
return this->state_;
}

void Generator::SetState(const GeneratorState& state) {
void Generator::SetState(const pten::Generator::GeneratorState& state) {
std::lock_guard<std::mutex> lock(this->mu_);
this->state_ = state;
this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine);
Expand Down
17 changes: 6 additions & 11 deletions paddle/fluid/framework/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ limitations under the License. */
#include <typeinfo>
#include <utility>

#include "paddle/pten/core/generator.h"

namespace paddle {
namespace framework {

Expand All @@ -34,14 +36,7 @@ static uint64_t GetRandomSeed() {
return ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
}

struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};

struct Generator {
struct Generator : public pten::Generator {
Generator() {
auto seed = GetRandomSeed();
std::seed_seq seq({seed});
Expand Down Expand Up @@ -82,9 +77,9 @@ struct Generator {
Generator(const Generator& other) = delete;

// get random state
GeneratorState GetState();
pten::Generator::GeneratorState GetState();
// set random state
void SetState(const GeneratorState&);
void SetState(const pten::Generator::GeneratorState&);
// get current seed
uint64_t GetCurrentSeed();
// random a seed and get
Expand All @@ -105,7 +100,7 @@ struct Generator {
uint64_t get_device_id() { return this->state_.device; }

private:
GeneratorState state_;
pten::Generator::GeneratorState state_;
std::shared_ptr<std::mt19937_64> engine_;
mutable std::mutex mu_;

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ cc_library(init SRCS init.cc DEPS device_context custom_kernel)
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc DEPS simple_threadpool malloc xxhash ${STREAM_CALLBACK_DEPS}
place pten_place eigen3 stringpiece cpu_helper cpu_info framework_proto ${IPU_CTX_DEPS} ${GPU_CTX_DEPS} ${NPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context)
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS} ${MLU_CTX_DEPS} eigen3 cpu_context generator)
if(WITH_XPU)
target_link_libraries(device_context xpu_context)
endif()
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#endif
#include "glog/logging.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/profiler.h"

Expand Down Expand Up @@ -160,11 +161,14 @@ inline void EmplaceDeviceContext(
.GetAllocator(p, cuda_ctx->stream())
.get());
cuda_ctx->PartialInitWithAllocator();
dev_ctx->SetGenerator(
framework::GetDefaultCUDAGenerator(p.GetDeviceId()).get());
#endif
} else {
dev_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance()
.GetAllocator(p)
.get());
dev_ctx->SetGenerator(framework::DefaultCPUGenerator().get());
}
dev_ctx->SetHostAllocator(
memory::allocation::AllocatorFacade::Instance()
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/pybind/generator_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/generator.h"
#include <fcntl.h>

#ifdef _POSIX_C_SOURCE
Expand All @@ -31,10 +32,11 @@ namespace paddle {
namespace pybind {
void BindGenerator(py::module* m_ptr) {
auto& m = *m_ptr;
py::class_<framework::GeneratorState,
std::shared_ptr<framework::GeneratorState>>(m, "GeneratorState")
py::class_<pten::Generator::GeneratorState,
std::shared_ptr<pten::Generator::GeneratorState>>(m,
"GeneratorState")
.def("current_seed",
[](std::shared_ptr<framework::GeneratorState>& self) {
[](std::shared_ptr<pten::Generator::GeneratorState>& self) {
return self->current_seed;
});
py::class_<std::mt19937_64>(m, "mt19937_64", "");
Expand Down
21 changes: 21 additions & 0 deletions paddle/pten/core/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,27 @@ struct DeviceContext::Impl {
return static_cast<T*>(HostAlloc(tensor, dtype, requested_size));
}

void SetGenerator(Generator* gen) {
PADDLE_ENFORCE_NOT_NULL(
gen,
pten::errors::InvalidArgument(
"Required generator shall not be nullptr, but received nullptr."));
generator_ = gen;
}

Generator* GetGenerator() const {
PADDLE_ENFORCE_NOT_NULL(
generator_,
pten::errors::InvalidArgument("Required generator_ shall not be "
"nullptr, but received nullptr."));
return generator_;
}

private:
const Allocator* device_allocator_{nullptr};
const Allocator* host_allocator_{nullptr};
const Allocator* zero_allocator_{nullptr};
Generator* generator_{nullptr};
};

DeviceContext::DeviceContext() { impl_ = std::make_unique<Impl>(); }
Expand Down Expand Up @@ -201,4 +218,8 @@ DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(::paddle::experimental::complex128)

#undef DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION

void DeviceContext::SetGenerator(Generator* gen) { impl_->SetGenerator(gen); }

Generator* DeviceContext::GetGenerator() const { return impl_->GetGenerator(); }

} // namespace pten
20 changes: 15 additions & 5 deletions paddle/pten/core/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ limitations under the License. */

#include <memory>

// TODO(wilber): Do we need to use place in pten kernel?
#include "paddle/pten/common/place.h"

#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/allocator.h"
#include "paddle/pten/core/generator.h"

namespace pten {
class TensorBase;
Expand Down Expand Up @@ -112,13 +111,24 @@ class DeviceContext {
template <typename T>
T* HostAlloc(TensorBase* tensor, size_t requested_size = 0) const;

// TODO(wilber): Just for the convenience of migrating the code, it will be
// modified or removed later.
virtual const Place& GetPlace() const = 0;
// TODO(wilber): The fluid framework uses wait() in many places, how to delete
// this API interface.
virtual void Wait() const {}

/**
* @brief Set the generator for special op.
*
* @param Generator
*/
void SetGenerator(Generator*);
/**
* @brief Get the generator object.
*
* @return Generator
*/
Generator* GetGenerator() const;

private:
struct Impl;
std::unique_ptr<Impl> impl_;
Expand Down
61 changes: 61 additions & 0 deletions paddle/pten/core/generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cstdint>
#include <memory>
#include <random>

namespace pten {

class Generator {
public:
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};

virtual ~Generator() = default;

// get random state
virtual GeneratorState GetState() = 0;
// set random state
virtual void SetState(const GeneratorState&) = 0;
// get current seed
virtual uint64_t GetCurrentSeed() = 0;
// random a seed and get
virtual uint64_t Seed() = 0;
// set seed
virtual void SetCurrentSeed(uint64_t seed) = 0;
// get cpu engine
virtual std::shared_ptr<std::mt19937_64> GetCPUEngine() = 0;
// set cpu engine
virtual void SetCPUEngine(std::shared_ptr<std::mt19937_64>) = 0;
virtual uint64_t Random64() = 0;
virtual std::pair<uint64_t, uint64_t> IncrementOffset(
uint64_t increament_offset) = 0;

// NOTE(zhiqiu): is_init_py_ is used to make generator be compatible with
// old seed, and it should be removed after all random-related operators
// and unittests upgrades to use generator.
virtual void SetIsInitPy(bool) = 0;
virtual bool GetIsInitPy() const = 0;

virtual uint64_t get_device_id() = 0;
};

} // namespace pten