Skip to content

Commit

Permalink
improve token handling, rename api key name, write tests for Azure Op…
Browse files Browse the repository at this point in the history
…enAI - Fixes #218 and #187
  • Loading branch information
JamesHWade committed Aug 21, 2024
1 parent de3a692 commit b10917d
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 26 deletions.
2 changes: 1 addition & 1 deletion R/api_skeletons.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ new_gptstudio_request_skeleton_google <- function(

new_gptstudio_request_skeleton_azure_openai <- function(
url = "user provided with environmental variables",
api_key = Sys.getenv("AZURE_OPENAI_KEY"),
api_key = Sys.getenv("AZURE_OPENAI_API_KEY"),
model = "gpt-4o-mini",
prompt = "What is a ggplot?",
history = list(
Expand Down
2 changes: 1 addition & 1 deletion R/gptstudio-sitrep.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ gptstudio_sitrep <- function(verbose = TRUE) {
cli::cli_h3("Checking Azure OpenAI API connection")
check_api_connection_azure_openai(
service = "Azure OpenAI",
api_key = Sys.getenv("AZURE_OPENAI_KEY")
api_key = Sys.getenv("AZURE_OPENAI_API_KEY")
)
cli::cli_h3("Checking Perplexity API connection")
check_api_connection_perplexity(
Expand Down
58 changes: 38 additions & 20 deletions R/service-azure_openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' @param deployment_name a character string for the deployment name. It will
#' default to the Azure OpenAI deployment name from environment variables if
#' not specified.
#' @param token a character string for the API key. It will default to the Azure
#' @param api_key a character string for the API key. It will default to the Azure
#' OpenAI API key from your environment variables if not specified.
#' @param api_version a character string for the API version. It will default to
#' the Azure OpenAI API version from your environment variables if not
Expand All @@ -27,15 +27,15 @@ create_completion_azure_openai <-
task = Sys.getenv("AZURE_OPENAI_TASK"),
base_url = Sys.getenv("AZURE_OPENAI_ENDPOINT"),
deployment_name = Sys.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
token = Sys.getenv("AZURE_OPENAI_KEY"),
api_key = Sys.getenv("AZURE_OPENAI_API_KEY"),
api_version = Sys.getenv("AZURE_OPENAI_API_VERSION")) {
request_body <- list(list(role = "user", content = prompt))
query_api_azure_openai(
task,
request_body,
base_url,
deployment_name,
token,
api_key,
api_version
)
}
Expand All @@ -44,41 +44,47 @@ request_base_azure_openai <-
function(task = Sys.getenv("AZURE_OPENAI_TASK"),
base_url = Sys.getenv("AZURE_OPENAI_ENDPOINT"),
deployment_name = Sys.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
token = Sys.getenv("AZURE_OPENAI_KEY"),
api_key = Sys.getenv("AZURE_OPENAI_API_KEY"),
api_version = Sys.getenv("AZURE_OPENAI_API_VERSION"),
use_token = Sys.getenv("AZURE_OPENAI_USE_TOKEN")) {
response <-
request(base_url) %>%
req_url_path_append("openai/deployments") %>%
req_url_path_append(deployment_name) %>%
req_url_path_append(task) %>%
req_url_query("api-version" = api_version) %>%
req_headers(
"api-key" = token,
"Content-Type" = "application/json"
)
req_url_query("api-version" = api_version)

if (is_true(as.logical(use_token))) {
token <- retrieve_azure_token()
response %>% req_auth_bearer_token(token = token)
response %>%
req_headers(
"api-key" = api_key,
"Content-Type" = "application/json"
) %>%
req_auth_bearer_token(token = token)
} else {
response
response %>%
req_headers(
"api-key" = api_key,
"Content-Type" = "application/json"
)
}

}

query_api_azure_openai <-
function(task = Sys.getenv("AZURE_OPENAI_TASK"),
request_body,
base_url = Sys.getenv("AZURE_OPENAI_ENDPOINT"),
deployment_name = Sys.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
token = Sys.getenv("AZURE_OPENAI_KEY"),
api_key = Sys.getenv("AZURE_OPENAI_API_KEY"),
api_version = Sys.getenv("AZURE_OPENAI_API_VERSION")) {
response <-
request_base_azure_openai(
task,
base_url,
deployment_name,
token,
api_key,
api_version
) %>%
req_body_json(list(messages = request_body)) %>%
Expand All @@ -104,12 +110,24 @@ query_api_azure_openai <-

retrieve_azure_token <- function() {
rlang::check_installed("AzureRMR")
token <- AzureRMR::create_azure_login(
tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"),
app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"),
password = Sys.getenv("AZURE_OPENAI_CLIENT_SECRET"),
host = "https://cognitiveservices.azure.com/",
scopes = ".default"
)

token <- tryCatch({
AzureRMR::get_azure_login(
tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"),
app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"),
scopes = ".default"
)
}, error = function(e) NULL)

if (is.null(token)) {
token <- AzureRMR::create_azure_login(
tenant = Sys.getenv("AZURE_OPENAI_TENANT_ID"),
app = Sys.getenv("AZURE_OPENAI_CLIENT_ID"),
password = Sys.getenv("AZURE_OPENAI_CLIENT_SECRET"),
host = "https://cognitiveservices.azure.com/",
scopes = ".default"
)
}

invisible(token$token$credentials$access_token)
}
2 changes: 1 addition & 1 deletion man/create_completion_azure_openai.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

217 changes: 217 additions & 0 deletions tests/testthat/test-service-azure_openai.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
test_that("create_completion_azure_openai formats request correctly", {
mock_query_api <- function(task, request_body, base_url, deployment_name,
api_key, api_version) {
list(choices = list(list(message = list(content = "Mocked response"))))
}

withr::with_envvar(
new = c(
AZURE_OPENAI_TASK = "env_task",
AZURE_OPENAI_ENDPOINT = "https://env.openai.azure.com",
AZURE_OPENAI_DEPLOYMENT_NAME = "env_deployment",
AZURE_OPENAI_API_KEY = "env_token",
AZURE_OPENAI_API_VERSION = "env_version"
),
{
local_mocked_bindings(
query_api_azure_openai = mock_query_api
)

result <- create_completion_azure_openai("Test prompt")

expect_type(result, "list")
expect_equal(result$choices[[1]]$message$content, "Mocked response")
}
)
})

test_that("request_base_azure_openai constructs correct request", {
mock_request <- function(url) {
structure(list(url = url, headers = list()), class = "httr2_request")
}

mock_req_url_path_append <- function(req, path) {
req$url <- paste0(req$url, "/", path)
req
}

mock_req_url_query <- function(req, ...) {
req$url <- paste0(req$url, "?api-version=test_version")
req
}

mock_req_headers <- function(req, ...) {
req$headers <- list("api-key" = "test_token",
"Content-Type" = "application/json")
req
}

withr::with_envvar(
new = c(AZURE_OPENAI_USE_TOKEN = "false"),
{
local_mocked_bindings(
request = mock_request,
req_url_path_append = mock_req_url_path_append,
req_url_query = mock_req_url_query,
req_headers = mock_req_headers
)

result <- request_base_azure_openai(
task = "test_task",
base_url = "https://test.openai.azure.com",
deployment_name = "test_deployment",
api_key = "test_token",
api_version = "test_version"
)

expect_equal(result$url, "https://test.openai.azure.com/openai/deployments/test_deployment/test_task?api-version=test_version")

Check warning on line 67 in tests/testthat/test-service-azure_openai.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-service-azure_openai.R,line=67,col=101,[line_length_linter] Lines should not be more than 100 characters. This line is 133 characters.
expect_equal(result$headers, list("api-key" = "test_token",
"Content-Type" = "application/json"))
}
)
})

test_that("query_api_azure_openai handles successful response", {
mock_request_base <- function(...) {
structure(list(url = "https://test.openai.azure.com", headers = list()),
class = "httr2_request")
}

mock_req_perform <- function(req) {
structure(list(status_code = 200, body = '{"result": "success"}'),
class = "httr2_response")
}

mock_resp_body_json <- function(resp) list(result = "success")

local_mocked_bindings(
request_base_azure_openai = mock_request_base,
req_body_json = function(req, body) req,
req_retry = function(req, max_tries) req,
req_error = function(req, is_error) req,
req_perform = mock_req_perform,
resp_is_error = function(resp) FALSE,
resp_body_json = mock_resp_body_json
)

result <- query_api_azure_openai(
task = "test_task",
request_body = list(list(role = "user", content = "Test prompt")),
base_url = "https://test.openai.azure.com",
deployment_name = "test_deployment",
api_key = "test_token",
api_version = "test_version"
)

expect_type(result, "list")
expect_equal(result$result, "success")
})

test_that("query_api_azure_openai handles error response", {
mock_request_base <- function(...) {
structure(list(url = "https://test.openai.azure.com", headers = list()),
class = "httr2_request")
}

mock_req_perform <- function(req) {
structure(list(status_code = 400, body = '{"error": "Bad Request"}'),
class = "httr2_response")
}

local_mocked_bindings(
request_base_azure_openai = mock_request_base,
req_body_json = function(req, body) req,
req_retry = function(req, max_tries) req,
req_error = function(req, is_error) req,
req_perform = mock_req_perform,
resp_is_error = function(resp) TRUE,
resp_status = function(resp) 400,
resp_status_desc = function(resp) "Bad Request"
)

expect_error(
query_api_azure_openai(
task = "test_task",
request_body = list(list(role = "user", content = "Test prompt")),
base_url = "https://test.openai.azure.com",
deployment_name = "test_deployment",
api_key = "test_token",
api_version = "test_version"
),
"Azure OpenAI API request failed. Error 400 - Bad Request"
)
})

# Test token retrieval --------------------------------------------------------

test_that("retrieve_azure_token successfully gets existing token", {
local_mocked_bindings(
get_azure_login = function(...) list(token = list(credentials = list(access_token = "existing_token"))),

Check warning on line 149 in tests/testthat/test-service-azure_openai.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-service-azure_openai.R,line=149,col=101,[line_length_linter] Lines should not be more than 100 characters. This line is 108 characters.
create_azure_login = function(...) stop("Should not be called"),
.package = "AzureRMR"
)

token <- retrieve_azure_token()

expect_equal(token, "existing_token")
})

test_that("retrieve_azure_token creates new token when get_azure_login fails", {
local_mocked_bindings(
get_azure_login = function(...) stop("Error"),
create_azure_login = function(...) list(token = list(credentials = list(access_token = "new_token"))),

Check warning on line 162 in tests/testthat/test-service-azure_openai.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-service-azure_openai.R,line=162,col=101,[line_length_linter] Lines should not be more than 100 characters. This line is 106 characters.
.package = "AzureRMR"
)

token <- retrieve_azure_token()

expect_equal(token, "new_token")
})

test_that("retrieve_azure_token uses correct environment variables", {
mock_get_azure_login <- function(tenant, app, scopes) {
expect_equal(tenant, "test_tenant")
expect_equal(app, "test_client")
expect_equal(scopes, ".default")
stop("Error")
}

mock_create_azure_login <- function(tenant, app, password, host, scopes) {
expect_equal(tenant, "test_tenant")
expect_equal(app, "test_client")
expect_equal(password, "test_secret")
expect_equal(host, "https://cognitiveservices.azure.com/")
expect_equal(scopes, ".default")
list(token = list(credentials = list(access_token = "new_token")))
}

local_mocked_bindings(
get_azure_login = mock_get_azure_login,
create_azure_login = mock_create_azure_login,
.package = "AzureRMR"
)

withr::local_envvar(
AZURE_OPENAI_TENANT_ID = "test_tenant",
AZURE_OPENAI_CLIENT_ID = "test_client",
AZURE_OPENAI_CLIENT_SECRET = "test_secret"
)

expect_no_error(retrieve_azure_token())
})

test_that("retrieve_azure_token checks for AzureRMR installation", {
mock_check_installed <- function(pkg) {
expect_equal(pkg, "AzureRMR")
}

local_mocked_bindings(
check_installed = mock_check_installed,
.package = "rlang"
)

expect_no_error(tryCatch(
retrieve_azure_token(),
error = function(e) {}
))
})
6 changes: 3 additions & 3 deletions vignettes/azure.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ To configure gptstudio to work using Azure OpenAI service, you need to provide s
- AZURE_OPENAI_TASK
- AZURE_OPENAI_ENDPOINT
- AZURE_OPENAI_DEPLOYMENT_NAME
- AZURE_OPENAI_KEY
- AZURE_OPENAI_API_KEY
- AZURE_OPENAI_API_VERSION
- AZURE_OPENAI_USE_TOKEN

Expand All @@ -31,15 +31,15 @@ Here's how you can add these details to your .Renviron file:
2. Add environment variable details: Add a new line for each variable you need to set in the following format: VARIABLE_NAME="YOUR_VALUE". Replace VARIABLE_NAME with the name of the environment variable and YOUR_VALUE with the actual value that you want to set. For example, to set the API key you would have a line like this:

```bash
AZURE_OPENAI_KEY="your_actual_key_goes_here"
AZURE_OPENAI_API_KEY="your_actual_key_goes_here"
```
You need to do this for each of the environment variables expected by the function. Your .Renviron file should look something like this:

```bash
AZURE_OPENAI_TASK="your_task_code"
AZURE_OPENAI_ENDPOINT="your_endpoint_url"
AZURE_OPENAI_DEPLOYMENT_NAME="your_deployment_name"
AZURE_OPENAI_KEY="your_api_key"
AZURE_OPENAI_API_KEY="your_api_key"
AZURE_OPENAI_API_VERSION="your_api_version"
AZURE_OPENAI_USE_TOKEN=FALSE
```
Expand Down

0 comments on commit b10917d

Please sign in to comment.