Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/ml/azure-ai-ml",
"Tag": "python/ml/azure-ai-ml_2ff0425fbf"
"Tag": "python/ml/azure-ai-ml_baea630318"
}
4 changes: 4 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class _FlowComponentOverridableSchema(metaclass=PatchedSchemaMeta):
class FlowSchema(YamlFileSchema, _ComponentMetadataSchema, _FlowComponentOverridableSchema):
"""Schema for flow.dag.yaml file."""

environment_variables = fields.Dict(
fields.Str(),
fields.Str(),
)
additional_includes = fields.List(LocalPathField())


Expand Down
11 changes: 8 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def __init__(
self._column_mapping = column_mapping or {}
self._variant = variant
self._connections = connections or {}
self._environment_variables = environment_variables or {}

self._inputs = FlowComponentInputDict()
self._outputs = FlowComponentOutputDict()
Expand All @@ -266,8 +265,14 @@ def __init__(
# file existence has been checked in _get_flow_definition
# we don't need to rebase additional_includes as we have updated base_path
with open(Path(self.base_path, self._flow), "r", encoding="utf-8") as f:
flow_content = f.read()
additional_includes = yaml.safe_load(flow_content).get("additional_includes", None)
flow_content = yaml.safe_load(f.read())
additional_includes = flow_content.get("additional_includes", None)
# environment variables in run.yaml have higher priority than those in flow.dag.yaml
self._environment_variables = flow_content.get("environment_variables", {})
self._environment_variables.update(environment_variables or {})
else:
self._environment_variables = environment_variables or {}

self._additional_includes = additional_includes or []

# unlike other Component, code is a private property in FlowComponent and
Expand Down
4 changes: 2 additions & 2 deletions sdk/ml/azure-ai-ml/scripts/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def run_simple(
stdout = None
json_log_file_path = log_file_path.with_suffix(log_file_path.suffix + ".log")
else:
stdout = open(log_file_path.with_suffix(log_file_path.suffix + ".txt"), "wb", encoding="utf-8")
stdout = open(log_file_path.with_suffix(log_file_path.suffix + ".txt"), "w", encoding="utf-8")
json_log_file_path = None

with update_dot_env_file(
Expand All @@ -134,7 +134,7 @@ def run_simple(
+ tmp_extra_params,
cwd=working_dir,
stdout=stdout,
check=True,
check=False,
)
if log_in_json:
# append temp json file to the final log file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@ def test_component_load_from_dag(self):
"is_deterministic": True,
"code": "/subscriptions/xxx/resourceGroups/xxx/workspaces/xxx/codes/xxx/versions/1",
"flow_file_name": "flow.dag.yaml",
"environment_variables": {
"AZURE_OPENAI_API_BASE": "${my_connection.api_base}",
"AZURE_OPENAI_API_KEY": "${my_connection.api_key}",
"AZURE_OPENAI_API_TYPE": "azure",
"AZURE_OPENAI_API_VERSION": "2023-03-15-preview",
},
},
"description": "test load component from flow",
"is_anonymous": False,
"is_archived": False,
"properties": {
"client_component_hash": "b503491e-be3a-de50-0413-30c8c8abb43a",
"client_component_hash": "19278001-3d52-0e43-dc43-4082128d8243",
},
"tags": {},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2059,7 +2059,8 @@ def test_pipeline_job_with_flow(
# constructed based on response of code pending upload requests, and those requests have been normalized
# in playback mode and mixed up.
pipeline_job = load_job(source=test_path, params_override=[{"name": randstr("name")}])
assert client.jobs.validate(pipeline_job).passed
validation_result = client.jobs.validate(pipeline_job)
assert validation_result.passed, validation_result

created_pipeline_job = assert_job_cancel(pipeline_job, client)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ nodes:
max_tokens: "120"
environment:
python_requirements_txt: requirements.txt
environment_variables:
AZURE_OPENAI_API_TYPE: azure
AZURE_OPENAI_API_VERSION: 2023-03-15-preview
AZURE_OPENAI_API_KEY: ${my_connection.api_key}
AZURE_OPENAI_API_BASE: ${my_connection.api_base}
38 changes: 25 additions & 13 deletions sdk/ml/azure-ai-ml/tests/test_configs/flows/basic/hello.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import openai
from dotenv import load_dotenv
from openai.version import VERSION as OPENAI_VERSION
from promptflow import tool

# The inputs section will change based on the arguments of the tool function, after you save the code
Expand All @@ -13,6 +13,27 @@ def to_bool(value) -> bool:
return str(value).lower() == "true"


def get_client():
if OPENAI_VERSION.startswith("0."):
raise Exception(
"Please upgrade your OpenAI package to version >= 1.0.0 or using the command: pip install --upgrade openai."
)
api_key = os.environ["AZURE_OPENAI_API_KEY"]
conn = dict(
api_key=os.environ["AZURE_OPENAI_API_KEY"],
)
if api_key.startswith("sk-"):
from openai import OpenAI as Client
else:
from openai import AzureOpenAI as Client

conn.update(
azure_endpoint=os.environ["AZURE_OPENAI_API_BASE"],
api_version=os.environ.get("AZURE_OPENAI_API_VERSION", "2023-07-01-preview"),
)
return Client(**conn)


@tool
def my_python_tool(
prompt: str,
Expand All @@ -38,21 +59,14 @@ def my_python_tool(
load_dotenv()

if "AZURE_OPENAI_API_KEY" not in os.environ:
raise Exception("Please sepecify environment variables: AZURE_OPENAI_API_KEY")

conn = dict(
api_key=os.environ["AZURE_OPENAI_API_KEY"],
api_base=os.environ["AZURE_OPENAI_API_BASE"],
api_type=os.environ.get("AZURE_OPENAI_API_TYPE", "azure"),
api_version=os.environ.get("AZURE_OPENAI_API_VERSION", "2023-03-15-preview"),
)
raise Exception("Please specify environment variables: AZURE_OPENAI_API_KEY")

# TODO: remove below type conversion after client can pass json rather than string.
echo = to_bool(echo)

response = openai.Completion.create(
response = get_client().completions.create(
prompt=prompt,
engine=deployment_name,
model=deployment_name,
# empty string suffix should be treated as None.
suffix=suffix if suffix else None,
max_tokens=int(max_tokens),
Expand All @@ -69,8 +83,6 @@ def my_python_tool(
# Logit bias must be a dict if we passed it to openai api.
logit_bias=logit_bias if logit_bias else {},
user=user,
request_timeout=30,
**conn,
)

# get first element because prompt is single.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ environment_variables:
# environment variables from connection
AZURE_OPENAI_API_KEY: ${azure_open_ai_connection.api_key}
AZURE_OPENAI_API_BASE: ${azure_open_ai_connection.api_base}
AZURE_OPENAI_API_TYPE: azure
AZURE_OPENAI_API_VERSION: 2023-03-15-preview
connections:
llm:
connection: azure_open_ai_connection
Expand Down