Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: mypy coverage enhancement #6250

Merged
merged 16 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions python/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ ignore_errors = true
[mypy-semantic_kernel.planners.*]
ignore_errors = true

[mypy-semantic_kernel.prompt_template.*]
ignore_errors = true

[mypy-semantic_kernel.reliability.*]
ignore_errors = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Jinja2PromptTemplate(PromptTemplateBase):
Jinja2TemplateSyntaxError: If there is a syntax error in the Jinja2 template.
"""

_env: ImmutableSandboxedEnvironment = PrivateAttr()
_env: ImmutableSandboxedEnvironment | None = PrivateAttr()

@field_validator("prompt_template_config")
@classmethod
Expand Down Expand Up @@ -95,6 +95,7 @@ async def render(self, kernel: "Kernel", arguments: Optional["KernelArguments"]
}
)
try:
assert self.prompt_template_config.template is not None
template = self._env.from_string(self.prompt_template_config.template, globals=helpers)
return template.render(**arguments)
except TemplateError as exc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from semantic_kernel.prompt_template.input_variable import InputVariable
from semantic_kernel.prompt_template.prompt_template_base import PromptTemplateBase
from semantic_kernel.template_engine.blocks.block import Block
from semantic_kernel.template_engine.blocks.block_types import BlockTypes
from semantic_kernel.template_engine.blocks.code_block import CodeBlock
from semantic_kernel.template_engine.blocks.named_arg_block import NamedArgBlock
from semantic_kernel.template_engine.blocks.var_block import VarBlock
from semantic_kernel.template_engine.template_tokenizer import TemplateTokenizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -39,25 +41,27 @@ def model_post_init(self, __context: Any) -> None:

# Enumerate every block in the template, adding any variables that are referenced.
for block in self._blocks:
if block.type == BlockTypes.VARIABLE:
if isinstance(block, VarBlock):
# Add all variables from variable blocks, e.g. "{{$a}}".
self._add_if_missing(block.name, seen)
continue
if block.type == BlockTypes.CODE:
if isinstance(block, CodeBlock):
for sub_block in block.tokens:
if sub_block.type == BlockTypes.VARIABLE:
if isinstance(sub_block, VarBlock):
# Add all variables from code blocks, e.g. "{{p.bar $b}}".
self._add_if_missing(sub_block.name, seen)
continue
if sub_block.type == BlockTypes.NAMED_ARG and sub_block.variable:
if isinstance(sub_block, NamedArgBlock) and sub_block.variable:
# Add all variables from named arguments, e.g. "{{p.bar b = $b}}".
# represents a named argument for a function call.
# For example, in the template {{ MyPlugin.MyFunction var1=$boo }}, var1=$boo
# is a named arg block.
self._add_if_missing(sub_block.variable.name, seen)

def _add_if_missing(self, variable_name: str, seen: Optional[set] = None):
def _add_if_missing(self, variable_name: str | None, seen: Optional[set] = None):
# Convert variable_name to lower case to handle case-insensitivity
if not seen:
seen = set()
if variable_name and variable_name.lower() not in seen:
seen.add(variable_name.lower())
self.prompt_template_config.input_variables.append(InputVariable(name=variable_name))
Expand Down Expand Up @@ -141,7 +145,7 @@ def render_variables(

rendered_blocks: List[Block] = []
for block in blocks:
if block.type == BlockTypes.VARIABLE:
if isinstance(block, VarBlock):
rendered_blocks.append(TextBlock.from_text(block.render(kernel, arguments)))
continue
rendered_blocks.append(block)
Expand All @@ -164,7 +168,7 @@ async def render_code(self, blocks: List[Block], kernel: "Kernel", arguments: "K

rendered_blocks: List[Block] = []
for block in blocks:
if block.type == BlockTypes.CODE:
if isinstance(block, CodeBlock):
rendered_blocks.append(TextBlock.from_text(await block.render_code(kernel, arguments)))
continue
rendered_blocks.append(block)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import asyncio
import logging
from typing import TYPE_CHECKING, Callable, Literal
from typing import TYPE_CHECKING, Callable

import nest_asyncio

from semantic_kernel.prompt_template.const import HANDLEBARS_TEMPLATE_FORMAT_NAME
from semantic_kernel.prompt_template.const import (
HANDLEBARS_TEMPLATE_FORMAT_NAME,
JINJA2_TEMPLATE_FORMAT_NAME,
TEMPLATE_FORMAT_TYPES,
)

if TYPE_CHECKING:
from semantic_kernel.functions.kernel_arguments import KernelArguments
Expand All @@ -21,9 +25,12 @@ def create_template_helper_from_function(
function: "KernelFunction",
kernel: "Kernel",
base_arguments: "KernelArguments",
template_format: Literal["handlebars", "jinja2"],
template_format: TEMPLATE_FORMAT_TYPES,
) -> Callable:
"""Create a helper function for both the Handlebars and Jinja2 templating engines from a kernel function."""
if template_format not in [JINJA2_TEMPLATE_FORMAT_NAME, HANDLEBARS_TEMPLATE_FORMAT_NAME]:
raise ValueError(f"Invalid template format: {template_format}")

if not getattr(asyncio, "_nest_patched", False):
nest_asyncio.apply()

Expand Down
Loading