Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions R/chat-structured.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,63 @@
extract_data <- function(turn, type, convert = TRUE, needs_wrapper = FALSE) {
extract_data <- function(
turn,
type,
convert = TRUE,
needs_wrapper = FALSE,
prompt_index = NULL
) {
is_json <- map_lgl(turn@contents, S7_inherits, ContentJson)
n <- sum(is_json)
if (n != 1) {
cli::cli_abort("Data extraction failed: {n} data results recieved.")
if (n == 0) {
cli::cli_abort("Data extraction failed: 0 data results received.")
} else if (n == 1) {
# Normal case - exactly 1 JSON object
json <- turn@contents[[which(is_json)]]
out <- json@value
} else if (n == 2) {
# Check if the two JSON objects are identical (duplicate case)
json_indices <- which(is_json)
json1 <- turn@contents[[json_indices[1]]]
json2 <- turn@contents[[json_indices[2]]]
val1 <- json1@value
val2 <- json2@value
if (identical(val1, val2)) {
# Duplicate case - use the first one
index_msg <- if (!is.null(prompt_index)) {
paste0(" (prompt ", prompt_index, ")")
} else {
""
}
warning(
"Found duplicate JSON responses, using the first one",
index_msg,
".",
call. = FALSE,
immediate. = TRUE
)
out <- val1
} else {
# Different JSON objects - use the last one (likely the final response)
index_msg <- if (!is.null(prompt_index)) {
paste0(" (prompt ", prompt_index, ")")
} else {
""
}
warning(
"Found multiple different JSON responses, using the last one",
index_msg,
".",
call. = FALSE,
immediate. = TRUE
)
out <- val2
}
} else {
# More than 2 JSON objects - this is unexpected
cli::cli_abort(
"Data extraction failed: {n} data results received. Expected 1 or 2."
)
}

json <- turn@contents[[which(is_json)]]
out <- json@value

if (needs_wrapper) {
out <- out$wrapper
type <- type@properties[[1]]
Expand Down
5 changes: 3 additions & 2 deletions R/parallel-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,13 @@ multi_convert <- function(
) {
needs_wrapper <- type_needs_wrapper(type, provider)

rows <- map(turns, \(turn) {
rows <- imap(turns, \(turn, idx) {
extract_data(
turn = turn,
type = wrap_type_if_needed(type, needs_wrapper),
convert = FALSE,
needs_wrapper = needs_wrapper
needs_wrapper = needs_wrapper,
prompt_index = idx
)
})

Expand Down
8 changes: 0 additions & 8 deletions tests/testthat/_snaps/chat-structured.md

This file was deleted.

92 changes: 91 additions & 1 deletion tests/testthat/test-chat-structured.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

test_that("useful error if no ContentJson", {
turn <- Turn("assistant", list(ContentText("Hello")))
expect_snapshot(extract_data(turn), error = TRUE)
expect_error(
extract_data(turn),
"Data extraction failed: 0 data results received.",
fixed = TRUE
)
})

test_that("can extract data from ContentJson", {
Expand All @@ -24,6 +28,92 @@ test_that("can extract data when wrapper is used", {
expect_equal(extract_data(turn, type, needs_wrapper = TRUE), list(x = 1))
})

test_that("handles duplicate identical JSON responses", {
# This test covers the Bedrock duplicate JSON issue
json_data <- list(name = "John", age = 25)
turn <- Turn(
"assistant",
list(
ContentJson(json_data),
ContentJson(json_data) # Identical duplicate
)
)
type <- type_object(name = type_string(), age = type_integer())
# Should warn about duplicates and use the first one
expect_warning(
result <- extract_data(turn, type),
"Found duplicate JSON responses, using the first one"
)
expect_equal(result, list(name = "John", age = 25))
})

test_that("handles duplicate identical JSON responses with prompt index", {
# Test that prompt index is included in warning message
json_data <- list(score = 42)
turn <- Turn(
"assistant",
list(
ContentJson(json_data),
ContentJson(json_data)
)
)
type <- type_object(score = type_integer())
expect_warning(
extract_data(turn, type, prompt_index = 3),
"Found duplicate JSON responses, using the first one \\(prompt 3\\)"
)
})

test_that("handles different JSON responses", {
# This test covers the case where two different JSON objects are returned
turn <- Turn(
"assistant",
list(
ContentJson(list(name = "John", age = 25)),
ContentJson(list(name = "Jane", age = 30)) # Different data
)
)
type <- type_object(name = type_string(), age = type_integer())
# Should warn about multiple responses and use the last one
expect_warning(
result <- extract_data(turn, type),
"Found multiple different JSON responses, using the last one"
)
expect_equal(result, list(name = "Jane", age = 30))
})

test_that("handles different JSON responses with prompt index", {
turn <- Turn(
"assistant",
list(
ContentJson(list(value = 1)),
ContentJson(list(value = 2))
)
)
type <- type_object(value = type_integer())
expect_warning(
extract_data(turn, type, prompt_index = 5),
"Found multiple different JSON responses, using the last one \\(prompt 5\\)"
)
})

test_that("errors on more than 2 JSON responses", {
# Should error if there are more than 2 JSON objects
turn <- Turn(
"assistant",
list(
ContentJson(list(x = 1)),
ContentJson(list(x = 2)),
ContentJson(list(x = 3))
)
)
type <- type_object(x = type_integer())
expect_error(
extract_data(turn, type),
"Data extraction failed: 3 data results received. Expected 1 or 2."
)
})

# Type coercion ---------------------------------------------------------------

test_that("optional base types (scalars) stay as NULL", {
Expand Down
Loading