-
Notifications
You must be signed in to change notification settings - Fork 4
/
coqui.go
273 lines (235 loc) · 9.73 KB
/
coqui.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
package asticoqui
/*
#cgo CXXFLAGS: -std=c++11
#cgo LDFLAGS: -lstt
#include "coqui_wrap.h"
#include "stdlib.h"
*/
import "C"
import (
"errors"
"math"
"unsafe"
)
// Model provides an interface to a trained model.
type Model struct {
w *C.ModelWrapper
}
// New creates a new Model.
// modelPath is the path to the frozen model graph.
func New(modelPath string) (*Model, error) {
cModelPath := C.CString(modelPath)
defer C.free(unsafe.Pointer(cModelPath))
var ret C.int
w := C.New(cModelPath, &ret) // returns nil on error
if ret != 0 {
return nil, errorFromCode(ret)
}
return &Model{w}, nil
}
// Close frees associated resources and destroys the model object.
func (m *Model) Close() {
C.Model_Close(m.w) // deletes m.w
m.w = nil
}
// BeamWidth returns the beam width value used by the model.
// If SetModelBeamWidth was not called before, it will return the default
// value loaded from the model file.
func (m *Model) BeamWidth() uint {
return uint(C.Model_BeamWidth(m.w))
}
// SetBeamWidth sets the beam width value used by the model.
// A larger beam width value generates better results at the cost of decoding time.
func (m *Model) SetBeamWidth(width uint) error {
return errorFromCode(C.Model_SetBeamWidth(m.w, C.uint(width)))
}
// SampleRate returns the sample rate that was used to produce the model file.
func (m *Model) SampleRate() int {
return int(C.Model_SampleRate(m.w))
}
// EnableExternalScorer enables decoding using an external scorer.
// scorerPath is the path to the external scorer file.
func (m *Model) EnableExternalScorer(scorerPath string) error {
cScorerPath := C.CString(scorerPath)
defer C.free(unsafe.Pointer(cScorerPath))
return errorFromCode(C.Model_EnableExternalScorer(m.w, cScorerPath))
}
// DisableExternalScorer disables decoding using an external scorer.
func (m *Model) DisableExternalScorer() error {
return errorFromCode(C.Model_DisableExternalScorer(m.w))
}
// SetScorerAlphaBeta sets hyperparameters alpha and beta of the external scorer.
// alpha is the language model weight. beta is the word insertion weight.
func (m *Model) SetScorerAlphaBeta(alpha, beta float32) error {
return errorFromCode(C.Model_SetScorerAlphaBeta(m.w, C.float(alpha), C.float(beta)))
}
// sliceHeader represents a slice header
type sliceHeader struct {
Data uintptr
Len int
Cap int
}
// SpeechToText uses the model to convert speech to text.
// buffer is 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
func (m *Model) SpeechToText(buffer []int16) (string, error) {
hdr := (*sliceHeader)(unsafe.Pointer(&buffer))
str := C.Model_STT(m.w, (*C.short)(unsafe.Pointer(hdr.Data)), C.uint(hdr.Len))
if str == nil {
return "", errors.New("conversion failed")
}
defer C.FreeString(str)
return C.GoString(str), nil
}
// TokenMetadata stores text of an individual token, along with its timing information.
type TokenMetadata C.struct_TokenMetadata
// Text returns the text corresponding to this token.
func (tm *TokenMetadata) Text() string {
return C.GoString(C.TokenMetadata_Text((*C.TokenMetadata)(unsafe.Pointer(tm))))
}
// Timestep returns the position of the token in units of 20ms.
func (tm *TokenMetadata) Timestep() uint {
return uint(C.TokenMetadata_Timestep((*C.TokenMetadata)(unsafe.Pointer(tm))))
}
// StartTime returns the position of the token in seconds.
func (tm *TokenMetadata) StartTime() float32 {
return float32(C.TokenMetadata_StartTime((*C.TokenMetadata)(unsafe.Pointer(tm))))
}
// CandidateTranscript is a single transcript computed by the model,
// including a confidence value and the metadata for its constituent tokens.
type CandidateTranscript C.struct_CandidateTranscript
func (ct *CandidateTranscript) NumTokens() uint {
return uint(C.CandidateTranscript_NumTokens((*C.CandidateTranscript)(unsafe.Pointer(ct))))
}
func (ct *CandidateTranscript) Tokens() []TokenMetadata {
numTokens := uint(C.CandidateTranscript_NumTokens((*C.CandidateTranscript)(unsafe.Pointer(ct))))
allTokens := C.CandidateTranscript_Tokens((*C.CandidateTranscript)(unsafe.Pointer(ct)))
return (*[math.MaxInt32 - 1]TokenMetadata)(unsafe.Pointer(allTokens))[:numTokens:numTokens]
}
// Confidence returns the approximated confidence value for this transcript.
// This is roughly the sum of the acoustic model logit values for each timestep/character that
// contributed to the creation of this transcript.
func (ct *CandidateTranscript) Confidence() float64 {
return float64(C.CandidateTranscript_Confidence((*C.CandidateTranscript)(unsafe.Pointer(ct))))
}
// Metadata holds an array of CandidateTranscript objects computed by the model.
type Metadata C.struct_Metadata
func (m *Metadata) NumTranscripts() uint {
return uint(C.Metadata_NumTranscripts((*C.Metadata)(unsafe.Pointer(m))))
}
func (m *Metadata) Transcripts() []CandidateTranscript {
numTranscripts := int32(C.Metadata_NumTranscripts((*C.Metadata)(unsafe.Pointer(m))))
allTranscripts := C.Metadata_Transcripts((*C.Metadata)(unsafe.Pointer(m)))
return (*[math.MaxInt32 - 1]CandidateTranscript)(unsafe.Pointer(allTranscripts))[:numTranscripts:numTranscripts]
}
// Close frees the Metadata structure properly.
func (m *Metadata) Close() {
C.Metadata_Close((*C.Metadata)(unsafe.Pointer(m)))
}
// SpeechToTextWithMetadata uses the model to convert speech to text and
// output results including metadata.
//
// buffer is a 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
// numResults is the maximum number of CandidateTranscript structs to return. Returned value might be smaller than this.
// If an error is not returned, the returned metadata's Close method must be called later to free resources.
func (m *Model) SpeechToTextWithMetadata(buffer []int16, numResults uint) (*Metadata, error) {
hdr := (*sliceHeader)(unsafe.Pointer(&buffer))
md := (*Metadata)(unsafe.Pointer(C.Model_STTWithMetadata(
m.w, (*C.short)(unsafe.Pointer(hdr.Data)), C.uint(hdr.Len), C.uint(numResults))))
if md == nil {
return nil, errors.New("conversion failed")
}
return md, nil
}
// Stream represents a streaming inference state.
type Stream struct {
sw *C.StreamWrapper
}
// NewStream creates a new streaming inference state.
// If an error is not returned, exactly one of the returned stream's Finish,
// FinishWithMetadata, or Discard methods must be called later to free resources.
func (m *Model) NewStream() (*Stream, error) {
var ret C.int
sw := C.Model_NewStream(m.w, &ret) // returns nil on error
if ret != 0 {
return nil, errorFromCode(ret)
}
return &Stream{sw}, nil
}
// FeedAudioContent feeds audio samples to an ongoing streaming inference.
// buffer is an array of 16-bit, mono raw audio samples at the appropriate sample rate
// (matching what the model was trained on).
func (s *Stream) FeedAudioContent(buffer []int16) {
hdr := (*sliceHeader)(unsafe.Pointer(&buffer))
C.Stream_FeedAudioContent(s.sw, (*C.short)(unsafe.Pointer(hdr.Data)), C.uint(hdr.Len))
}
// IntermediateDecode computes the intermediate decoding of an ongoing streaming inference.
// This is an expensive process as the decoder implementation isn't
// currently capable of streaming, so it always starts from the beginning
// of the audio.
func (s *Stream) IntermediateDecode() (string, error) {
// STT_IntermediateDecode isn't documented as returning null, but future-proofing this seems safer.
str := C.Stream_IntermediateDecode(s.sw)
if str == nil {
return "", errors.New("decoding failed")
}
defer C.FreeString(str)
return C.GoString(str), nil
}
// IntermediateDecodeWithMetadata computes the intermediate decoding of an
// ongoing streaming inference, returning results including metadata.
// numResults is the number of candidate transcripts to return.
// If an error is not returned, the metadata's Close method must be called.
func (s *Stream) IntermediateDecodeWithMetadata(numResults uint) (*Metadata, error) {
md := (*Metadata)(unsafe.Pointer(C.Stream_IntermediateDecodeWithMetadata(s.sw, C.uint(numResults))))
if md == nil {
return nil, errors.New("decoding failed")
}
return md, nil
}
// Finish computes the final decoding of an ongoing streaming inference and returns the result.
// This signals the end of an ongoing streaming inference.
func (s *Stream) Finish() (string, error) {
// STT_FinishStream isn't documented as returning null, but future-proofing this seems safer.
str := C.Stream_Finish(s.sw) // deletes s.sw
s.sw = nil
if str == nil {
return "", errors.New("decoding failed")
}
defer C.FreeString(str)
return C.GoString(str), nil
}
// FinishWithMetadata computes the final decoding of an ongoing streaming inference and returns
// results including metadata. This signals the end of an ongoing streaming inference.
// If an error is not returned, the metadata's Close method must be called.
func (s *Stream) FinishWithMetadata(numResults uint) (*Metadata, error) {
md := (*Metadata)(unsafe.Pointer(C.Stream_FinishWithMetadata(s.sw, C.uint(numResults)))) // deletes s.sw
s.sw = nil
if md == nil {
return nil, errors.New("decoding failed")
}
return md, nil
}
// Discard destroys a streaming state without decoding the computed logits.
// This can be used if you no longer need the result of an ongoing streaming
// inference and don't want to perform a costly decode operation.
func (s *Stream) Discard() {
C.Stream_Discard(s.sw) // deletes s.sw
s.sw = nil
}
// Version returns the version of the C library.
// The returned version is a semantic version (SemVer 2.0.0).
func Version() string {
str := C.Version()
defer C.FreeString(str)
return C.GoString(str)
}
// errorFromCode converts a C error code into a Go error.
// Returns nil if code is equal to zero, indicating success.
func errorFromCode(code C.int) error {
if code == 0 {
return nil
}
str := C.ErrorCodeToErrorMessage(code)
defer C.FreeString(str)
return errors.New(C.GoString(str))
}