-
Notifications
You must be signed in to change notification settings - Fork 34
/
__init__.py
81 lines (76 loc) · 3.01 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
CommunicationDType,
DDPShampooConfig,
DefaultEigenvalueCorrectedShampooConfig,
DefaultShampooConfig,
DefaultSOAPConfig,
DistributedConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
FSDPShampooConfig,
FullyShardShampooConfig,
GraftingConfig,
HSDPShampooConfig,
PreconditionerConfig,
RMSpropGraftingConfig,
SGDGraftingConfig,
ShampooPreconditionerConfig,
ShampooPT2CompileConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import compile_fsdp_parameter_metadata
from matrix_functions_types import (
CoupledHigherOrderConfig,
CoupledNewtonConfig,
DefaultEigenConfig,
EigenConfig,
EigenvectorConfig,
MatrixFunctionConfig,
RootInvConfig,
)
__all__ = [
"DistributedShampoo",
# `grafting_config` options.
"GraftingConfig", # Abstract base class.
"SGDGraftingConfig",
"AdaGradGraftingConfig",
"RMSpropGraftingConfig",
"AdamGraftingConfig",
# PT2 compile.
"ShampooPT2CompileConfig",
# `distributed_config` options.
"DistributedConfig", # Abstract base class.
"DDPShampooConfig",
"FSDPShampooConfig",
"FullyShardShampooConfig",
"HSDPShampooConfig",
# `precision_config`.
# `preconditioner_config` options.
"PreconditionerConfig", # Abstract base class.
"ShampooPreconditionerConfig", # Based on `PreconditionerConfig`.
"DefaultShampooConfig", # Default `ShampooPreconditionerConfig` using `EigenConfig`.
"EigenvalueCorrectedShampooPreconditionerConfig", # Based on `PreconditionerConfig`.
"DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigenvectorConfig`.
"DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QRConfig`.
# matrix functions configs.
"MatrixFunctionConfig", # Abstract base class.
"RootInvConfig", # Abstract base class (based on `MatrixFunctionConfig`).
"EigenConfig", # Based on `RootInvConfig`.
"DefaultEigenConfig", # Default `RootInvConfig` using `EigenConfig`.
"CoupledNewtonConfig", # Based on `RootInvConfig`.
"CoupledHigherOrderConfig", # Based on `RootInvConfig`.
"EigenvectorConfig", # Abstract base class (based on `MatrixFunctionConfig`).
"EighEigenvectorConfig", # Based on `EigenvectorConfig`.
"DefaultEighEigenvectorConfig", # Default `EigenvectorConfig` using `EighEigenvectorConfig`.
"QRConfig", # Based on `EigenvectorConfig`.
# Other utilities.
"compile_fsdp_parameter_metadata", # For `FSDPShampooConfig` and `HSDPShampooConfig`.
"CommunicationDType", # For `DDPShampooConfig` and `HSDPShampooConfig`.
]