Skip to content

Commit

Permalink
Training demo (pytorch#5445)
Browse files Browse the repository at this point in the history
Summary:
Allows the XOR model training demo to be runnable in OSS. Will follow up with a documentation PR about training and how to run this demo.

Im sure my cmakelist.txt changes have issues so if anyone sees ways to improve them please let me know.

Only hack I had to do was the optimizer was calling an ET op directly. I don't think we have enabled this in OSS yet so I will follow up with larryliu0820 when hes back and in the meantime open up an issue.

Repro of demo:

Pull Request resolved: pytorch#5445

Test Plan:
python3 extension/training/examples/XOR/export_model.py --outdir /tmp/xor

rm -rf cmake-out

mkdir cmake-out

cmake \
    -DCMAKE_INSTALL_PREFIX=cmake-out \
    -DCMAKE_BUILD_TYPE=Release \
    -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
    -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
    -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
    -DEXECUTORCH_BUILD_EXTENSION_TRAINING=ON \
    -DEXECUTORCH_ENABLE_LOGGING=ON \
    -DPYTHON_EXECUTABLE=python \
    -Bcmake-out .

cmake --build cmake-out -j9 --target install --config Release

./cmake-out/extension/training/train_xor --model_path=/tmp/xor/xor.pte

Reviewed By: dvorjackz

Differential Revision: D62905840

Pulled By: JacobSzwejbka

fbshipit-source-id: 622e68637ee7a0bb1b323e777d60e9516be115cd
  • Loading branch information
JacobSzwejbka authored and facebook-github-bot committed Sep 18, 2024
1 parent 53c1a5f commit 26c736e
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 84 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ option(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL "Build the Runner Util extension"

option(EXECUTORCH_BUILD_EXTENSION_TENSOR "Build the Tensor extension" OFF)

option(EXECUTORCH_BUILD_EXTENSION_TRAINING "Build the training extension" OFF)

option(EXECUTORCH_BUILD_GTESTS "Build googletest based test binaries" OFF)

option(EXECUTORCH_BUILD_MPS "Build the MPS backend" OFF)
Expand Down Expand Up @@ -636,6 +638,10 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
endif()

if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
endif()

if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
endif()
Expand Down
3 changes: 3 additions & 0 deletions build/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ function(executorch_print_configuration_summary)
message(STATUS " EXECUTORCH_BUILD_EXTENSION_TENSOR : "
"${EXECUTORCH_BUILD_EXTENSION_TENSOR}"
)
message(STATUS " EXECUTORCH_BUILD_EXTENSION_TRAINING : "
"${EXECUTORCH_BUILD_EXTENSION_TRAINING}"
)
message(
STATUS
" EXECUTORCH_BUILD_FLATC : ${EXECUTORCH_BUILD_FLATC}"
Expand Down
28 changes: 28 additions & 0 deletions build/cmake_deps.toml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,34 @@ deps = [
"executorch",
"executorch_no_prim_ops",
]

[targets.extension_training]
buck_targets = [
"//extension/training/module:training_module",
"//extension/training/optimizer:sgd",
]
filters = [
".cpp$",
]
deps = [
"executorch_no_prim_ops",
]

[targets.train_xor]
buck_targets = [
"//extension/training/examples/XOR:train_xor",
]
filters = [
".cpp$",
]
excludes = [
"^codegen",
]
deps = [
"executorch",
"executorch_no_prim_ops",
"portable_kernels",
]
# ---------------------------------- extension end ----------------------------------
# ---------------------------------- binary start ----------------------------------

Expand Down
1 change: 1 addition & 0 deletions build/executorch-config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ set(lib_list
extension_runner_util
extension_tensor
extension_threadpool
extension_training
xnnpack_backend
XNNPACK
cpuinfo
Expand Down
49 changes: 49 additions & 0 deletions extension/training/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Please this file formatted by running:
# ~~~
# cmake-format -i CMakeLists.txt
# ~~~

cmake_minimum_required(VERSION 3.19)

# Source root directory for executorch.
if(NOT EXECUTORCH_ROOT)
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
endif()

list(TRANSFORM _extension_training__srcs PREPEND "${EXECUTORCH_ROOT}/")

add_library(extension_training ${_extension_training__srcs})
target_include_directories(
extension_training PUBLIC ${_common_include_directories}
)

target_include_directories(extension_training PUBLIC ${EXECUTORCH_ROOT}/..)
target_compile_options(extension_training PUBLIC ${_common_compile_options})
target_link_libraries(extension_training executorch_no_prim_ops
extension_data_loader extension_module extension_tensor)


list(TRANSFORM _train_xor__srcs PREPEND "${EXECUTORCH_ROOT}/")
add_executable(train_xor ${_train_xor__srcs})
target_include_directories(
train_xor PUBLIC ${_common_include_directories}
)
target_link_libraries(
train_xor gflags executorch_no_prim_ops portable_ops_lib extension_tensor
extension_training program_schema
)
target_compile_options(train_xor PUBLIC ${_common_compile_options})

# Install libraries
install(
TARGETS extension_training
DESTINATION lib
INCLUDES
DESTINATION ${_common_include_directories}
)
Empty file added extension/training/__init__.py
Empty file.
32 changes: 0 additions & 32 deletions extension/training/examples/XOR/TARGETS
Original file line number Diff line number Diff line change
@@ -1,40 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()

python_library(
name = "model",
srcs = ["model.py"],
visibility = [], # Private
deps = [
"//caffe2:torch",
],
)

python_library(
name = "export_model_lib",
srcs = ["export_model_lib.py"],
visibility = [],
deps = [
":model",
"//caffe2:torch",
"//executorch/exir:lib",
],
)

python_binary(
name = "export_model",
main_function = ".export_model.main",
main_src = "export_model.py",
deps = [
":export_model_lib",
"//caffe2:torch",
],
)
29 changes: 27 additions & 2 deletions extension/training/examples/XOR/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@

import argparse

import os

import torch
from executorch.exir import to_edge

from .export_model_lib import export_model
from executorch.extension.training.examples.XOR.model import Net, TrainingNet
from torch.export._trace import _export
from torch.export.experimental import _export_forward_backward


def main() -> None:
Expand All @@ -26,7 +31,27 @@ def main() -> None:
help="Path to the directory to write xor.pte files to",
)
args = parser.parse_args()
export_model(args.outdir)

net = TrainingNet(Net())
x = torch.randn(1, 2)

# Captures the forward graph. The graph will look similar to the model definition now.
# Will move to export_for_training soon which is the api planned to be supported in the long term.
ep = _export(net, (x, torch.ones(1, dtype=torch.int64)), pre_dispatch=True)
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
ep = _export_forward_backward(ep)
# Lower the graph to edge dialect.
ep = to_edge(ep)
# Lower the graph to executorch.
ep = ep.to_executorch()

# Write out the .pte file.
os.makedirs(args.outdir, exist_ok=True)
outfile = os.path.join(args.outdir, "xor.pte")
with open(outfile, "wb") as fp:
fp.write(
ep.buffer,
)


if __name__ == "__main__":
Expand Down
28 changes: 28 additions & 0 deletions extension/training/examples/XOR/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,31 @@ def define_common_targets():
external_deps = ["gflags"],
define_static_target = True,
)

runtime.python_library(
name = "model",
srcs = ["model.py"],
visibility = [], # Private
deps = [
"//caffe2:torch",
],
)

runtime.python_library(
name = "export_model_lib",
srcs = ["export_model_lib.py", "export_model.py"],
visibility = [],
deps = [
":model",
"//caffe2:torch",
"//executorch/exir:lib",
],
)

runtime.python_binary(
name = "export_model",
main_module = "executorch.extension.training.examples.XOR.export_model",
deps = [
":export_model_lib",
],
)
8 changes: 4 additions & 4 deletions extension/training/examples/XOR/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ int main(int argc, char** argv) {
data_set;
data_set.push_back( // XOR(1, 1) = 0
{executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 1}),
executorch::extension::make_tensor_ptr<long>({1}, {0})});
executorch::extension::make_tensor_ptr<int64_t>({1}, {0})});
data_set.push_back( // XOR(0, 0) = 0
{executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 0}),
executorch::extension::make_tensor_ptr<long>({1}, {0})});
executorch::extension::make_tensor_ptr<int64_t>({1}, {0})});
data_set.push_back( // XOR(1, 0) = 1
{executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 0}),
executorch::extension::make_tensor_ptr<long>({1}, {1})});
executorch::extension::make_tensor_ptr<int64_t>({1}, {1})});
data_set.push_back( // XOR(0, 1) = 1
{executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 1}),
executorch::extension::make_tensor_ptr<long>({1}, {1})});
executorch::extension::make_tensor_ptr<int64_t>({1}, {1})});

// Create optimizer.
// Get the params and names
Expand Down
67 changes: 37 additions & 30 deletions extension/training/optimizer/sgd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,49 @@
*/

#include <executorch/extension/training/optimizer/sgd.h>
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>

using exec_aten::Tensor;
using exec_aten::TensorImpl;
using ::executorch::runtime::Error;
using ::executorch::runtime::KernelRuntimeContext;

namespace executorch {
namespace extension {
namespace training {
namespace optimizer {

namespace {
void add_out_hack(
const Tensor& a,
const Tensor& b,
const double alpha,
Tensor& out) {
auto a_ptr = a.const_data_ptr<float>();
auto b_ptr = b.const_data_ptr<float>();
auto out_ptr = out.mutable_data_ptr<float>();
for (size_t i = 0; i < a.numel(); ++i) {
out_ptr[i] = a_ptr[i] + b_ptr[i] * alpha;
}
}

void mul_out_hack(const Tensor& a, const double alpha, Tensor& out) {
auto a_ptr = a.const_data_ptr<float>();
auto out_ptr = out.mutable_data_ptr<float>();
for (size_t i = 0; i < a.numel(); ++i) {
out_ptr[i] = a_ptr[i] * alpha;
}
}

void clone_out_hack(const Tensor& a, Tensor& out) {
auto a_ptr = a.const_data_ptr<float>();
auto out_ptr = out.mutable_data_ptr<float>();
for (size_t i = 0; i < a.numel(); ++i) {
out_ptr[i] = a_ptr[i];
}
}
} // namespace

bool SGDParamGroup::has_options() const {
return options_ != nullptr;
}
Expand Down Expand Up @@ -55,7 +83,6 @@ void SGD::add_param_group(const SGDParamGroup& param_group) {

Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
named_gradients) {
KernelRuntimeContext context;
for (auto& group : param_groups_) {
auto& options = static_cast<SGDOptions&>(group.options());
auto weight_decay = options.weight_decay();
Expand All @@ -73,10 +100,7 @@ Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
auto p = param_iter->second;
if (weight_decay != 0) {
// uses weight_decay specified and adds it to the gradient
torch::executor::aten::add_outf(context, d_p, p, weight_decay, d_p);
if (context.failure_state() != Error::Ok) {
return context.failure_state();
}
add_out_hack(d_p, p, weight_decay, d_p);
}
if (momentum != 0) {
Tensor buf(nullptr);
Expand All @@ -100,11 +124,7 @@ Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
const_cast<TensorImpl::DimOrderType*>(d_p.dim_order().data()));
buf = Tensor(buf_impl);
#endif
torch::executor::aten::clone_outf(
context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
if (context.failure_state() != Error::Ok) {
return context.failure_state();
}
clone_out_hack(d_p, buf);

// save the state of the momentum buffer to be reused in later
// epochs
Expand All @@ -115,31 +135,18 @@ Error SGD::step(const std::map<exec_aten::string_view, exec_aten::Tensor>&
.momentum_buffer();

// update the momentum buffer and apply dampening
torch::executor::aten::mul_outf(context, buf, momentum, buf);
if (context.failure_state() != Error::Ok) {
return context.failure_state();
}
torch::executor::aten::add_outf(
context, buf, d_p, 1 - dampening, buf);
if (context.failure_state() != Error::Ok) {
return context.failure_state();
}
mul_out_hack(buf, momentum, buf);
add_out_hack(buf, d_p, 1 - dampening, buf);
}
if (nesterov) {
// apply nesterov momentum
torch::executor::aten::add_outf(context, d_p, buf, momentum, d_p);
if (context.failure_state() != Error::Ok) {
return context.failure_state();
}
add_out_hack(d_p, buf, momentum, d_p);
} else {
d_p = buf;
}
}
// update the parameter using the gradient and learning rate
torch::executor::aten::add_outf(context, p, d_p, -1 * options.lr(), p);
if (context.failure_state() != Error::Ok) {
return context.failure_state();
}
add_out_hack(p, d_p, -1 * options.lr(), p);
}
}
}
Expand Down
Loading

0 comments on commit 26c736e

Please sign in to comment.