Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: elevenlabs sound-generation api #3355

Merged
merged 8 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ service Backend {
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
rpc TTS(TTSRequest) returns (Result) {}
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
rpc Status(HealthMessage) returns (StatusResponse) {}

Expand Down Expand Up @@ -270,6 +271,17 @@ message TTSRequest {
optional string language = 5;
}

message SoundGenerationRequest {
string text = 1;
string model = 2;
string dst = 3;
optional float duration = 4;
optional float temperature = 5;
optional bool sample = 6;
optional string src = 7;
optional int32 src_divisor = 8;
}

message TokenizationResponse {
int32 length = 1;
repeated int32 tokens = 2;
Expand Down
60 changes: 57 additions & 3 deletions backend/python/transformers-musicgen/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import grpc

from scipy.io.wavfile import write as write_wav
from scipy.io import wavfile
from transformers import AutoProcessor, MusicgenForConditionalGeneration

_ONE_DAY_IN_SECONDS = 60 * 60 * 24
Expand Down Expand Up @@ -63,6 +63,61 @@ def LoadModel(self, request, context):

return backend_pb2.Result(message="Model loaded successfully", success=True)

def SoundGeneration(self, request, context):
model_name = request.model
if model_name == "":
return backend_pb2.Result(success=False, message="request.model is required")
try:
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
inputs = None
if request.text == "":
inputs = self.model.get_unconditional_inputs(num_samples=1)
elif request.HasField('src'):
# TODO SECURITY CODE GOES HERE LOL
# WHO KNOWS IF THIS WORKS???
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')

if request.HasField('src_divisor'):
wsamples = wsamples[: len(wsamples) // request.src_divisor]

inputs = self.processor(
audio=wsamples,
sampling_rate=sample_rate,
text=[request.text],
padding=True,
return_tensors="pt",
)
else:
inputs = self.processor(
text=[request.text],
padding=True,
return_tensors="pt",
)

tokens = 256
if request.HasField('duration'):
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
guidance = 3.0
if request.HasField('temperature'):
guidance = request.temperature
dosample = True
if request.HasField('sample'):
dosample = request.sample
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
print(request, file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)


# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
def TTS(self, request, context):
dave-gray101 marked this conversation as resolved.
Show resolved Hide resolved
model_name = request.model
if model_name == "":
Expand All @@ -75,8 +130,7 @@ def TTS(self, request, context):
padding=True,
return_tensors="pt",
)
tokens = 256
# TODO get tokens from request?
tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate
Expand Down
21 changes: 20 additions & 1 deletion backend/python/transformers-musicgen/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_load_model(self):

def test_tts(self):
"""
This method tests if the embeddings are generated successfully
This method tests if TTS is generated successfully
"""
try:
self.setUp()
Expand All @@ -77,5 +77,24 @@ def test_tts(self):
except Exception as err:
print(err)
self.fail("TTS service failed")
finally:
self.tearDown()

def test_sound_generation(self):
"""
This method tests if SoundGeneration is generated successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small"))
self.assertTrue(response.success)
sg_request = backend_pb2.SoundGenerationRequest(text="80s TV news production music hit for tonight's biggest story")
sg_response = stub.SoundGeneration(sg_request)
self.assertIsNotNone(sg_response)
except Exception as err:
print(err)
self.fail("SoundGeneration service failed")
finally:
self.tearDown()
2 changes: 1 addition & 1 deletion core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
case string:
protoMessages[i].Content = ct
default:
return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct)
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
}
}
}
Expand Down
74 changes: 74 additions & 0 deletions core/backend/soundgeneration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package backend

import (
"context"
"fmt"
"os"
"path/filepath"

"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
)

func SoundGeneration(
backend string,
modelFile string,
text string,
duration *float32,
temperature *float32,
doSample *bool,
sourceFile *string,
sourceDivisor *int32,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {
if backend == "" {
return "", nil, fmt.Errorf("backend is a required parameter")
}

grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(backend),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})

soundGenModel, err := loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}

if soundGenModel == nil {
return "", nil, fmt.Errorf("could not load sound generation model")
}

if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}

fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)

res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text,
Model: modelFile,
Dst: filePath,
Sample: doSample,
Duration: duration,
Temperature: temperature,
Src: sourceFile,
SrcDivisor: sourceDivisor,
})

// return RPC error if any
if !res.Success {
return "", nil, fmt.Errorf(res.Message)
}

return filePath, res, err
}
30 changes: 7 additions & 23 deletions core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,15 @@ import (
"github.com/mudler/LocalAI/core/config"

"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
)

func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext

for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}

counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}

func ModelTTS(
backend,
text,
modelFile,
voice ,
voice,
language string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
Expand Down Expand Up @@ -66,7 +50,7 @@ func ModelTTS(
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}

fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)

// If the model file is not empty, we pass it joined with the model path
Expand All @@ -88,10 +72,10 @@ func ModelTTS(
}

res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
Language: &language,
})

Expand Down
17 changes: 9 additions & 8 deletions core/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import (
var CLI struct {
cliContext.Context `embed:""`

Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
Util UtilCMD `cmd:"" help:"Utility commands"`
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
Util UtilCMD `cmd:"" help:"Utility commands"`
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
}
Loading