Skip to content

Commit

Permalink
feat: Add a cloneable protocol for Reasoning Engine.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638428830
  • Loading branch information
Yeesian Ng authored and Copybara-Service committed May 29, 2024
1 parent 3b83ba9 commit 8960a80
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ def test_set_up(self, vertexai_init_mock):
agent.set_up()
assert agent._runnable is not None

def test_clone(self, vertexai_init_mock):
agent = reasoning_engines.LangchainAgent(
model=_TEST_MODEL,
prompt=self.prompt,
output_parser=self.output_parser,
)
agent.set_up()
assert agent._runnable is not None
agent_clone = agent.clone()
assert agent._runnable is not None
assert agent_clone._runnable is None
agent_clone.set_up()
assert agent_clone._runnable is not None

def test_query(self, langchain_dump_mock):
agent = reasoning_engines.LangchainAgent(
model=_TEST_MODEL,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/vertex_langchain/test_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def query(self, unused_arbitrary_string_name: str) -> str:
"""Runs the engine."""
return unused_arbitrary_string_name.upper()

def clone(self):
return self


_TEST_RETRY = base._DEFAULT_RETRY
_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())
Expand Down
21 changes: 18 additions & 3 deletions vertexai/preview/reasoning_engines/templates/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

Expand Down Expand Up @@ -390,6 +387,24 @@ def set_up(self):
runnable_kwargs=self._runnable_kwargs,
)

def clone(self) -> "LangchainAgent":
"""Returns a clone of the LangchainAgent."""
import copy

return LangchainAgent(
model=self._model_name,
prompt=copy.deepcopy(self._prompt),
tools=copy.deepcopy(self._tools),
output_parser=copy.deepcopy(self._output_parser),
chat_history=copy.deepcopy(self._chat_history),
model_kwargs=copy.deepcopy(self._model_kwargs),
model_tool_kwargs=copy.deepcopy(self._model_tool_kwargs),
agent_executor_kwargs=copy.deepcopy(self._agent_executor_kwargs),
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
model_builder=self._model_builder,
runnable_builder=self._runnable_builder,
)

def query(
self,
*,
Expand Down
12 changes: 12 additions & 0 deletions vertexai/reasoning_engines/_reasoning_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ def query(self, **kwargs):
"""Runs the Reasoning Engine to serve the user query."""


@typing.runtime_checkable
class Cloneable(Protocol):
"""Protocol for Reasoning Engine applications that can be cloned."""

@abc.abstractmethod
def clone(self):
"""Return a clone of the object."""


class ReasoningEngine(base.VertexAiResourceNounWithFutureManager, Queryable):
"""Represents a Vertex AI Reasoning Engine resource."""

Expand Down Expand Up @@ -214,6 +223,9 @@ def create(
"Invalid query signature. This might be due to a missing "
"`self` argument in the reasoning_engine.query method."
) from err
if isinstance(reasoning_engine, Cloneable):
# Avoid undeployable ReasoningChain states.
reasoning_engine = reasoning_engine.clone()
if isinstance(requirements, str):
try:
_LOGGER.info(f"Reading requirements from {requirements=}")
Expand Down

0 comments on commit 8960a80

Please sign in to comment.