diff --git a/core/testcontainers/compose/compose.py b/core/testcontainers/compose/compose.py index 61961ce0..bf628486 100644 --- a/core/testcontainers/compose/compose.py +++ b/core/testcontainers/compose/compose.py @@ -11,6 +11,8 @@ from types import TracebackType from typing import Any, Callable, Literal, Optional, TypeVar, Union, cast +from docker import DockerClient +from docker.models.containers import Container from testcontainers.core.exceptions import ContainerIsNotRunning, NoSuchPortExposed from testcontainers.core.waiting_utils import WaitStrategy @@ -137,9 +139,9 @@ def get_logs(self) -> tuple[bytes, bytes]: stdout, stderr = self._docker_compose.get_logs(self.Service) return stdout.encode(), stderr.encode() - def get_wrapped_container(self) -> "ComposeContainer": + def get_wrapped_container(self) -> Container: """Get the underlying container object for compatibility.""" - return self + return self._docker_compose._get_docker_client().containers.get(self.ID) def reload(self) -> None: """Reload container information for compatibility with wait strategies.""" @@ -214,7 +216,9 @@ class DockerCompose: services: Optional[list[str]] = None docker_command_path: Optional[str] = None profiles: Optional[list[str]] = None + docker_client_kw: Optional[dict[str, Any]] = None _wait_strategies: Optional[dict[str, Any]] = field(default=None, init=False, repr=False) + _docker_client: Optional[DockerClient] = field(default=None, init=False, repr=False) def __post_init__(self) -> None: if isinstance(self.compose_file_name, str): @@ -259,6 +263,13 @@ def compose_command_property(self) -> list[str]: docker_compose_cmd += ["--env-file", env_file] return docker_compose_cmd + def _get_docker_client(self) -> DockerClient: + dc = self._docker_client + if dc is None: + dc = DockerClient(**(self.docker_client_kw or {})) + self._docker_client = dc + return dc + def waiting_for(self, strategies: dict[str, WaitStrategy]) -> "DockerCompose": """ Set wait strategies for specific services. diff --git a/core/testcontainers/core/waiting_utils.py b/core/testcontainers/core/waiting_utils.py index 9942854a..55f6efa4 100644 --- a/core/testcontainers/core/waiting_utils.py +++ b/core/testcontainers/core/waiting_utils.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Optional, Protocol, TypeVar, Union, cast import wrapt +from docker.models.containers import Container from typing_extensions import Self from testcontainers.core.config import testcontainers_config @@ -52,7 +53,7 @@ def get_exposed_port(self, port: int) -> int: """Get the exposed port mapping for the given internal port.""" ... - def get_wrapped_container(self) -> Any: + def get_wrapped_container(self) -> Container: """Get the underlying container object.""" ...