Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package com.azure.ai.openai;

import com.azure.ai.openai.implementation.CompletionsUtils;
import com.azure.ai.openai.implementation.MultipartDataHelper;
import com.azure.ai.openai.implementation.MultipartField;
import com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl;
import com.azure.ai.openai.implementation.OpenAIClientImpl;
import com.azure.ai.openai.implementation.OpenAIServerSentEvents;
Expand Down Expand Up @@ -824,6 +826,14 @@ public Response<BinaryData> getAudioTranslationWithResponse(
deploymentOrModelName, audioTranslationOptions, requestOptions);
}


@ServiceMethod(returns = ReturnType.SINGLE)
public Response<BinaryData> getAudioTranslationWithResponse(
String deploymentOrModelName, BinaryData audioTranslationOptions, RequestOptions requestOptions, String boundary, long contentLength) {
return this.serviceClient.getAudioTranslationWithResponse(
deploymentOrModelName, audioTranslationOptions, requestOptions, boundary, String.valueOf(contentLength));
}

/**
* Transcribes audio into the input language.
*
Expand Down Expand Up @@ -877,4 +887,30 @@ public AudioTranscription getAudioTranslation(
.getValue()
.toObject(AudioTranscription.class);
}

public AudioTranscription getAudioTranslation(
String deploymentOrModelName, AudioTranslationOptions audioTranslationOptions, String fileName
) {
RequestOptions requestOptions = new RequestOptions();
MultipartDataHelper helper = new MultipartDataHelper();
helper.addFields((fields) -> {
if (audioTranslationOptions.getResponseFormat() != null) {
fields.add(
new MultipartField(
"response_format",
audioTranslationOptions.getResponseFormat().toString())
);
}
});
MultipartDataHelper.SerializationResult result = helper.serializeAudioTranscriptionOption(audioTranslationOptions, fileName);
return getAudioTranslationWithResponse(
deploymentOrModelName,
result.getData(),
requestOptions,
helper.getBoundary(),
result.getDataLength()
)
.getValue()
.toObject(AudioTranscription.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package com.azure.ai.openai.implementation;

import com.azure.ai.openai.models.AudioTranslationOptions;
import com.azure.core.util.BinaryData;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.function.Consumer;

public class MultipartDataHelper {
private final String boundaryId = UUID.randomUUID().toString().substring(0, 16);

private final String boundary = "AZ-OAI-JAVA--" + boundaryId;

private final String partSeparator = "--" + boundary;
private final String endMarker = partSeparator + "--";

private final String CRLF = "\r\n";
private final List<MultipartField> fields = new ArrayList<>();

public String getBoundary() {
return boundary;
}

public SerializationResult serializeAudioTranscriptionOption (
AudioTranslationOptions audioTranscriptionOptions, String fileName) {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
// File
String fileFieldPreamble = partSeparator
+ CRLF + "Content-Disposition: form-data; name=\"file\"; filename=\""
+ fileName + "\""
+ CRLF + "Content-Type: application/octet-stream" + CRLF + CRLF;
try {
byteArrayOutputStream.write(fileFieldPreamble.getBytes(StandardCharsets.US_ASCII));
byteArrayOutputStream.write(audioTranscriptionOptions.getFile());
for (MultipartField field : fields) {
byteArrayOutputStream.write(serializeField(field));
}
byteArrayOutputStream.write((CRLF + endMarker).getBytes(StandardCharsets.US_ASCII));
} catch (IOException e) {
throw new RuntimeException(e);
}

byte[] totalData = byteArrayOutputStream.toByteArray();
// Uncomment to verify as string. Seems to check out with structure observed in the curl traces
System.out.println(new String(totalData, StandardCharsets.US_ASCII));
return new SerializationResult(BinaryData.fromBytes(totalData), totalData.length);
}

public void addFields(Consumer<List<MultipartField>> fieldAdder) {
fieldAdder.accept(fields);
}

private byte[] serializeField(MultipartField field) {
String toSerizalise = CRLF + partSeparator
+ CRLF + "Content-Disposition: form-data; name=\""
+ field.getWireName() + "\"" + CRLF + CRLF
+ field.getValue();

return toSerizalise.getBytes(StandardCharsets.US_ASCII);
}

public class SerializationResult {
private final long dataLength;
private final BinaryData data;

public SerializationResult(BinaryData data, long contentLength) {
this.dataLength = contentLength;
this.data = data;
}

public BinaryData getData() {
return data;
}

public long getDataLength() {
return dataLength;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.azure.ai.openai.implementation;

public class MultipartField {
private final String wireName;
private final String value;

public MultipartField(String wireName, String value) {
this.wireName = wireName;
this.value = value;
}

public String getWireName() {
return wireName;
}

public String getValue() {
return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,20 @@ Response<BinaryData> getAudioTranslationSync(
@BodyParam("multipart/form-data") BinaryData audioTranslationOptions,
RequestOptions requestOptions,
Context context);

@Post("/deployments/{deploymentId}/audio/translations")
@ExpectedResponses({200})
@UnexpectedResponseExceptionType(HttpResponseException.class)
Response<BinaryData> getAudioTranslationSync(
@HostParam("endpoint") String endpoint,
@QueryParam("api-version") String apiVersion,
@PathParam("deploymentId") String deploymentOrModelName,
@HeaderParam("content-type") String contentType,
@HeaderParam("accept") String accept,
@HeaderParam("content-length") String contentLength,
@BodyParam("multipart/form-data") BinaryData audioTranslationOptions,
RequestOptions requestOptions,
Context context);
}

/**
Expand Down Expand Up @@ -1928,4 +1942,21 @@ public Response<BinaryData> getAudioTranslationWithResponse(
requestOptions,
Context.NONE);
}

@ServiceMethod(returns = ReturnType.SINGLE)
public Response<BinaryData> getAudioTranslationWithResponse(
String deploymentOrModelName, BinaryData audioTranslationOptions, RequestOptions requestOptions, String boundary, String contentLength) {
final String contentType = "multipart/form-data; boundary=" + boundary ;
final String accept = "*/*";
return service.getAudioTranslationSync(
this.getEndpoint(),
this.getServiceVersion().getVersion(),
deploymentOrModelName,
contentType,
accept,
contentLength,
audioTranslationOptions,
requestOptions,
Context.NONE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.KeyCredential;
import com.azure.core.http.HttpClient;
import com.azure.core.http.policy.HttpLogDetailLevel;
import com.azure.core.http.policy.HttpLogOptions;
import com.azure.core.http.rest.Response;
import com.azure.core.test.TestMode;
import com.azure.core.test.TestProxyTestBase;
Expand Down Expand Up @@ -55,7 +57,7 @@ public abstract class OpenAIClientTestBase extends TestProxyTestBase {

OpenAIClientBuilder getOpenAIClientBuilder(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
OpenAIClientBuilder builder = new OpenAIClientBuilder()
// .httpLogOptions(new HttpLogOptions().setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS))
.httpLogOptions(new HttpLogOptions().setLogLevel(HttpLogDetailLevel.BODY_AND_HEADERS))
.httpClient(httpClient)
.serviceVersion(serviceVersion);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.azure.ai.openai.models.AudioTranscription;
import com.azure.ai.openai.models.AudioTranscriptionFormat;
import com.azure.ai.openai.models.AudioTranscriptionOptions;
import com.azure.ai.openai.models.AudioTranslationOptions;
import com.azure.ai.openai.models.AzureChatExtensionConfiguration;
import com.azure.ai.openai.models.AzureChatExtensionType;
import com.azure.ai.openai.models.AzureCognitiveSearchChatExtensionConfiguration;
Expand Down Expand Up @@ -420,4 +421,19 @@ public void testGetAudioTranscription(HttpClient httpClient, OpenAIServiceVersio
assertNotNull(transcription);
});
}

@ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS)
@MethodSource("com.azure.ai.openai.TestUtils#getTestParameters")
public void testGetAudioTranslation(HttpClient httpClient, OpenAIServiceVersion serviceVersion) {
client = getOpenAIClient(httpClient, serviceVersion);

getAudioTranscriptionRunner((deploymentName, fileName) -> {
byte[] file = BinaryData.fromFile(Path.of("src/test/resources/JP_it_is_rainy_today.wav")).toBytes();
AudioTranslationOptions translationOptions = new AudioTranslationOptions(file);

AudioTranscription transcription = client.getAudioTranslation(
deploymentName, translationOptions, "JP_it_is_rainy_today.wav");
assertNotNull(transcription);
});
}
}