diff --git a/CMakeLists.txt b/CMakeLists.txt index 6bdcda2f19c..de941663a88 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -248,14 +248,15 @@ cmake_dependent_option( "NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF ) -if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR) +if(EXECUTORCH_BUILD_EXTENSION_TRAINING) set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON) + set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON) + set(EXECUTORCH_BUILD_EXTENSION_MODULE ON) + set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON) endif() -if(EXECUTORCH_BUILD_EXTENSION_TRAINING) - set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON) +if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR) set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON) - set(EXECUTORCH_BUILD_EXTENSION_MODULE ON) endif() if(EXECUTORCH_BUILD_EXTENSION_MODULE) diff --git a/extension/training/CMakeLists.txt b/extension/training/CMakeLists.txt index e50bb3c71eb..97e75955837 100644 --- a/extension/training/CMakeLists.txt +++ b/extension/training/CMakeLists.txt @@ -26,7 +26,7 @@ target_include_directories( target_include_directories(extension_training PUBLIC ${EXECUTORCH_ROOT}/..) target_compile_options(extension_training PUBLIC ${_common_compile_options}) target_link_libraries(extension_training executorch_core - extension_data_loader extension_module extension_tensor) + extension_data_loader extension_module extension_tensor extension_flat_tensor) list(TRANSFORM _train_xor__srcs PREPEND "${EXECUTORCH_ROOT}/") diff --git a/extension/training/examples/XOR/export_model.py b/extension/training/examples/XOR/export_model.py index bfbe0ce2138..98e04f09a2f 100644 --- a/extension/training/examples/XOR/export_model.py +++ b/extension/training/examples/XOR/export_model.py @@ -11,14 +11,14 @@ import os import torch -from executorch.exir import to_edge +from executorch.exir import ExecutorchBackendConfig, to_edge from executorch.extension.training.examples.XOR.model import Net, TrainingNet from torch.export import export from torch.export.experimental import _export_forward_backward -def _export_model(): +def _export_model(external_mutable_weights: bool = False): net = TrainingNet(Net()) x = torch.randn(1, 2) @@ -30,7 +30,11 @@ def _export_model(): # Lower the graph to edge dialect. ep = to_edge(ep) # Lower the graph to executorch. - ep = ep.to_executorch() + ep = ep.to_executorch( + config=ExecutorchBackendConfig( + external_mutable_weights=external_mutable_weights + ) + ) return ep @@ -44,19 +48,27 @@ def main() -> None: "--outdir", type=str, required=True, - help="Path to the directory to write xor.pte files to", + help="Path to the directory to write xor.pte and xor.ptd files to", + ) + parser.add_argument( + "--external", + action="store_true", + help="Export the model with external weights", ) args = parser.parse_args() - ep = _export_model() + ep = _export_model(args.external) # Write out the .pte file. os.makedirs(args.outdir, exist_ok=True) outfile = os.path.join(args.outdir, "xor.pte") with open(outfile, "wb") as fp: - fp.write( - ep.buffer, - ) + ep.write_to_file(fp) + + if args.external: + # current infra doesnt easily allow renaming this file, so just hackily do it here. + ep._tensor_data["xor"] = ep._tensor_data.pop("_default_external_constant") + ep.write_tensor_data_to_file(args.outdir) if __name__ == "__main__": diff --git a/extension/training/examples/XOR/train.cpp b/extension/training/examples/XOR/train.cpp index 746daebbf1b..af1c37a6a50 100644 --- a/extension/training/examples/XOR/train.cpp +++ b/extension/training/examples/XOR/train.cpp @@ -23,12 +23,18 @@ using executorch::extension::training::optimizer::SGDOptions; using executorch::runtime::Error; using executorch::runtime::Result; DEFINE_string(model_path, "xor.pte", "Model serialized in flatbuffer format."); +DEFINE_string(ptd_path, "", "Model weights serialized in flatbuffer format."); int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - if (argc != 1) { + if (argc == 0) { + ET_LOG(Error, "Please provide a model path."); + return 1; + } else if (argc > 2) { std::string msg = "Extra commandline args: "; - for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) { + for (int i = 2 /* skip argv[0] (pte path) and argv[1] (ptd path) */; + i < argc; + i++) { msg += argv[i]; } ET_LOG(Error, "%s", msg.c_str()); @@ -46,7 +52,21 @@ int main(int argc, char** argv) { auto loader = std::make_unique( std::move(loader_res.get())); - auto mod = executorch::extension::training::TrainingModule(std::move(loader)); + std::unique_ptr ptd_loader = nullptr; + if (!FLAGS_ptd_path.empty()) { + executorch::runtime::Result + ptd_loader_res = + executorch::extension::FileDataLoader::from(FLAGS_ptd_path.c_str()); + if (ptd_loader_res.error() != Error::Ok) { + ET_LOG(Error, "Failed to open ptd file: %s", FLAGS_ptd_path.c_str()); + return 1; + } + ptd_loader = std::make_unique( + std::move(ptd_loader_res.get())); + } + + auto mod = executorch::extension::training::TrainingModule( + std::move(loader), nullptr, nullptr, nullptr, std::move(ptd_loader)); // Create full data set of input and labels. std::vector(param_res.error())); return 1; } @@ -112,5 +135,6 @@ int main(int argc, char** argv) { std::string(param.first.data()), param.second}); } - executorch::extension::flat_tensor::save_ptd("xor.ptd", param_map, 16); + executorch::extension::flat_tensor::save_ptd( + "trained_xor.ptd", param_map, 16); }