Skip to content
This repository was archived by the owner on Nov 15, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
351753a
[rllib] Remove dependency on TensorFlow (#4764)
ericl May 11, 2019
004440f
Dynamic Custom Resources - create and delete resources (#3742)
romilbhardwaj May 11, 2019
f3b8b90
Update tutorial link in doc (#4777)
May 12, 2019
69352e3
[rllib] Implement learn_on_batch() in torch policy graph
ericl May 13, 2019
62c949b
Fix `ray stop` by killing raylet before plasma (#4778)
jovany-wang May 13, 2019
1622fc2
Fatal check if object store dies (#4763)
stephanie-wang May 13, 2019
c5161a2
[rllib] fix clip by value issue as TF upgraded (#4697)
joneswong May 13, 2019
67af103
Merge with ray master
stefanpantic May 14, 2019
3bbafc7
[autoscaler] Fix submit (#4782)
richardliaw May 15, 2019
d6bf680
Merge with master
stefanpantic May 15, 2019
cb1a195
Queue tasks in the raylet in between async callbacks (#4766)
stephanie-wang May 15, 2019
643f62d
[Java][Bazel] Refine auto-generated pom files (#4780)
raulchen May 16, 2019
1490a98
Bump version to 0.7.0 (#4791)
devin-petersohn May 16, 2019
98dd033
[JAVA] setDefaultUncaughtExceptionHandler to log uncaught exception i…
May 16, 2019
9f2645d
[tune] Fix CLI test (#4801)
richardliaw May 16, 2019
ffd596d
Fix pom file generation (#4800)
raulchen May 17, 2019
7d5ef6d
[rllib] Support continuous action distributions in IMPALA/APPO (#4771)
ericl May 17, 2019
3807fb5
[rllib] TensorFlow 2 compatibility (#4802)
ericl May 17, 2019
84cf474
Change tagline in documentation and README. (#4807)
pcmoritz May 17, 2019
ffe61fc
[tune] Support non-arg submit (#4803)
richardliaw May 17, 2019
88b45a5
[autoscaler] rsync cluster (#4785)
richardliaw May 17, 2019
e20855c
[tune] Remove extra parsing functionality (#4804)
richardliaw May 17, 2019
dcd6d49
Fix Java worker log dir (#4781)
jovany-wang May 17, 2019
1ef9c07
[tune] Initial track integration (#4362)
noahgolmant May 17, 2019
6cb5b90
[rllib] [RFC] Dynamic definition of loss functions and modularization…
ericl May 18, 2019
04294d9
Merge remote-tracking branch 'remotes/main/master' into unstable
pimpke May 20, 2019
02583a8
[rllib] Rename PolicyGraph => Policy, move from evaluation/ to policy…
ericl May 20, 2019
081708b
[Java] Dynamic resource API in Java (#4824)
jovany-wang May 21, 2019
ac47d03
Merge with master
stefanpantic May 21, 2019
5391b61
Add default values for Wgym flags
stefanpantic May 21, 2019
87bb2e5
Fix import
stefanpantic May 21, 2019
259cdfa
Fix issue when starting `raylet_monitor` (#4829)
jovany-wang May 22, 2019
1a39fee
Refactor ID Serial 1: Separate ObjectID and TaskID from UniqueID (#4776)
guoyuhong May 22, 2019
2015085
Fix bug in which actor classes are not exported multiple times. (#4838)
robertnishihara May 23, 2019
ba6c595
Bump Ray master version to 0.8.0.dev0 (#4845)
devin-petersohn May 24, 2019
4e281ba
Add section to bump version of master branch and cleanup release docs…
devin-petersohn May 24, 2019
71f95e1
Fix import
stefanpantic May 24, 2019
be1850f
Merge branch 'unstable' of github.com:wingman-ai/ray into unstable
stefanpantic May 24, 2019
49fe894
Export remote functions when first used and also fix bug in which rem…
robertnishihara May 24, 2019
a7d01ab
Update wheel versions in documentation to 0.8.0.dev0 and 0.7.0. (#4847)
devin-petersohn May 24, 2019
0ce0ecb
[tune] Later expansion of local_dir (#4806)
richardliaw May 25, 2019
7237ea7
[rllib] [RFC] Deprecate Python 2 / RLlib (#4832)
ericl May 25, 2019
ea8d7b4
Fix a typo in kubernetes yaml (#4872)
ikedaosushi May 26, 2019
6703519
Move global state API out of global_state object. (#4857)
robertnishihara May 26, 2019
7a78e1e
Install bazel in autoscaler development configs. (#4874)
robertnishihara May 26, 2019
574e1c7
[tune] Fix up Ax Search and Examples (#4851)
richardliaw May 27, 2019
a45c61e
[rllib] Update concepts docs and add "Building Policies in Torch/Tens…
ericl May 27, 2019
d7be5a5
[rllib] Fix error getting kl when simple_optimizer: True in multi-age…
ericl May 28, 2019
fa0892f
Replace ReturnIds with NumReturns in TaskInfo to reduce the size (#4854)
guoyuhong May 28, 2019
64a01b2
Update deps commits of opencensus to support building with bzl 0.25.x…
jovany-wang May 28, 2019
0bcc589
Merge with master
stefanpantic May 28, 2019
64eb7b3
Upgrade arrow to latest master (#4858)
pcmoritz May 28, 2019
acee89b
[tune] Auto-init Ray + default SearchAlg (#4815)
richardliaw May 29, 2019
a218a14
Bump version from 0.8.0.dev0 to 0.7.1. (#4890)
robertnishihara May 29, 2019
2dd0beb
[rllib] Allow access to batches prior to postprocessing (#4871)
ericl May 30, 2019
3f4d37c
[rllib] Fix Multidiscrete support (#4869)
ericl May 30, 2019
b7c284a
Refactor redis callback handling (#4841)
jovany-wang May 30, 2019
2912a7c
Initial high-level code structure of CoreWorker. (#4875)
raulchen May 30, 2019
4e0be8b
Drop duplicated string format (#4897)
suquark May 30, 2019
1f0809e
Refactor ID Serial 2: change all ID functions to `CamelCase` (#4896)
May 31, 2019
0066d7c
Hotfix for change of from_random to FromRandom (#4909)
May 31, 2019
1c073e9
[rllib] Fix documentation on custom policies (#4910)
ericl Jun 1, 2019
9aa1cd6
[rllib] Allow Torch policies access to full action input dict in extr…
ericl Jun 1, 2019
88bab5d
[tune] Pretty print params json in logger.py (#4903)
hartikainen Jun 1, 2019
c2ade07
[sgd] Distributed Training via PyTorch (#4797)
pschafhalter Jun 2, 2019
665d081
[rllib] Rough port of DQN to build_tf_policy() pattern (#4823)
ericl Jun 2, 2019
d86ee8c
fetching objects in parallel in _get_arguments_for_execution (#4775)
ajgokhale Jun 2, 2019
99eae05
[tune] Disallow setting resources_per_trial when it is already config…
ericl Jun 2, 2019
7501ee5
[rllib] Rename PolicyEvaluator => RolloutWorker (#4820)
ericl Jun 2, 2019
084b221
Fix local cluster yaml (#4918)
richardliaw Jun 3, 2019
89722ff
[tune] Directional metrics for components (#4120) (#4915)
hershg Jun 3, 2019
b674c4a
[Core Worker] implement ObjectInterface and add test framework (#4899)
zhijunfu Jun 3, 2019
c2253d2
[tune] Make PBT Quantile fraction configurable (#4912)
timonbimon Jun 4, 2019
d106283
Better organize ray_common module (#4898)
raulchen Jun 5, 2019
649af18
Merge branches 'master' and 'unstable' of github.com:wingman-ai/ray i…
stefanpantic Jun 5, 2019
d7680ab
Merge with ray master
stefanpantic Jun 5, 2019
ffaae1c
Fix error
stefanpantic Jun 5, 2019
b2581c4
Merge branch 'master' of github.com:wingman-ai/ray into unstable
stefanpantic Jun 5, 2019
82b3972
Fix compute actions return value
stefanpantic Jun 6, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 52 additions & 16 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ cc_library(
"src/ray/raylet/mock_gcs_client.cc",
"src/ray/raylet/monitor_main.cc",
"src/ray/raylet/*_test.cc",
"src/ray/raylet/main.cc",
],
),
hdrs = glob([
Expand Down Expand Up @@ -105,6 +106,39 @@ cc_library(
],
)

cc_library(
name = "core_worker_lib",
srcs = glob(
[
"src/ray/core_worker/*.cc",
],
exclude = [
"src/ray/core_worker/*_test.cc",
],
),
hdrs = glob([
"src/ray/core_worker/*.h",
]),
copts = COPTS,
deps = [
":ray_common",
":ray_util",
":raylet_lib",
],
)

# This test is run by src/ray/test/run_core_worker_tests.sh
cc_binary(
name = "core_worker_test",
srcs = ["src/ray/core_worker/core_worker_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
":gcs",
"@com_google_googletest//:gtest_main",
],
)

cc_test(
name = "lineage_cache_test",
srcs = ["src/ray/raylet/lineage_cache_test.cc"],
Expand Down Expand Up @@ -247,16 +281,13 @@ cc_library(
name = "ray_util",
srcs = glob(
[
"src/ray/*.cc",
"src/ray/util/*.cc",
],
exclude = [
"src/ray/util/logging_test.cc",
"src/ray/util/signal_test.cc",
"src/ray/util/*_test.cc",
],
),
hdrs = glob([
"src/ray/*.h",
"src/ray/util/*.h",
]),
copts = COPTS,
Expand All @@ -272,23 +303,28 @@ cc_library(

cc_library(
name = "ray_common",
srcs = [
"src/ray/common/client_connection.cc",
"src/ray/common/common_protocol.cc",
],
hdrs = [
"src/ray/common/client_connection.h",
"src/ray/common/common_protocol.h",
],
srcs = glob(
[
"src/ray/common/*.cc",
],
exclude = [
"src/ray/common/*_test.cc",
],
),
hdrs = glob(
[
"src/ray/common/*.h",
],
),
copts = COPTS,
includes = [
"src/ray/gcs/format",
],
deps = [
":gcs_fbs",
":node_manager_fbs",
":ray_util",
"@boost//:asio",
"@plasma//:plasma_client",
],
)

Expand Down Expand Up @@ -432,7 +468,7 @@ cc_binary(
srcs = [
"src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h",
"src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc",
"src/ray/id.h",
"src/ray/common/id.h",
"src/ray/raylet/raylet_client.h",
"src/ray/util/logging.h",
"@bazel_tools//tools/jdk:jni_header",
Expand Down Expand Up @@ -637,8 +673,8 @@ genrule(
cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ &&
for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done &&
mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ &&
for f in $(locations //:python_node_manager_fbs); do
cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/;
for f in $(locations //:python_node_manager_fbs); do
cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/;
done &&
echo $$WORK_DIR > $@
""",
Expand Down
6 changes: 6 additions & 0 deletions bazel/BUILD.plasma
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ cc_library(
name = "arrow",
srcs = [
"cpp/src/arrow/buffer.cc",
"cpp/src/arrow/io/interfaces.cc",
"cpp/src/arrow/memory_pool.cc",
"cpp/src/arrow/status.cc",
"cpp/src/arrow/util/io-util.cc",
"cpp/src/arrow/util/logging.cc",
"cpp/src/arrow/util/memory.cc",
"cpp/src/arrow/util/string_builder.cc",
"cpp/src/arrow/util/thread-pool.cc",
],
hdrs = [
Expand All @@ -42,6 +44,7 @@ cc_library(
"cpp/src/arrow/util/logging.h",
"cpp/src/arrow/util/macros.h",
"cpp/src/arrow/util/memory.h",
"cpp/src/arrow/util/stl.h",
"cpp/src/arrow/util/string_builder.h",
"cpp/src/arrow/util/string_view.h",
"cpp/src/arrow/util/thread-pool.h",
Expand All @@ -53,6 +56,9 @@ cc_library(
"cpp/src/arrow/vendored/xxhash/xxhash.h",
],
strip_include_prefix = "cpp/src",
deps = [
"@boost//:filesystem",
],
)

cc_library(
Expand Down
30 changes: 15 additions & 15 deletions bazel/ray_deps_setup.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

def ray_deps_setup():
RULES_JVM_EXTERNAL_TAG = "1.2"

RULES_JVM_EXTERNAL_SHA = "e5c68b87f750309a79f59c2b69ead5c3221ffa54ff9496306937bfa1c9c8c86b"

http_archive(
name = "rules_jvm_external",
sha256 = RULES_JVM_EXTERNAL_SHA,
Expand All @@ -18,72 +18,72 @@ def ray_deps_setup():
strip_prefix = "bazel-common-f1115e0f777f08c3cdb115526c4e663005bec69b",
url = "https://github.com/google/bazel-common/archive/f1115e0f777f08c3cdb115526c4e663005bec69b.zip",
)

BAZEL_SKYLIB_TAG = "0.6.0"

http_archive(
name = "bazel_skylib",
strip_prefix = "bazel-skylib-%s" % BAZEL_SKYLIB_TAG,
url = "https://github.com/bazelbuild/bazel-skylib/archive/%s.tar.gz" % BAZEL_SKYLIB_TAG,
)

git_repository(
name = "com_github_checkstyle_java",
commit = "85f37871ca03b9d3fee63c69c8107f167e24e77b",
remote = "https://github.com/ruifangChen/checkstyle_java",
)

git_repository(
name = "com_github_nelhage_rules_boost",
commit = "5171b9724fbb39c5fdad37b9ca9b544e8858d8ac",
remote = "https://github.com/ray-project/rules_boost",
)

git_repository(
name = "com_github_google_flatbuffers",
commit = "63d51afd1196336a7d1f56a988091ef05deb1c62",
remote = "https://github.com/google/flatbuffers.git",
)

git_repository(
name = "com_google_googletest",
commit = "3306848f697568aacf4bcca330f6bdd5ce671899",
remote = "https://github.com/google/googletest",
)

git_repository(
name = "com_github_gflags_gflags",
remote = "https://github.com/gflags/gflags.git",
tag = "v2.2.2",
)

new_git_repository(
name = "com_github_google_glog",
build_file = "@//bazel:BUILD.glog",
commit = "5c576f78c49b28d89b23fbb1fc80f54c879ec02e",
remote = "https://github.com/google/glog",
)

new_git_repository(
name = "plasma",
build_file = "@//bazel:BUILD.plasma",
commit = "d00497b38be84fd77c40cbf77f3422f2a81c44f9",
commit = "9fcc12fc094b85ec2e3e9798bae5c8151d14df5e",
remote = "https://github.com/apache/arrow",
)

new_git_repository(
name = "cython",
build_file = "@//bazel:BUILD.cython",
commit = "49414dbc7ddc2ca2979d6dbe1e44714b10d72e7e",
remote = "https://github.com/cython/cython",
)

http_archive(
name = "io_opencensus_cpp",
strip_prefix = "opencensus-cpp-3aa11f20dd610cb8d2f7c62e58d1e69196aadf11",
urls = ["https://github.com/census-instrumentation/opencensus-cpp/archive/3aa11f20dd610cb8d2f7c62e58d1e69196aadf11.zip"],
)

# OpenCensus depends on Abseil so we have to explicitly pull it in.
# This is how diamond dependencies are prevented.
git_repository(
Expand All @@ -96,7 +96,7 @@ def ray_deps_setup():
http_archive(
name = "com_github_jupp0r_prometheus_cpp",
strip_prefix = "prometheus-cpp-master",

# TODO(qwang): We should use the repository of `jupp0r` here when this PR
# `https://github.com/jupp0r/prometheus-cpp/pull/225` getting merged.
urls = ["https://github.com/jovany-wang/prometheus-cpp/archive/master.zip"],
Expand Down
4 changes: 2 additions & 2 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ pushd "$BUILD_DIR"
# generated from https://github.com/ray-project/arrow-build from
# the commit listed in the command.
$PYTHON_EXECUTABLE -m pip install \
--target="$ROOT_DIR/python/ray/pyarrow_files" pyarrow==0.12.0.RAY \
--find-links https://s3-us-west-2.amazonaws.com/arrow-wheels/ca1fa51f0901f5a4298f0e4faea00f24e5dd7bb7/index.html
--target="$ROOT_DIR/python/ray/pyarrow_files" pyarrow==0.14.0.RAY \
--find-links https://s3-us-west-2.amazonaws.com/arrow-wheels/9f35817b35f9d0614a736a497d70de2cf07fed52/index.html
export PYTHON_BIN_PATH="$PYTHON_EXECUTABLE"

if [ "$RAY_BUILD_JAVA" == "YES" ]; then
Expand Down
23 changes: 1 addition & 22 deletions ci/jenkins_tests/run_multi_node_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,4 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=60G --memory=60G $DOCKER_SHA \
######################## SGD TESTS #################################

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \
--batch-size=1 --strategy=simple

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \
--batch-size=1 --strategy=ps

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=simple

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=ps

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \
--num-workers=1 --devices-per-worker=1 --strategy=ps

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \
--num-workers=1 --devices-per-worker=1 --strategy=ps --tune
python -m pytest /ray/python/ray/experimental/sgd/tests
13 changes: 11 additions & 2 deletions ci/jenkins_tests/run_rllib_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_checkpoint_restore.py

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_policy_evaluator.py
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_rollout_worker.py

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_nested_spaces.py
Expand Down Expand Up @@ -390,7 +390,16 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/rollout_worker_custom_workflow.py

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_torch_policy.py --iters=2

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/rollout_worker_custom_workflow.py

docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2
Expand Down
3 changes: 2 additions & 1 deletion ci/long_running_tests/workloads/pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

pbt = PopulationBasedTraining(
time_attr="training_iteration",
reward_attr="episode_reward_mean",
metric="episode_reward_mean",
mode="max",
perturbation_interval=10,
hyperparam_mutations={
"lr": [0.1, 0.01, 0.001, 0.0001],
Expand Down
4 changes: 4 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
"torch",
"torch.distributed",
"torch.nn",
"torch.utils.data",
]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
Expand Down
48 changes: 48 additions & 0 deletions doc/source/distributed_training.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Distributed Training (Experimental)
===================================


Ray includes abstractions for distributed model training that integrate with
deep learning frameworks, such as PyTorch.

Ray Train is built on top of the Ray task and actor abstractions to provide
seamless integration into existing Ray applications.

PyTorch Interface
-----------------

To use Ray Train with PyTorch, pass model and data creator functions to the
``ray.experimental.sgd.pytorch.PyTorchTrainer`` class.
To drive the distributed training, ``trainer.train()`` can be called
repeatedly.

.. code-block:: python

model_creator = lambda config: YourPyTorchModel()
data_creator = lambda config: YourTrainingSet(), YourValidationSet()

trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator=utils.sgd_mse_optimizer,
config={"lr": 1e-4},
num_replicas=2,
resources_per_replica=Resources(num_gpus=1),
batch_size=16,
backend="auto")

for i in range(NUM_EPOCHS):
trainer.train()

Under the hood, Ray Train will create *replicas* of your model
(controlled by ``num_replicas``) which are each managed by a worker.
Multiple devices (e.g. GPUs) can be managed by each replica (controlled by ``resources_per_replica``),
which allows training of lage models across multiple GPUs.
The ``PyTorchTrainer`` class coordinates the distributed computation and training to improve the model.

The full documentation for ``PyTorchTrainer`` is as follows:

.. autoclass:: ray.experimental.sgd.pytorch.PyTorchTrainer
:members:

.. automethod:: __init__
Loading