Skip to content

Commit

Permalink
split framework TARGETS to individual files (pytorch#475)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#475

Another try after D47418445 reverted the previous change.
1. split the framework targets, copied from D47418445
2. change the imports to directly use name in submodule.
3. run autodeps through arc lint.

Reviewed By: gunchu

Differential Revision: D47638416

fbshipit-source-id: 82b0bd027907d2579b2b049dab84d5fee2d07b2c
  • Loading branch information
Fan Zhang(DevX) authored and facebook-github-bot committed Jul 24, 2023
1 parent 5506f61 commit fe9ea1c
Show file tree
Hide file tree
Showing 17 changed files with 35 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/source/framework/unit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Here is an example of a unit which extends TrainUnit, EvalUnit, and PredictUnit.

.. code-block:: python
from torchtnt.framework import TrainUnit, EvalUnit, PredictUnit
from torchtnt.framework.unit import TrainUnit, EvalUnit, PredictUnit
Batch = Tuple[torch.tensor, torch.tensor]
Expand Down
5 changes: 3 additions & 2 deletions examples/auto_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from torch.distributed import launcher as pet
from torch.utils.data.dataset import Dataset, TensorDataset
from torcheval.metrics import BinaryAccuracy
from torchtnt.framework import AutoUnit, fit, State
from torchtnt.framework.state import EntryPoint
from torchtnt.framework.auto_unit import AutoUnit
from torchtnt.framework.fit import fit
from torchtnt.framework.state import EntryPoint, State
from torchtnt.utils import init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger

Expand Down
4 changes: 3 additions & 1 deletion examples/mingpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
GPTConfig,
OptimizerConfig,
)
from torchtnt.framework import AutoUnit, fit, State
from torchtnt.framework.auto_unit import AutoUnit
from torchtnt.framework.fit import fit
from torchtnt.framework.state import State
from torchtnt.utils import init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger

Expand Down
4 changes: 3 additions & 1 deletion examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torcheval.metrics import MulticlassAccuracy
from torchtnt.framework import AutoUnit, fit, State
from torchtnt.framework.auto_unit import AutoUnit
from torchtnt.framework.fit import fit
from torchtnt.framework.state import State
from torchtnt.utils import init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger
from torchvision import datasets, transforms
Expand Down
4 changes: 3 additions & 1 deletion examples/torchdata_train_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
from torchdata.datapipes.iter import IterableWrapper
from torcheval.metrics import BinaryAccuracy
from torchtnt.framework import State, train, TrainUnit
from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.framework.unit import TrainUnit
from torchtnt.utils import copy_data_to_device, init_from_env, seed, TLRScheduler

from torchtnt.utils.loggers import TensorBoardLogger
Expand Down
6 changes: 4 additions & 2 deletions examples/torchrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.optim.keyed import KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter

from torchtnt.framework import EvalUnit, fit, State, TrainUnit
from torchtnt.framework.callbacks import TQDMProgressBar
from torchtnt.framework.fit import fit
from torchtnt.framework.state import State

from torchtnt.framework.unit import EvalUnit, TrainUnit
from torchtnt.utils import (
get_process_group_backend_from_device,
init_from_env,
Expand Down
4 changes: 3 additions & 1 deletion examples/train_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import torch.nn as nn
from torch.utils.data.dataset import Dataset, TensorDataset
from torcheval.metrics import BinaryAccuracy
from torchtnt.framework import State, train, TrainUnit
from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.framework.unit import TrainUnit
from torchtnt.utils import copy_data_to_device, init_from_env, seed, TLRScheduler

from torchtnt.utils.loggers import TensorBoardLogger
Expand Down
4 changes: 2 additions & 2 deletions tests/framework/callbacks/test_module_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from unittest.mock import MagicMock

import torch

from torchtnt.framework import AutoUnit
from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader

from torchtnt.framework.auto_unit import AutoUnit
from torchtnt.framework.callbacks.module_summary import ModuleSummary
from torchtnt.framework.state import EntryPoint, PhaseState, State
from torchtnt.utils.version import is_torch_version_geq_1_13
Expand Down
2 changes: 1 addition & 1 deletion tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import torch
from torch.distributed import launcher
from torch.optim.lr_scheduler import ExponentialLR
from torchtnt.framework import AutoUnit

from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader
from torchtnt.framework.auto_unit import AutoUnit
from torchtnt.framework.callbacks import Lambda, TorchSnapshotSaver
from torchtnt.framework.state import State
from torchtnt.framework.train import train
Expand Down
3 changes: 1 addition & 2 deletions torchtnt/framework/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, Dataset, IterableDataset, TensorDataset
from torchtnt.framework import PhaseState
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.state import EntryPoint, PhaseState, State
from torchtnt.framework.unit import EvalUnit, PredictUnit, TrainUnit

Batch = Tuple[torch.Tensor, torch.Tensor]
Expand Down
4 changes: 3 additions & 1 deletion torchtnt/framework/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class Callback:
.. code-block:: python
from torchtnt.framework import Callback, State, TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
class PrintingCallback(Callback):
def on_train_start(self, state: State, unit: TTrainUnit) -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/lambda_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class Lambda(Callback):
Examples::
from torchtnt.framework import evaluate
from torchtnt.framework.callbacks import Lambda
from torchtnt.framework.evaluate import evaluate
dataloader = MyDataLoader()
unit = MyUnit()
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def evaluate(
.. code-block:: python
from torchtnt.framework import evaluate
from torchtnt.framework.evaluate import evaluate
eval_unit = MyEvalUnit(module=..., optimizer=..., lr_scheduler=...)
eval_dataloader = torch.utils.data.DataLoader(...)
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def fit(
.. code-block:: python
from torchtnt.framework import fit
from torchtnt.framework.fit import fit
fit_unit = MyFitUnit(module=..., optimizer=..., lr_scheduler=...)
train_dataloader = torch.utils.data.DataLoader(...)
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def predict(
.. code-block:: python
from torchtnt.framework import predict
from torchtnt.framework.predict import predict
predict_unit = MyPredictUnit(module=..., optimizer=..., lr_scheduler=...)
predict_dataloader = torch.utils.data.DataLoader(...)
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def train(
.. code-block:: python
from torchtnt.framework import train
from torchtnt.framework.train import train
train_unit = MyTrainUnit(module=..., optimizer=..., lr_scheduler=...)
train_dataloader = torch.utils.data.DataLoader(...)
Expand Down
6 changes: 3 additions & 3 deletions torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class TrainUnit(AppStateMixin, _OnExceptionMixin, Generic[TTrainData], ABC):
.. code-block:: python
from torchtnt.framework import TrainUnit
from torchtnt.framework.unit import TrainUnit
Batch = Tuple[torch.tensor, torch.tensor]
# specify type of the data in each batch of the dataloader to allow for typechecking
Expand Down Expand Up @@ -286,7 +286,7 @@ class EvalUnit(AppStateMixin, _OnExceptionMixin, Generic[TEvalData], ABC):
.. code-block:: python
from torchtnt.framework import EvalUnit
from torchtnt.framework.unit import EvalUnit
Batch = Tuple[torch.tensor, torch.tensor]
# specify type of the data in each batch of the dataloader to allow for typechecking
Expand Down Expand Up @@ -373,7 +373,7 @@ class PredictUnit(
.. code-block:: python
from torchtnt.framework import PredictUnit
from torchtnt.framework.unit import PredictUnit
Batch = Tuple[torch.tensor, torch.tensor]
# specify type of the data in each batch of the dataloader to allow for typechecking
Expand Down

0 comments on commit fe9ea1c

Please sign in to comment.