Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
6adb290
temporary weight adjust index
reyna-abhyankar Aug 25, 2024
61697c2
Loss function
reyna-abhyankar Aug 27, 2024
b56c046
Add cuda test for loss function
reyna-abhyankar Aug 27, 2024
f75a3d4
Format
reyna-abhyankar Aug 27, 2024
f74711f
Refactor and build optimizer kernels, op
reyna-abhyankar Aug 27, 2024
40c6252
Finish optimizer local backing
reyna-abhyankar Aug 27, 2024
ad9b9ea
Format
reyna-abhyankar Aug 27, 2024
1ddfade
E2E update test
reyna-abhyankar Aug 27, 2024
dde9496
Format
reyna-abhyankar Aug 27, 2024
59635d8
Small fixes
reyna-abhyankar Sep 11, 2024
103ef07
Format
reyna-abhyankar Sep 11, 2024
f48f9ff
Fix test and small issues
reyna-abhyankar Sep 18, 2024
189c9c8
Format
reyna-abhyankar Sep 18, 2024
d93f464
Merge branch 'repo-refactor' into local-e2e-training
reyna-abhyankar Oct 1, 2024
b5647c8
Pass tests after merge
reyna-abhyankar Oct 1, 2024
f5ff91e
Fix input/weight differentiation
reyna-abhyankar Oct 1, 2024
7470e71
Fix signature to use unified rep
reyna-abhyankar Oct 1, 2024
deece1b
Fix model training instance abstraction
reyna-abhyankar Oct 1, 2024
1d3cc94
Change subcase test name
reyna-abhyankar Oct 1, 2024
3cf5d08
Quick fixes
reyna-abhyankar Oct 16, 2024
79ef4c9
Refactor training backing and instance
reyna-abhyankar Oct 22, 2024
a73b1c3
Expose op folders publicly
reyna-abhyankar Nov 13, 2024
c6fed29
Add tensor type, operate over reduced tensor
reyna-abhyankar Nov 13, 2024
0cdfb1a
Fixes
reyna-abhyankar Jan 7, 2025
9d252b3
Remove tensor lower
reyna-abhyankar Jan 15, 2025
895c117
Add tensor and task lowering scheme
reyna-abhyankar Jan 17, 2025
66d61eb
feat: add realm-backend subdir
chenzhuofu Jan 21, 2025
8d0cfec
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Jan 21, 2025
411017d
Build local exec
reyna-abhyankar Jan 22, 2025
759abdd
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Jan 22, 2025
bcd1408
chore: duplicate some files from local-execution
chenzhuofu Jan 22, 2025
5e11568
Merge branch 'master' of github.com:flexflow/flexflow-train into real…
chenzhuofu Jan 28, 2025
1c55cf7
Merge branch 'master' of github.com:flexflow/flexflow-train into real…
chenzhuofu Jan 28, 2025
b9144ad
chore: update legion
chenzhuofu Jan 30, 2025
66647a2
feat: add legion related code
chenzhuofu Jan 30, 2025
0128abb
Disaggregate local backend
reyna-abhyankar Feb 1, 2025
277f8c2
Update task binding interface and cost estimator
reyna-abhyankar Feb 1, 2025
377c6aa
Merge master into local execution
reyna-abhyankar Feb 4, 2025
6f689a4
feat: add Future wrapper for func result
chenzhuofu Feb 5, 2025
fe2bc21
feat: add realm-backend draft impl
chenzhuofu Feb 5, 2025
8efaec7
Build
reyna-abhyankar Feb 6, 2025
1dc1398
Format
reyna-abhyankar Feb 6, 2025
17ad5c8
Split task spec files
reyna-abhyankar Feb 6, 2025
639c2c1
Delete outdated sim environment file
reyna-abhyankar Feb 6, 2025
c408ebb
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Feb 8, 2025
a697044
Finish API
reyna-abhyankar Feb 13, 2025
187a8d5
Add tests for allocated and unallocated
reyna-abhyankar Feb 13, 2025
a0f8113
Fix nonnegative
reyna-abhyankar Feb 13, 2025
b1eab94
Format
reyna-abhyankar Feb 13, 2025
b532c50
Pass allocated-unallocated tests
reyna-abhyankar Feb 13, 2025
f28e5c2
Update task registry tests
reyna-abhyankar Feb 13, 2025
7887183
Merge branch 'local-e2e-training' of github.com:reyna-abhyankar/FlexF…
chenzhuofu Feb 16, 2025
9c16d76
feat: intial implementation of realm-backend
chenzhuofu Feb 19, 2025
89752fa
Move local tensor backing to dtgen
reyna-abhyankar Feb 22, 2025
aef8ad5
Remove lowered tensor source
reyna-abhyankar Feb 22, 2025
f0a4285
Loss and update tests
reyna-abhyankar Feb 24, 2025
9047edc
Merge master
reyna-abhyankar Feb 24, 2025
350babf
Passing tests after merge issues
reyna-abhyankar Feb 24, 2025
aef7c6e
Pass gpu tests
reyna-abhyankar Feb 25, 2025
6c84fb3
chore: fix typo
chenzhuofu Feb 26, 2025
d6aa7ad
chore: update realm allocator impl
chenzhuofu Feb 27, 2025
419cca8
chore: eliminate std::optional<float>
chenzhuofu Mar 3, 2025
2c0b573
feat: buildable realm-backend
chenzhuofu Mar 5, 2025
ebe06cf
Merge commit 'aef8ad58196f7b7f724fc7f0a1a65af24ee12acd' of github.com…
chenzhuofu Mar 5, 2025
062825e
chore: Move realm tensor backing to dtgen
chenzhuofu Mar 5, 2025
d82fa2a
Merge commit '350babf3584c3d99e76e4dc0f72a658aa0222afc' of github.com…
chenzhuofu Mar 5, 2025
7c53bb3
chore: minor
chenzhuofu Mar 5, 2025
403ec78
Merge commit 'aef7c6e3c3087f15b4c90792148f170da84f6f7c' of github.com…
chenzhuofu Mar 5, 2025
bf57d1d
chore: remove deprecated file
chenzhuofu Mar 5, 2025
3a0d4e8
feat: add a unit test for realm backend
chenzhuofu Mar 12, 2025
fa3f917
fix: DeviceSpecificState error
chenzhuofu Mar 16, 2025
b6c163e
Merge branch 'master' of github.com:flexflow/flexflow-train into real…
chenzhuofu Mar 18, 2025
b55aed7
fix: realm task id should start from `Processor::TASK_ID_FIRST_AVAILA…
chenzhuofu Mar 19, 2025
a921775
fix: RealmTrainingBacking initialization
chenzhuofu Mar 19, 2025
a708496
fix: bugs with DeviceSpecificDeviceStates...
chenzhuofu Mar 19, 2025
6e9c9af
tests: pass test_update
chenzhuofu Mar 19, 2025
7b1f653
chore: minor
chenzhuofu Mar 19, 2025
75747a8
Merge branch 'master' into local-e2e-training
reyna-abhyankar Apr 3, 2025
9e3246e
Merge branch 'master' into local-e2e-training
reyna-abhyankar Apr 30, 2025
64a82b3
Add e2e test
reyna-abhyankar Apr 30, 2025
ffd96e2
Format
reyna-abhyankar Apr 30, 2025
2f75451
Pass cost estimator test
reyna-abhyankar Apr 30, 2025
2746e14
Add nccl fix and host accessor access
reyna-abhyankar May 5, 2025
633253d
Merge remote-tracking branch 'origin/master' into local-e2e-training
lockshaw May 8, 2025
31df722
Move operators into task-spec
lockshaw May 8, 2025
292c61c
Sync changes with Reyna
lockshaw May 15, 2025
d1ffea9
Fix typo in task-spec
lockshaw May 15, 2025
7e45215
Add positive_int and tensor reductions/comparisons
lockshaw May 21, 2025
a266a79
Format
lockshaw May 21, 2025
cebd06c
Merge branch 'master' into local-e2e-training
reyna-abhyankar Apr 30, 2025
ea1a6df
Add tests for positive_int
lockshaw May 23, 2025
9d4f90b
test: realm backend add e2e test
chenzhuofu May 26, 2025
f3e2a27
tweak: minor
chenzhuofu May 27, 2025
ba85fe4
Pass cost estimator test
reyna-abhyankar Apr 30, 2025
ed0a164
feat: fix e2e test
chenzhuofu May 28, 2025
7755b94
fix: TaskArgumentAccessor has an share_ptr object, which need to be h…
chenzhuofu May 28, 2025
335ac6d
Expose test kernels, fill weights
reyna-abhyankar Jun 11, 2025
dbbb574
Expose test utils
reyna-abhyankar Jun 11, 2025
a4c1ea4
Remove prints
reyna-abhyankar Jun 17, 2025
b2de407
tweak: megre from local-e2e-training
chenzhuofu Jun 18, 2025
346f986
tweak: minor
chenzhuofu Jun 18, 2025
d5a57ba
feat: test e2e for realm-backend
chenzhuofu Jun 18, 2025
32971ef
tweak: minor
chenzhuofu Jun 19, 2025
a1a8c14
tweak: minor
chenzhuofu Jun 19, 2025
4e3fb7d
feat: reconstrcut realm backend
chenzhuofu Aug 4, 2025
426206e
fix: e2e test for realm backend
chenzhuofu Aug 6, 2025
b73feb4
tweak: minor
chenzhuofu Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .flake/pkgs/legion.nix
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ in

stdenv.mkDerivation rec {
pname = "legion_flexflow";
version = "2024-03-13";
version = "2025-01-21";

src = fetchFromGitLab {
owner = "StanfordLegion";
repo = "legion";
rev = "24e8c452341dea41427e0ce61e154d61715e6835";
sha256 = "sha256-NjCSjphOIew/V24i74I6DModSGcWKLeiSIjts3cFtx4=";
rev = "0c5a181e59c07e3af1091a2007378ff9355047fa";
sha256 = "sha256-oapo7klN17gmRsmaSsrpup4YJ0dtHxiKFtwz8jyPqzU=";
fetchSubmodules = true;
};

Expand All @@ -33,7 +33,7 @@ stdenv.mkDerivation rec {
];

cmakeFlags = [
"-DLegion_USE_Python=1"
"-DLegion_USE_Python=0"
"-DLegion_BUILD_BINDINGS=1"
"-DLegion_USE_CUDA=1"
"-DLegion_CUDA_ARCH=${lib.concatStringsSep "," cudaCapabilities}"
Expand Down
7 changes: 7 additions & 0 deletions .proj.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ has-cpu-only-benchmarks = false
has-cuda-tests = true
has-cuda-benchmarks = false

[targets.realm-backend]
type = "lib"
has-cpu-only-tests = true
has-cpu-only-benchmarks = false
has-cuda-tests = true
has-cuda-benchmarks = false

[targets.models]
type = "lib"
has-cpu-only-tests = true
Expand Down
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_subdirectory(runtime)
add_subdirectory(op-attrs)
add_subdirectory(kernels)
add_subdirectory(local-execution)
add_subdirectory(realm-backend)
add_subdirectory(task-spec)
add_subdirectory(utils)
add_subdirectory(ffi)
Expand Down
12 changes: 1 addition & 11 deletions lib/local-execution/include/local-execution/local_args_backing.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,14 @@ std::optional<DeviceSpecificDeviceStates>
std::unordered_map<slot_id_t, ConcreteArgSpec>
construct_arg_slots_backing(TaskBinding const &, RuntimeArgConfig const &);

std::optional<DeviceSpecificDeviceStates>
create_per_device_op_state(LocalTaskRegistry const &,
LocalTensorBacking const &,
RuntimeArgConfig const &,
Allocator &,
TrainingLayerPlusContext const &);

TaskArgumentAccessor get_task_arg_accessor(LocalTensorBacking const &,
RuntimeArgConfig const &,
TaskInvocation const &,
Allocator &);

LocalArgsBacking make_local_args_backing_for_computation_graph(
LocalTaskRegistry const &,
TrainingComputationGraph const &,
RuntimeArgConfig const &,
LocalTensorBacking const &,
Allocator &);
std::unordered_map<layer_guid_t, std::optional<DeviceSpecificDeviceStates>> const &);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "pcg/optimizer_attrs.dtg.h"
#include "task-spec/training_computation_graph.dtg.h"
#include "task-spec/training_tensor_guid_t.dtg.h"
#include "utils/containers/generate_map.h"
#include "utils/units/milliseconds_t.h"

namespace FlexFlow {
Expand All @@ -18,6 +19,13 @@ LocalTrainingBacking make_local_training_backing_for_computation_graph(
RuntimeArgConfig const &runtime_arg_config,
OptimizerAttrs const &optimizer_attrs);

std::optional<DeviceSpecificDeviceStates>
create_per_device_op_state(LocalTaskRegistry const &,
LocalTensorBacking const &,
RuntimeArgConfig const &,
Allocator &,
TrainingLayerPlusContext const &);

std::optional<milliseconds_t> execute_forward(LocalTaskRegistry const &,
LocalTensorBacking const &,
LocalArgsBacking const &,
Expand Down
51 changes: 2 additions & 49 deletions lib/local-execution/src/local-execution/local_args_backing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,6 @@ std::unordered_map<slot_id_t, ConcreteArgSpec>
;
}

std::optional<DeviceSpecificDeviceStates>
create_per_device_op_state(LocalTaskRegistry const &local_task_registry,
LocalTensorBacking const &tensor_backing,
RuntimeArgConfig const &runtime_arg_config,
Allocator &allocator,
TrainingLayerPlusContext const &training_layer) {
std::optional maybe_registered_task = try_get_registered_task(
local_task_registry, training_layer.layer_guid, OpTaskType::INIT);

ASSERT(maybe_registered_task.has_value());

registered_task_t registered_task = maybe_registered_task.value();
if (registered_task.is_noop_task()) {
return std::nullopt;
}

TaskInvocation invocation = lower_to_task_invocation(
/*op_task_invocation=*/get_init_op_task_invocation(
training_layer.layer_attrs.op_attrs),
/*training_layer=*/training_layer,
/*device_specific_device_states=*/std::nullopt);

TaskArgumentAccessor accessor = get_task_arg_accessor(
tensor_backing, runtime_arg_config, invocation, allocator);
TaskSignatureAndImpl task_sig_impl =
local_task_registry.task_mapping.at(invocation.task_id);
auto fn =
task_sig_impl.impl_function.get<InitOpTaskImplFunction>().function_ptr;
DeviceSpecificDeviceStates device_state = fn(accessor);
return device_state;
}

TaskArgumentAccessor
get_task_arg_accessor(LocalTensorBacking const &local_tensor_backing,
RuntimeArgConfig const &runtime_arg_config,
Expand All @@ -82,24 +50,9 @@ TaskArgumentAccessor
}

LocalArgsBacking make_local_args_backing_for_computation_graph(
LocalTaskRegistry const &task_registry,
TrainingComputationGraph const &training_computation_graph,
RuntimeArgConfig const &runtime_arg_config,
LocalTensorBacking const &local_tensor_backing,
Allocator &allocator) {
std::unordered_map<layer_guid_t, std::optional<DeviceSpecificDeviceStates>>
per_device_op_states = generate_map(
topological_ordering(training_computation_graph.computation_graph),
[&](layer_guid_t const &layer_guid) {
return create_per_device_op_state(
task_registry,
local_tensor_backing,
runtime_arg_config,
allocator,
get_training_layer_plus_context(training_computation_graph,
layer_guid));
});

std::unordered_map<layer_guid_t, std::optional<DeviceSpecificDeviceStates>> const &
per_device_op_states) {
return LocalArgsBacking{
runtime_arg_config,
per_device_op_states,
Expand Down
52 changes: 47 additions & 5 deletions lib/local-execution/src/local-execution/local_training_backing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,22 @@ LocalTrainingBacking make_local_training_backing_for_computation_graph(
preallocated,
allocator);

std::unordered_map<layer_guid_t, std::optional<DeviceSpecificDeviceStates>>
per_device_op_states = generate_map(
topological_ordering(training_computation_graph.computation_graph),
[&](layer_guid_t const &layer_guid) {
return create_per_device_op_state(
local_task_registry,
local_tensor_backing,
runtime_arg_config,
allocator,
get_training_layer_plus_context(training_computation_graph,
layer_guid));
});

LocalArgsBacking local_args_backing =
make_local_args_backing_for_computation_graph(local_task_registry,
training_computation_graph,
runtime_arg_config,
local_tensor_backing,
allocator);
make_local_args_backing_for_computation_graph(runtime_arg_config,
per_device_op_states);

return LocalTrainingBacking{
/*computation_graph=*/training_computation_graph,
Expand All @@ -54,6 +64,38 @@ LocalTrainingBacking make_local_training_backing_for_computation_graph(
};
}

std::optional<DeviceSpecificDeviceStates>
create_per_device_op_state(LocalTaskRegistry const &local_task_registry,
LocalTensorBacking const &tensor_backing,
RuntimeArgConfig const &runtime_arg_config,
Allocator &allocator,
TrainingLayerPlusContext const &training_layer) {
std::optional maybe_registered_task = try_get_registered_task(
local_task_registry, training_layer.layer_guid, OpTaskType::INIT);

ASSERT(maybe_registered_task.has_value());

registered_task_t registered_task = maybe_registered_task.value();
if (registered_task.is_noop_task()) {
return std::nullopt;
}

TaskInvocation invocation = lower_to_task_invocation(
/*op_task_invocation=*/get_init_op_task_invocation(
training_layer.layer_attrs.op_attrs),
/*training_layer=*/training_layer,
/*device_specific_device_states=*/std::nullopt);

TaskArgumentAccessor accessor = get_task_arg_accessor(
tensor_backing, runtime_arg_config, invocation, allocator);
TaskSignatureAndImpl task_sig_impl =
local_task_registry.task_mapping.at(invocation.task_id);
auto fn =
task_sig_impl.impl_function.get<InitOpTaskImplFunction>().function_ptr;
DeviceSpecificDeviceStates device_state = fn(accessor);
return device_state;
}

std::optional<milliseconds_t>
execute_forward(LocalTaskRegistry const &local_task_registry,
LocalTensorBacking const &local_tensor_backing,
Expand Down
21 changes: 21 additions & 0 deletions lib/realm-backend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
ff_add_library(
NAME
realm-backend
SRC_PATTERNS
src/*.cc
PUBLIC_INCLUDE
include/
PRIVATE_INCLUDE
src/
DEPS
op-attrs
utils
kernels
compiler
local-execution
pcg
spdlog
legion
)

add_subdirectory(test)
13 changes: 13 additions & 0 deletions lib/realm-backend/include/realm-backend/driver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef _FLEXFLOW_REALM_BACKEND_DRIVER_H
#define _FLEXFLOW_REALM_BACKEND_DRIVER_H

#include "realm.h"
#include "realm/cmdline.h"
#include "task-spec/op_task_invocation.h"

Realm::Processor::TaskFuncID get_realm_task_id(FlexFlow::task_id_t task_id);

void top_level_task(const void *args, size_t arglen, const void *userdata,
size_t userlen, Realm::Processor p);

#endif
31 changes: 31 additions & 0 deletions lib/realm-backend/include/realm-backend/model_training_instance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef _FLEXFLOW_LOCAL_EXECUTION_MODEL_TRAINING_INSTANCE_H
#define _FLEXFLOW_LOCAL_EXECUTION_MODEL_TRAINING_INSTANCE_H

#include "realm-backend/realm_training_backing.h"
#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h"
#include "pcg/tensor_guid_t.dtg.h"
#include "task-spec/loss_tensor_guid_t.dtg.h"

namespace FlexFlow {

struct ModelTrainingInstance {
ModelTrainingInstance(RealmRuntimeState &,
LocalTrainingBacking const &,
LossAttrs const &,
OptimizerAttrs const &);

RealmRuntimeState &runtime_state;
LocalTrainingBacking training_backing;
LossAttrs loss_attrs;
OptimizerAttrs optimizer_attrs;

public:
std::unordered_map<layer_guid_t, std::optional<milliseconds_t>> forward();
std::unordered_map<layer_guid_t, std::optional<milliseconds_t>> backward();
void update();
GenericTensorAccessorR get_loss_tensor_accessor() const;
};

} // namespace FlexFlow

#endif
34 changes: 34 additions & 0 deletions lib/realm-backend/include/realm-backend/realm_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef _FLEXFLOW_REALM_BACKEND_REALM_ALLOCATOR_H
#define _FLEXFLOW_REALM_BACKEND_REALM_ALLOCATOR_H

#include "realm-backend/driver.h"
#include "realm.h"
#include "kernels/allocation.h"
#include <realm/event.h>

namespace FlexFlow {

struct RealmAllocatorImpl : public IAllocator {
RealmAllocatorImpl() = delete;
RealmAllocatorImpl(RealmAllocatorImpl const &) = delete;
RealmAllocatorImpl(RealmAllocatorImpl &&) = delete;
RealmAllocatorImpl(Realm::Processor);
~RealmAllocatorImpl() = default;

void *allocate(size_t) override;
void deallocate(void *) override;

DeviceType get_allocation_device_type() const override;

private:
std::unordered_map<void *, Realm::RegionInstance> ptrs;
Realm::Processor proc;
Realm::Memory mem;
std::vector<size_t> field_sizes = {sizeof(char)};
};

Allocator create_realm_memory_allocator(Realm::Processor);

} // namespace FlexFlow

#endif
64 changes: 64 additions & 0 deletions lib/realm-backend/include/realm-backend/realm_training_backing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#ifndef _FLEXFLOW_REALM_BACKEND_REALM_TRAINING_BACKING_H
#define _FLEXFLOW_REALM_BACKEND_REALM_TRAINING_BACKING_H

#include "local-execution/local_training_backing.dtg.h"
#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h"
#include "pcg/optimizer_attrs.dtg.h"
#include "task-spec/training_computation_graph.dtg.h"
#include "task-spec/training_tensor_guid_t.dtg.h"
#include "utils/containers/generate_map.h"
#include "utils/units/milliseconds_t.h"
#include "realm-backend/driver.h"
#include "realm-backend/realm_allocator.h"
#include "realm-backend/task_wrapper.h"

namespace FlexFlow {

struct RealmRuntimeState {
Realm::Processor master_proc;
Realm::Event master_event;
Realm::Memory master_mem;
std::vector<Realm::Processor> worker_procs;
std::vector<Realm::Event> worker_events;
std::vector<Allocator> allocators;
};

LocalTrainingBacking make_realm_training_backing_for_computation_graph(
RealmRuntimeState &runtime_state,
std::unordered_map<training_tensor_guid_t, GenericTensorAccessorW> const
&preallocated_tensors,
TrainingComputationGraph const &training_computation_graph,
RuntimeArgConfig const &runtime_arg_config,
OptimizerAttrs const &optimizer_attrs);

void register_tasks_for_realm(LocalTaskRegistry const &, RealmRuntimeState &);

std::optional<DeviceSpecificDeviceStates>
create_per_device_op_state(LocalTaskRegistry const &,
LocalTensorBacking const &,
RuntimeArgConfig const &,
RealmRuntimeState &,
TrainingLayerPlusContext const &);

Future<std::optional<milliseconds_t>> execute_forward(LocalTaskRegistry const &,
LocalTensorBacking const &,
LocalArgsBacking const &,
TrainingLayerPlusContext const &,
RealmRuntimeState &);

Future<std::optional<milliseconds_t>> execute_backward(LocalTaskRegistry const &,
LocalTensorBacking const &,
LocalArgsBacking const &,
TrainingLayerPlusContext const &,
RealmRuntimeState &);

Future<void> compute_loss(LocalTrainingBacking const &, LossAttrs const &, RealmRuntimeState &);

Future<void> execute_update(LocalTrainingBacking const &,
layer_guid_t const &,
OptimizerAttrs const &,
RealmRuntimeState &);

} // namespace FlexFlow

#endif
Loading
Loading