Skip to content

Commit

Permalink
feat: elevenlabs sound-generation api (#3355)
Browse files Browse the repository at this point in the history
* initial version of elevenlabs compatible soundgeneration api and cli command

Signed-off-by: Dave Lee <[email protected]>

* minor cleanup

Signed-off-by: Dave Lee <[email protected]>

* restore TTS, add test

Signed-off-by: Dave Lee <[email protected]>

* remove stray s

Signed-off-by: Dave Lee <[email protected]>

* fix

Signed-off-by: Dave Lee <[email protected]>

---------

Signed-off-by: Dave Lee <[email protected]>
Signed-off-by: Ettore Di Giacinto <[email protected]>
Co-authored-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
dave-gray101 and mudler authored Aug 24, 2024
1 parent 84d6e5a commit 81ae92f
Show file tree
Hide file tree
Showing 20 changed files with 450 additions and 37 deletions.
12 changes: 12 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ service Backend {
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
rpc TTS(TTSRequest) returns (Result) {}
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
rpc Status(HealthMessage) returns (StatusResponse) {}

Expand Down Expand Up @@ -270,6 +271,17 @@ message TTSRequest {
optional string language = 5;
}

message SoundGenerationRequest {
string text = 1;
string model = 2;
string dst = 3;
optional float duration = 4;
optional float temperature = 5;
optional bool sample = 6;
optional string src = 7;
optional int32 src_divisor = 8;
}

message TokenizationResponse {
int32 length = 1;
repeated int32 tokens = 2;
Expand Down
60 changes: 57 additions & 3 deletions backend/python/transformers-musicgen/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import grpc

from scipy.io.wavfile import write as write_wav
from scipy.io import wavfile
from transformers import AutoProcessor, MusicgenForConditionalGeneration

_ONE_DAY_IN_SECONDS = 60 * 60 * 24
Expand Down Expand Up @@ -63,6 +63,61 @@ def LoadModel(self, request, context):

return backend_pb2.Result(message="Model loaded successfully", success=True)

def SoundGeneration(self, request, context):
model_name = request.model
if model_name == "":
return backend_pb2.Result(success=False, message="request.model is required")
try:
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
inputs = None
if request.text == "":
inputs = self.model.get_unconditional_inputs(num_samples=1)
elif request.HasField('src'):
# TODO SECURITY CODE GOES HERE LOL
# WHO KNOWS IF THIS WORKS???
sample_rate, wsamples = wavfile.read('path_to_your_file.wav')

if request.HasField('src_divisor'):
wsamples = wsamples[: len(wsamples) // request.src_divisor]

inputs = self.processor(
audio=wsamples,
sampling_rate=sample_rate,
text=[request.text],
padding=True,
return_tensors="pt",
)
else:
inputs = self.processor(
text=[request.text],
padding=True,
return_tensors="pt",
)

tokens = 256
if request.HasField('duration'):
tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second
guidance = 3.0
if request.HasField('temperature'):
guidance = request.temperature
dosample = True
if request.HasField('sample'):
dosample = request.sample
audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=tokens)
print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate
wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy())
print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr)
print("[transformers-musicgen] SoundGeneration for", file=sys.stderr)
print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr)
print(request, file=sys.stderr)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(success=True)


# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
def TTS(self, request, context):
model_name = request.model
if model_name == "":
Expand All @@ -75,8 +130,7 @@ def TTS(self, request, context):
padding=True,
return_tensors="pt",
)
tokens = 256
# TODO get tokens from request?
tokens = 512 # No good place to set the "length" in TTS, so use 10s as a sane default
audio_values = self.model.generate(**inputs, max_new_tokens=tokens)
print("[transformers-musicgen] TTS generated!", file=sys.stderr)
sampling_rate = self.model.config.audio_encoder.sampling_rate
Expand Down
21 changes: 20 additions & 1 deletion backend/python/transformers-musicgen/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_load_model(self):

def test_tts(self):
"""
This method tests if the embeddings are generated successfully
This method tests if TTS is generated successfully
"""
try:
self.setUp()
Expand All @@ -77,5 +77,24 @@ def test_tts(self):
except Exception as err:
print(err)
self.fail("TTS service failed")
finally:
self.tearDown()

def test_sound_generation(self):
"""
This method tests if SoundGeneration is generated successfully
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small"))
self.assertTrue(response.success)
sg_request = backend_pb2.SoundGenerationRequest(text="80s TV news production music hit for tonight's biggest story")
sg_response = stub.SoundGeneration(sg_request)
self.assertIsNotNone(sg_response)
except Exception as err:
print(err)
self.fail("SoundGeneration service failed")
finally:
self.tearDown()
2 changes: 1 addition & 1 deletion core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
case string:
protoMessages[i].Content = ct
default:
return nil, fmt.Errorf("Unsupported type for schema.Message.Content for inference: %T", ct)
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
}
}
}
Expand Down
74 changes: 74 additions & 0 deletions core/backend/soundgeneration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package backend

import (
"context"
"fmt"
"os"
"path/filepath"

"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
)

func SoundGeneration(
backend string,
modelFile string,
text string,
duration *float32,
temperature *float32,
doSample *bool,
sourceFile *string,
sourceDivisor *int32,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {
if backend == "" {
return "", nil, fmt.Errorf("backend is a required parameter")
}

grpcOpts := gRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(backend),
model.WithModel(modelFile),
model.WithContext(appConfig.Context),
model.WithAssetDir(appConfig.AssetsDestination),
model.WithLoadGRPCLoadModelOpts(grpcOpts),
})

soundGenModel, err := loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}

if soundGenModel == nil {
return "", nil, fmt.Errorf("could not load sound generation model")
}

if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}

fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "sound_generation", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)

res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text,
Model: modelFile,
Dst: filePath,
Sample: doSample,
Duration: duration,
Temperature: temperature,
Src: sourceFile,
SrcDivisor: sourceDivisor,
})

// return RPC error if any
if !res.Success {
return "", nil, fmt.Errorf(res.Message)
}

return filePath, res, err
}
30 changes: 7 additions & 23 deletions core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,15 @@ import (
"github.com/mudler/LocalAI/core/config"

"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
)

func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext

for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}

counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}

func ModelTTS(
backend,
text,
modelFile,
voice ,
voice,
language string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
Expand Down Expand Up @@ -66,7 +50,7 @@ func ModelTTS(
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}

fileName := generateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)

// If the model file is not empty, we pass it joined with the model path
Expand All @@ -88,10 +72,10 @@ func ModelTTS(
}

res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
Text: text,
Model: modelPath,
Voice: voice,
Dst: filePath,
Language: &language,
})

Expand Down
17 changes: 9 additions & 8 deletions core/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import (
var CLI struct {
cliContext.Context `embed:""`

Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
Util UtilCMD `cmd:"" help:"Utility commands"`
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"`
Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"`
Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"`
TTS TTSCMD `cmd:"" help:"Convert text to speech"`
SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"`
Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"`
Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"`
Util UtilCMD `cmd:"" help:"Utility commands"`
Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"`
}
Loading

0 comments on commit 81ae92f

Please sign in to comment.