go : adding features to the go-whisper example, go ci, etc (#384)

* Updated bindings so they can be used in third pary packages.

* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin

* Added test script

* Changes for examples

* Reverted

* Made the NewContext method private
This commit is contained in:
David Thorpe
2023-01-07 19:21:43 +00:00
committed by GitHub
parent f30b5d322c
commit f078a6f20e
10 changed files with 370 additions and 31 deletions

View File

@@ -11,10 +11,11 @@ import (
// ERRORS
var (
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrModelNotMultilingual = errors.New("model is not multilingual")
)
///////////////////////////////////////////////////////////////////////////////

View File

@@ -24,7 +24,7 @@ var _ Context = (*context)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func NewContext(model *model, params whisper.Params) (Context, error) {
func newContext(model *model, params whisper.Params) (Context, error) {
context := new(context)
context.model = model
context.params = params
@@ -41,6 +41,9 @@ func (context *context) SetLanguage(lang string) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
if !context.model.IsMultilingual() {
return ErrModelNotMultilingual
}
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
return ErrUnsupportedLanguage
} else if err := context.params.SetLanguage(id); err != nil {
@@ -50,16 +53,60 @@ func (context *context) SetLanguage(lang string) error {
return nil
}
func (context *context) IsMultilingual() bool {
return context.model.IsMultilingual()
}
// Get language
func (context *context) Language() string {
return whisper.Whisper_lang_str(context.params.Language())
}
// Set translate flag
func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v)
}
// Set speedup flag
func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
// Set number of threads to use
func (context *context) SetThreads(v uint) {
context.params.SetThreads(int(v))
}
// Set time offset
func (context *context) SetOffset(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetOffset(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
func (context *context) SetTokenThreshold(t float32) {
context.params.SetTokenThreshold(t)
}
// Set timestamp token sum probability threshold (~0.01)
func (context *context) SetTokenSumThreshold(t float32) {
context.params.SetTokenSumThreshold(t)
}
// Set max segment length in characters
func (context *context) SetMaxSegmentLength(n uint) {
context.params.SetMaxSegmentLength(int(n))
}
// Set max tokens per segment (0 = no limit)
func (context *context) SetMaxTokensPerSegment(n uint) {
context.params.SetMaxTokensPerSegment(int(n))
}
// Process new sample data and return any errors
func (context *context) Process(data []float32, cb SegmentCallback) error {
if context.model.ctx == nil {
@@ -119,6 +166,65 @@ func (context *context) NextSegment() (Segment, error) {
return result, nil
}
// Test for text tokens
func (context *context) IsText(t Token) bool {
switch {
case context.IsBEG(t):
return false
case context.IsSOT(t):
return false
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
return false
case context.IsPREV(t):
return false
case context.IsSOLM(t):
return false
case context.IsNOT(t):
return false
default:
return true
}
}
// Test for "begin" token
func (context *context) IsBEG(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
}
// Test for "start of transcription" token
func (context *context) IsSOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
}
// Test for "end of transcription" token
func (context *context) IsEOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
}
// Test for "start of prev" token
func (context *context) IsPREV(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
}
// Test for "start of lm" token
func (context *context) IsSOLM(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
}
// Test for "No timestamps" token
func (context *context) IsNOT(t Token) bool {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
}
// Test for token associated with a specific language
func (context *context) IsLANG(t Token, lang string) bool {
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
} else {
return false
}
}
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

View File

@@ -20,6 +20,9 @@ type Model interface {
// Return a new speech-to-text context.
NewContext() (Context, error)
// Return true if the model is multilingual.
IsMultilingual() bool
// Return all languages supported.
Languages() []string
}
@@ -27,8 +30,18 @@ type Model interface {
// Context is the speach recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language
SetSpeedup(bool) // Set speedup flag
SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
@@ -38,6 +51,15 @@ type Context interface {
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
NextSegment() (Segment, error)
IsBEG(Token) bool // Test for "begin" token
IsSOT(Token) bool // Test for "start of transcription" token
IsEOT(Token) bool // Test for "end of transcription" token
IsPREV(Token) bool // Test for "start of prev" token
IsSOLM(Token) bool // Test for "start of lm" token
IsNOT(Token) bool // Test for "No timestamps" token
IsLANG(Token, string) bool // Test for token associated with a specific language
IsText(Token) bool // Test for text token
}
// Segment is the text result of a speech recognition.

View File

@@ -23,7 +23,7 @@ var _ Model = (*model)(nil)
///////////////////////////////////////////////////////////////////////////////
// LIFECYCLE
func New(path string) (*model, error) {
func New(path string) (Model, error) {
model := new(model)
if _, err := os.Stat(path); err != nil {
return nil, err
@@ -64,6 +64,11 @@ func (model *model) String() string {
///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS
// Return true if model is multilingual (language and translation options are supported)
func (model *model) IsMultilingual() bool {
return model.ctx.Whisper_is_multilingual() != 0
}
// Return all recognized languages. Initially it is set to auto-detect
func (model *model) Languages() []string {
result := make([]string, 0, whisper.Whisper_lang_max_id())
@@ -91,5 +96,5 @@ func (model *model) NewContext() (Context, error) {
params.SetThreads(runtime.NumCPU())
// Return new context
return NewContext(model, params)
return newContext(model, params)
}