Skip to content

Commit 03e4570

Browse files
authored
Add automatic fallback to polling API for replicate requests (#104)
1 parent 2dac6b8 commit 03e4570

14 files changed

+1164
-811
lines changed

README.md

+174-79
Large diffs are not rendered by default.

Sources/AIProxy/AIProxy.swift

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ let aiproxyLogger = Logger(
1313
public struct AIProxy {
1414

1515
/// The current sdk version
16-
public static let sdkVersion = "0.69.0"
16+
public static let sdkVersion = "0.70.0"
1717

1818
/// - Parameters:
1919
/// - partialKey: Your partial key is displayed in the AIProxy dashboard when you submit your provider's key.
@@ -379,13 +379,15 @@ public struct AIProxy {
379379
/// - Parameters:
380380
/// - unprotectedAPIKey: Your Replicate API key
381381
/// - Returns: An instance of ReplicateService configured and ready to make requests
382+
#if false
382383
public static func replicateDirectService(
383384
unprotectedAPIKey: String
384385
) -> ReplicateService {
385386
return ReplicateDirectService(
386387
unprotectedAPIKey: unprotectedAPIKey
387388
)
388389
}
390+
#endif
389391

390392
/// AIProxy's ElevenLabs service
391393
///

Sources/AIProxy/Replicate/ReplicateDirectService.swift

+149-117
Large diffs are not rendered by default.

Sources/AIProxy/Replicate/ReplicateError.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public enum ReplicateError: LocalizedError {
2121
case .missingModelURL:
2222
return "The replicate model does not contain a URL"
2323
case .reachedRetryLimit:
24-
return "Reached replicate polling retry limit"
24+
return "Reached secondsToWait without the prediction completing"
2525
}
2626
}
2727
}

Sources/AIProxy/Replicate/ReplicateFileUploadRequestBody.swift

+10-4
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,23 @@
77

88
import Foundation
99

10-
struct ReplicateFileUploadRequestBody: MultipartFormEncodable {
10+
internal struct ReplicateFileUploadRequestBody: MultipartFormEncodable {
1111

12-
let fileData: Data
12+
/// The binary contents of the file
13+
let contents: Data
14+
15+
/// The file mime type
16+
let contentType: String
17+
18+
/// The name of the file. I believe this does not get preserved on replicate's CDN. Can it be removed?
1319
let fileName: String
1420

1521
var formFields: [FormField] {
1622
return [
1723
.fileField(
1824
name: "content",
19-
content: self.fileData,
20-
contentType: "application/zip",
25+
content: self.contents,
26+
contentType: self.contentType,
2127
filename: self.fileName
2228
)
2329
]

Sources/AIProxy/Replicate/ReplicatePredictionRequestBody.swift

+11-3
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,21 @@
77

88
import Foundation
99

10-
/// The request body for creating a Replicate prediction:
11-
/// https://replicate.com/docs/reference/http#create-a-prediction
10+
/// The request body for creating a Replicate prediction.
11+
///
12+
/// This type is used for both community models and official models.
13+
/// When using with an official model, the `version` property can remain `nil`.
14+
///
15+
/// Community model reference: https://replicate.com/docs/reference/http#predictions.create
16+
/// Official model reference: https://replicate.com/docs/reference/http#models.predictions.create
1217
public struct ReplicatePredictionRequestBody: Encodable {
18+
1319
/// The replicate input schema, for example ReplicateSDXLInputSchema
20+
/// TThe input schema depends on what model you are running. To see the available inputs, click the "API" tab on the model you are running or get the model version and look at its `openapi_schema` property. For example, `stability-ai/sdxl` takes `prompt` as an input.
1421
public let input: Encodable
1522

16-
/// The version of the model to run
23+
/// You do not need to set this field if you are using an official model.
24+
/// For community models, set it to the ID of the model version that you want to run.
1725
public let version: String?
1826

1927
public init(
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,62 @@
11
//
2-
// ReplicatePredictionResponseBody.swift
2+
// ReplicatePrediction.swift
33
//
44
//
55
// Created by Lou Zell on 8/25/24.
66
//
77

88
import Foundation
99

10-
/// Response body for a Replicate prediction.
11-
/// This format is used for both the "create a predition" and "get a prediction" endpoints:
10+
public typealias ReplicatePredictionResponseBody = ReplicatePrediction
11+
12+
/// Represents the current state of a replicate prediction
13+
///
14+
/// This type is used for both the "create a predition" and "get a prediction" endpoints:
1215
/// https://replicate.com/docs/reference/http#get-a-prediction
1316
/// https://replicate.com/docs/reference/http#create-a-prediction
14-
public struct ReplicatePredictionResponseBody<T: Decodable>: Decodable {
17+
///
18+
/// And it is used for both the sync and polling API:
19+
/// https://replicate.com/docs/topics/predictions/create-a-prediction#sync-mode
20+
/// https://replicate.com/docs/topics/predictions/create-a-prediction#polling
21+
public struct ReplicatePrediction<Output: Decodable>: Decodable {
1522

1623
/// ISO8601 date stamp of when the prediction completed
1724
public let completedAt: String?
1825

26+
/// ISO8601 date stamp of when the prediction was created
27+
public let createdAt: String?
28+
29+
/// https://replicate.com/docs/topics/predictions/data-retention
30+
public let dataRemoved: Bool?
31+
1932
/// In the case of failure, error will contain the error encountered during the prediction
2033
public let error: String?
2134

35+
/// ID of the prediction
36+
public let id: String?
37+
38+
/// Prediction logs
39+
public let logs: String?
40+
41+
/// How long the prediction took
42+
public let metrics: Metrics?
43+
44+
/// The model being run
45+
public let model: String?
46+
2247
/// The output adheres to Replicate's "output schema" structure.
2348
/// Schemas can be found at the Replicate model's detail page by tapping on `API > Schema > Output Schema`.
2449
/// In the case of SDXL, the output is an array of URLs, which you can see here:
2550
/// https://replicate.com/stability-ai/sdxl/api/schema#output-schema
26-
public let output: T?
51+
public let output: Output?
2752

2853
/// ISO8601 timestamp of start of prediction
2954
public let startedAt: String?
3055

31-
/// One of `starting`, `processing`, `succeeded`, `failed`, `canceled`
56+
/// One of `starting`, `processing`, `succeeded`, `failed`, `canceled`.
57+
///
58+
/// In the `succeeded` case, the `output` property on this type will be an object containing the output of the model.
59+
/// In the `failed` case, `error` property on this type will contain the error encountered during the prediction.
3260
public let status: Status?
3361

3462
/// URLs to cancel the prediction or get the result from the prediction
@@ -37,9 +65,20 @@ public struct ReplicatePredictionResponseBody<T: Decodable>: Decodable {
3765
/// The version of the model that ran
3866
public let version: String?
3967

68+
/// For compatibility with an older release of the libary
69+
public var predictionResultURL: URL? {
70+
return urls?.get
71+
}
72+
4073
private enum CodingKeys: String, CodingKey {
4174
case completedAt = "completed_at"
75+
case createdAt = "created_at"
76+
case dataRemoved = "data_removed"
4277
case error
78+
case id
79+
case logs
80+
case metrics
81+
case model
4382
case output
4483
case startedAt = "started_at"
4584
case status
@@ -48,19 +87,42 @@ public struct ReplicatePredictionResponseBody<T: Decodable>: Decodable {
4887
}
4988
}
5089

51-
extension ReplicatePredictionResponseBody {
90+
extension ReplicatePrediction {
5291
public struct ActionURLs: Decodable {
5392
public let cancel: URL?
5493
public let get: URL?
5594
}
5695
}
5796

58-
extension ReplicatePredictionResponseBody {
97+
extension ReplicatePrediction {
98+
public struct Metrics: Decodable {
99+
public let predictTime: Double
100+
101+
enum CodingKeys: String, CodingKey {
102+
case predictTime = "predict_time"
103+
}
104+
}
105+
}
106+
107+
extension ReplicatePrediction {
59108
public enum Status: String, Decodable {
109+
/// The prediction is starting up. If this status lasts longer than a few seconds, then it's typically because a new worker is being started to run the prediction.
60110
case starting
111+
112+
/// The `predict()` method of the model is currently running.
61113
case processing
114+
115+
/// The prediction completed successfully.
62116
case succeeded
117+
118+
/// The prediction encountered an error during processing.
63119
case failed
120+
121+
/// The prediction was canceled by its creator
64122
case canceled
123+
124+
var isTerminal: Bool {
125+
return [.succeeded, .failed, .canceled].contains(self)
126+
}
65127
}
66128
}

0 commit comments

Comments
 (0)