Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ package com.owncloud.android.lib.resources.assistant.v2

import com.owncloud.android.AbstractIT
import com.owncloud.android.lib.resources.assistant.v2.model.Shape
import com.owncloud.android.lib.resources.assistant.v2.model.TaskInputShape
import com.owncloud.android.lib.resources.assistant.v2.model.TaskOutputShape
import com.owncloud.android.lib.resources.assistant.v2.model.TaskTypeData
import com.owncloud.android.lib.resources.status.NextcloudVersion
import junit.framework.TestCase.assertEquals
Expand All @@ -31,17 +29,17 @@ class AssistantV2Tests : AbstractIT() {
"Free text to text prompt",
"Runs an arbitrary prompt through a language model that returns a reply",
inputShape =
TaskInputShape(
input =
mapOf(
"input" to
Shape(
"Prompt",
"Describe a task that you want the assistant to do or ask a question",
"Text"
)
),
outputShape =
TaskOutputShape(
output =
mapOf(
"output" to
Shape(
"Generated reply",
"The generated text from the assistant",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ data class TaskTypes(

data class TaskType(
val id: String?,
val name: String?,
val name: String,
val description: String?
)

Expand All @@ -26,7 +26,7 @@ fun TaskTypes.toV2(): List<TaskTypeData> =
id = taskType.id,
name = taskType.name,
description = taskType.description,
inputShape = null,
outputShape = null
inputShape = emptyMap(),
outputShape = emptyMap()
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,32 @@ import com.owncloud.android.lib.resources.assistant.v2.model.TaskTypeData
import com.owncloud.android.lib.resources.assistant.v2.model.TaskTypes
import org.apache.commons.httpclient.HttpStatus

/**
* Returns a list of supported task types.
*
* Example JSON representation of one task type:
* ```
* {
* "id": "core:text2text",
* "name": "Free text to text prompt",
* "description": "Runs an arbitrary prompt through a language model that returns a reply",
* "inputShape": {
* "input": {
* "name": "Prompt",
* "description": "Describe a task that you want the assistant to do or ask a question",
* "type": "Text"
* }
* },
* "outputShape": {
* "output": {
* "name": "Generated reply",
* "description": "The generated text from the assistant",
* "type": "Text"
* }
* }
* }
* ```
*/
class GetTaskTypesRemoteOperationV2 : OCSRemoteOperation<List<TaskTypeData>>() {
private val supportedTaskType = "Text"

Expand All @@ -36,21 +62,18 @@ class GetTaskTypesRemoteOperationV2 : OCSRemoteOperation<List<TaskTypeData>>() {
getServerResponse(
getMethod,
object : TypeToken<ServerResponse<TaskTypes>>() {}
)?.ocs?.data?.types
)

val taskTypeList =
response?.map { (key, value) ->
value.copy(id = value.id ?: key)
}

val supportedTaskTypeList =
taskTypeList?.filter { taskType ->
taskType.inputShape?.input?.type == supportedTaskType &&
taskType.outputShape?.output?.type == supportedTaskType
}
response
?.ocs
?.data
?.types
?.map { (key, value) -> value.copy(id = value.id ?: key) }
?.filter { taskType -> isSingleTextInputOutput(taskType) }

result = RemoteOperationResult(true, getMethod)
result.setResultData(supportedTaskTypeList)
result.resultData = taskTypeList
} else {
result = RemoteOperationResult(false, getMethod)
}
Expand All @@ -67,6 +90,21 @@ class GetTaskTypesRemoteOperationV2 : OCSRemoteOperation<List<TaskTypeData>>() {
return result
}

private fun isSingleTextInputOutput(taskType: TaskTypeData): Boolean {
val inputShape = taskType.inputShape
val outputShape = taskType.outputShape

val hasOneTextInput =
inputShape.size == 1 &&
inputShape.values.first().type == supportedTaskType

val hasOneTextOutput =
outputShape.size == 1 &&
outputShape.values.first().type == supportedTaskType

return hasOneTextInput && hasOneTextOutput
}

companion object {
private val TAG = GetTaskTypesRemoteOperationV2::class.java.simpleName
private const val DIRECT_ENDPOINT = "/ocs/v2.php/taskprocessing/tasktypes"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,10 @@ data class TaskTypes(

data class TaskTypeData(
val id: String?,
val name: String?,
val name: String,
val description: String?,
val inputShape: TaskInputShape?,
val outputShape: TaskOutputShape?
)

data class TaskInputShape(
val input: Shape?
)

data class TaskOutputShape(
val output: Shape?
val inputShape: Map<String, Shape>,
val outputShape: Map<String, Shape>
)

data class Shape(
Expand Down
Loading