Skip to content

Commit 290e506

Browse files
committed
simplify
1 parent a03ee6f commit 290e506

File tree

2 files changed

+20
-26
lines changed

2 files changed

+20
-26
lines changed

src/forge/actors/policy.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,6 @@
1717
import torch
1818
import torch.distributed.checkpoint as dcp
1919
import torchstore as ts
20-
21-
from forge.actors._torchstore_utils import (
22-
extract_param_name,
23-
get_dcp_whole_state_dict_key,
24-
get_param_key,
25-
get_param_prefix,
26-
load_tensor_from_dcp,
27-
)
28-
29-
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
30-
from forge.data.sharding import VLLMSharding
31-
from forge.data_models.completion import Completion
32-
from forge.data_models.prompt import to_prompt
33-
from forge.interfaces import Policy as PolicyInterface
34-
from forge.observability.metrics import record_metric, Reduce
35-
from forge.observability.perf_tracker import Tracer
36-
from forge.types import ProcessConfig
3720
from monarch.actor import current_rank, endpoint, ProcMesh
3821
from torchstore.state_dict_utils import DELIM
3922
from vllm.config import VllmConfig
@@ -58,6 +41,23 @@
5841
from vllm.v1.structured_output import StructuredOutputManager
5942
from vllm.worker.worker_base import WorkerWrapperBase
6043

44+
from forge.actors._torchstore_utils import (
45+
extract_param_name,
46+
get_dcp_whole_state_dict_key,
47+
get_param_key,
48+
get_param_prefix,
49+
load_tensor_from_dcp,
50+
)
51+
52+
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
53+
from forge.data.sharding import VLLMSharding
54+
from forge.data_models.completion import Completion
55+
from forge.data_models.prompt import to_prompt
56+
from forge.interfaces import Policy as PolicyInterface
57+
from forge.observability.metrics import record_metric, Reduce
58+
from forge.observability.perf_tracker import Tracer
59+
from forge.types import ProcessConfig
60+
6161
logger = logging.getLogger(__name__)
6262
logger.setLevel(logging.INFO)
6363

tests/integration_tests/test_policy_update.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import asyncio
2424
import logging
2525
import shutil
26-
import sys
2726
from pathlib import Path
2827

2928
import torch
@@ -66,14 +65,9 @@ async def zero_out_model_states(self):
6665

6766
def load_config(config_path: str) -> DictConfig:
6867
"""Load configuration from YAML file."""
69-
try:
70-
cfg = OmegaConf.load(config_path)
71-
except Exception as e:
72-
raise RuntimeError(f"Failed to load config file {config_path}: {e}")
73-
68+
cfg = OmegaConf.load(config_path)
7469
if not isinstance(cfg, DictConfig):
7570
raise TypeError(f"Expected DictConfig, got {type(cfg)}")
76-
7771
cfg = resolve_hf_hub_paths(cfg)
7872
return cfg
7973

@@ -277,9 +271,9 @@ async def run_weight_sync_sanity_check(policy: Policy, rl_trainer: MockRLTrainer
277271
RESET = "\033[0m"
278272

279273
success_message = f"""
280-
{GREEN}{'='*60}
274+
{GREEN}{'=' * 60}
281275
✅ Weight sharding sanity check passed! ✅
282-
{'='*60}{RESET}
276+
{'=' * 60}{RESET}
283277
"""
284278

285279
print(success_message)

0 commit comments

Comments
 (0)