Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/utils/monitor_swanlab_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Simple smoke test for SwanlabMonitor.

Run:
python cradle.py

What it does:
- Ensures SWANLAB_API_KEY is read from environment (sets a dummy if missing).
- Initializes SwanlabMonitor with minimal args.
- Logs a small metric and closes the run.

Notes:
- If `swanlab` is not installed, this script will print a helpful message and exit.
- The dummy API key is used only to exercise the login path; real authentication isn't required for this smoke test.
"""

import os
import sys


def main() -> int:
# Defer imports to keep error handling simple
try:
from trinity.utils.monitor import SwanlabMonitor
except Exception as e:
print("Failed to import SwanlabMonitor:", e)
return 1

# Ensure an env-based key path is exercised (uses dummy if not provided)
env_keys = ["SWANLAB_API_KEY", "SWANLAB_APIKEY", "SWANLAB_KEY", "SWANLAB_TOKEN"]
if not any(os.getenv(k) for k in env_keys):
os.environ["SWANLAB_API_KEY"] = "dummy_key_for_smoke_test"
print("Set SWANLAB_API_KEY to a dummy value to test env-based login path.")

# Try creating the monitor; if swanlab isn't installed, __init__ will assert
try:
mon = SwanlabMonitor(
project="trinity-smoke",
group="cradle",
name="swanlab-env",
role="tester",
config=None,
)
except AssertionError as e:
print("SwanLab not available or not installed:", e)
print("Install swanlab to run this smoke test: pip install swanlab")
return 0
except Exception as e:
print("Unexpected error constructing SwanlabMonitor:", e)
return 1

# Log a minimal metric to verify basic flow
try:
mon.log({"smoke/metric": 1.0}, step=1)
print("Logged a test metric via SwanlabMonitor.")
except Exception as e:
print("Error during logging:", e)
try:
mon.close()
except Exception:
pass
return 1

# Close cleanly
try:
mon.close()
print("SwanlabMonitor closed successfully.")
except Exception as e:
print("Error closing monitor:", e)
return 1

print("Smoke test completed.")
return 0


if __name__ == "__main__":
sys.exit(main())
129 changes: 129 additions & 0 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
import mlflow
except ImportError:
mlflow = None

try:
import swanlab
except ImportError:
swanlab = None

from torch.utils.tensorboard import SummaryWriter

from trinity.common.config import Config
Expand Down Expand Up @@ -224,3 +230,126 @@ def default_args(cls) -> Dict:
"username": None,
"password": None,
}


@MONITOR.register_module("swanlab")
class SwanlabMonitor(Monitor):
"""Monitor with SwanLab.
This monitor integrates with SwanLab (https://swanlab.cn/) to track experiments.
"""

def __init__(
self, project: str, group: str, name: str, role: str, config: Config = None
) -> None:
assert (
swanlab is not None
), "swanlab is not installed. Please install it to use SwanlabMonitor."

monitor_args = (
(config.monitor.monitor_args or {})
if config and getattr(config, "monitor", None)
else {}
)

# read api key from environment variable or monitor_args
api_key = monitor_args.get("api_key") or os.environ.get("SWANLAB_API_KEY")
if api_key:
try:
swanlab.login(api_key=api_key, save=True)
except Exception as e:
# Best-effort login; continue to init which may still work if already logged in
get_logger(__name__).warning(
f"Swanlab login failed, but continuing initialization: {e}"
)

# Compose tags (ensure list and include role/group markers)
tags = monitor_args.get("tags") or []
if isinstance(tags, tuple):
tags = list(tags)
if role and role not in tags:
tags.append(role)
if group and group not in tags:
tags.append(group)

# Determine experiment name
exp_name = monitor_args.get("experiment_name") or f"{name}_{role}"

# Prepare init kwargs, passing only non-None values to respect library defaults
init_kwargs = {
"project": project,
"experiment_name": exp_name,
"description": monitor_args.get("description"),
"tags": tags or None,
"logdir": monitor_args.get("logdir"),
"mode": monitor_args.get("mode") or "cloud",
"settings": monitor_args.get("settings"),
"id": monitor_args.get("id"),
"resume": monitor_args.get("resume"),
"reinit": monitor_args.get("reinit"),
}
# Strip None values to avoid overriding swanlab defaults
init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}

# Convert config to a plain dict for SwanLab config logging
cfg_dict = None
if config is not None:
if hasattr(config, "flatten"):
try:
cfg_dict = config.flatten()
except Exception:
# Fallback: try to cast to dict if possible
try:
cfg_dict = dict(config)
except Exception:
cfg_dict = None
else:
try:
cfg_dict = dict(config)
except Exception:
cfg_dict = None
Comment on lines +295 to +309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for converting the config object to a dictionary is quite complex with nested try-except blocks that silently ignore failures. This can be simplified and made more informative by logging a warning if the config conversion fails.

        if config is not None:
            try:
                if hasattr(config, "flatten"):
                    cfg_dict = config.flatten()
                else:
                    cfg_dict = dict(config)
            except Exception as e:
                get_logger(__name__).warning(f"Could not convert config to a dictionary for SwanLab: {e}")

if cfg_dict is not None:
init_kwargs["config"] = cfg_dict

self.logger = swanlab.init(**init_kwargs)
self.console_logger = get_logger(__name__, in_ray_actor=True)

def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
# Convert pandas DataFrame to SwanLab ECharts Table
headers: List[str] = list(experiences_table.columns)
# Ensure rows are native Python types
rows: List[List[object]] = experiences_table.astype(object).values.tolist()
try:
tbl = swanlab.echarts.Table()
tbl.add(headers, rows)
swanlab.log({table_name: tbl}, step=step)
except Exception:
# Fallback: log as CSV string if echarts table is unavailable
csv_str = experiences_table.to_csv(index=False)
swanlab.log({table_name: csv_str}, step=step)
Comment on lines +325 to +328
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback to logging a table as a CSV string is a good defensive measure. However, it happens silently. It would be beneficial to log a warning when this fallback is triggered, so the user is aware that the richer ECharts table could not be created and can investigate the root cause.

        except Exception as e:
            self.console_logger.warning(
                f"Failed to log table '{table_name}' as SwanLab ECharts Table, falling back to CSV. Error: {e}"
            )
            # Fallback: log as CSV string if echarts table is unavailable
            csv_str = experiences_table.to_csv(index=False)
            swanlab.log({table_name: csv_str}, step=step)


def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
# SwanLab doesn't use commit flag; keep signature for compatibility
swanlab.log(data, step=step)
self.console_logger.info(f"Step {step}: {data}")

def close(self) -> None:
try:
# Prefer run.finish() if available
if hasattr(self, "logger") and hasattr(self.logger, "finish"):
self.logger.finish()
elif swanlab:
# Fallback to global finish
swanlab.finish()
except Exception as e:
logger = getattr(self, "console_logger", get_logger(__name__))
logger.warning(f"Error closing Swanlab monitor: {e}")

@classmethod
def default_args(cls) -> Dict:
"""Return default arguments for the monitor."""
return {
"api_key": None,
"mode": "cloud",
"logdir": None,
}