Skip to content

Commit bf0289e

Browse files
committed
new api
1 parent 917051c commit bf0289e

File tree

3 files changed

+2
-17
lines changed

3 files changed

+2
-17
lines changed

src/forge/actors/generator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,10 @@
5050
)
5151

5252
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
53-
<<<<<<< HEAD:src/forge/actors/generator.py
5453
from forge.data_models.completion import Completion
5554
from forge.data_models.prompt import to_prompt
5655
from forge.env import TORCHSTORE_USE_RDMA
5756
from forge.interfaces import Policy as GeneratorInterface
58-
from forge.data.sharding import VLLMSharding
59-
from forge.data_models.completion import Completion
60-
from forge.data_models.prompt import to_prompt
6157
from forge.observability.metrics import record_metric, Reduce
6258
from forge.observability.perf_tracker import Tracer
6359
from forge.types import ProcessConfig

src/forge/actors/trainer.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,6 @@
1818
import torch.distributed.checkpoint as dcp
1919
import torchstore as ts
2020

21-
from forge.actors._torchstore_utils import (
22-
DcpHandle,
23-
get_dcp_whole_state_dict_key,
24-
get_param_key,
25-
)
26-
27-
from forge.controller import ForgeActor
28-
from forge.data.utils import batch_to_device
29-
from forge.observability.metrics import record_metric, Reduce
30-
from forge.observability.perf_tracker import Tracer
31-
3221
from monarch.actor import current_rank, current_size, endpoint
3322
from torch import Tensor
3423
from torch.distributed.checkpoint._nested_dict import flatten_state_dict

tests/integration_tests/test_policy_update.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ async def test_sanity_check(self, _setup_and_teardown):
264264
for _, e in errs.items():
265265
assert not e, f"Validation failed with exception: {e}"
266266

267-
await policy.update_weights.fanout(policy_version=v1)
267+
await policy.update_weights.fanout(version=v1)
268268
all_errs = await policy._test_validate_model_params.fanout(
269269
_test_validate_params_all_zeros
270270
)
@@ -273,7 +273,7 @@ async def test_sanity_check(self, _setup_and_teardown):
273273
assert not e, f"Validation failed with exception: {e}"
274274

275275
# Reloading v0, getting back original weights
276-
await policy.update_weights.fanout(policy_version=v0)
276+
await policy.update_weights.fanout(version=v0)
277277
all_errs = await policy._test_validate_model_params.fanout(
278278
_test_validate_params_unchanged
279279
)

0 commit comments

Comments
 (0)