mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2023-11-04 02:52:44 +03:00
go : adding features to the go-whisper example, go ci, etc (#384)
* Updated bindings so they can be used in third pary packages. * Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin * Added test script * Changes for examples * Reverted * Made the NewContext method private
This commit is contained in:
@@ -11,10 +11,11 @@ import (
|
||||
// ERRORS
|
||||
|
||||
var (
|
||||
ErrUnableToLoadModel = errors.New("unable to load model")
|
||||
ErrInternalAppError = errors.New("internal application error")
|
||||
ErrProcessingFailed = errors.New("processing failed")
|
||||
ErrUnsupportedLanguage = errors.New("unsupported language")
|
||||
ErrUnableToLoadModel = errors.New("unable to load model")
|
||||
ErrInternalAppError = errors.New("internal application error")
|
||||
ErrProcessingFailed = errors.New("processing failed")
|
||||
ErrUnsupportedLanguage = errors.New("unsupported language")
|
||||
ErrModelNotMultilingual = errors.New("model is not multilingual")
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -24,7 +24,7 @@ var _ Context = (*context)(nil)
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func NewContext(model *model, params whisper.Params) (Context, error) {
|
||||
func newContext(model *model, params whisper.Params) (Context, error) {
|
||||
context := new(context)
|
||||
context.model = model
|
||||
context.params = params
|
||||
@@ -41,6 +41,9 @@ func (context *context) SetLanguage(lang string) error {
|
||||
if context.model.ctx == nil {
|
||||
return ErrInternalAppError
|
||||
}
|
||||
if !context.model.IsMultilingual() {
|
||||
return ErrModelNotMultilingual
|
||||
}
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id < 0 {
|
||||
return ErrUnsupportedLanguage
|
||||
} else if err := context.params.SetLanguage(id); err != nil {
|
||||
@@ -50,16 +53,60 @@ func (context *context) SetLanguage(lang string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (context *context) IsMultilingual() bool {
|
||||
return context.model.IsMultilingual()
|
||||
}
|
||||
|
||||
// Get language
|
||||
func (context *context) Language() string {
|
||||
return whisper.Whisper_lang_str(context.params.Language())
|
||||
}
|
||||
|
||||
// Set translate flag
|
||||
func (context *context) SetTranslate(v bool) {
|
||||
context.params.SetTranslate(v)
|
||||
}
|
||||
|
||||
// Set speedup flag
|
||||
func (context *context) SetSpeedup(v bool) {
|
||||
context.params.SetSpeedup(v)
|
||||
}
|
||||
|
||||
// Set number of threads to use
|
||||
func (context *context) SetThreads(v uint) {
|
||||
context.params.SetThreads(int(v))
|
||||
}
|
||||
|
||||
// Set time offset
|
||||
func (context *context) SetOffset(v time.Duration) {
|
||||
context.params.SetOffset(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set duration of audio to process
|
||||
func (context *context) SetDuration(v time.Duration) {
|
||||
context.params.SetOffset(int(v.Milliseconds()))
|
||||
}
|
||||
|
||||
// Set timestamp token probability threshold (~0.01)
|
||||
func (context *context) SetTokenThreshold(t float32) {
|
||||
context.params.SetTokenThreshold(t)
|
||||
}
|
||||
|
||||
// Set timestamp token sum probability threshold (~0.01)
|
||||
func (context *context) SetTokenSumThreshold(t float32) {
|
||||
context.params.SetTokenSumThreshold(t)
|
||||
}
|
||||
|
||||
// Set max segment length in characters
|
||||
func (context *context) SetMaxSegmentLength(n uint) {
|
||||
context.params.SetMaxSegmentLength(int(n))
|
||||
}
|
||||
|
||||
// Set max tokens per segment (0 = no limit)
|
||||
func (context *context) SetMaxTokensPerSegment(n uint) {
|
||||
context.params.SetMaxTokensPerSegment(int(n))
|
||||
}
|
||||
|
||||
// Process new sample data and return any errors
|
||||
func (context *context) Process(data []float32, cb SegmentCallback) error {
|
||||
if context.model.ctx == nil {
|
||||
@@ -119,6 +166,65 @@ func (context *context) NextSegment() (Segment, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Test for text tokens
|
||||
func (context *context) IsText(t Token) bool {
|
||||
switch {
|
||||
case context.IsBEG(t):
|
||||
return false
|
||||
case context.IsSOT(t):
|
||||
return false
|
||||
case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot():
|
||||
return false
|
||||
case context.IsPREV(t):
|
||||
return false
|
||||
case context.IsSOLM(t):
|
||||
return false
|
||||
case context.IsNOT(t):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Test for "begin" token
|
||||
func (context *context) IsBEG(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg()
|
||||
}
|
||||
|
||||
// Test for "start of transcription" token
|
||||
func (context *context) IsSOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot()
|
||||
}
|
||||
|
||||
// Test for "end of transcription" token
|
||||
func (context *context) IsEOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot()
|
||||
}
|
||||
|
||||
// Test for "start of prev" token
|
||||
func (context *context) IsPREV(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev()
|
||||
}
|
||||
|
||||
// Test for "start of lm" token
|
||||
func (context *context) IsSOLM(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm()
|
||||
}
|
||||
|
||||
// Test for "No timestamps" token
|
||||
func (context *context) IsNOT(t Token) bool {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not()
|
||||
}
|
||||
|
||||
// Test for token associated with a specific language
|
||||
func (context *context) IsLANG(t Token, lang string) bool {
|
||||
if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 {
|
||||
return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
|
||||
@@ -20,6 +20,9 @@ type Model interface {
|
||||
// Return a new speech-to-text context.
|
||||
NewContext() (Context, error)
|
||||
|
||||
// Return true if the model is multilingual.
|
||||
IsMultilingual() bool
|
||||
|
||||
// Return all languages supported.
|
||||
Languages() []string
|
||||
}
|
||||
@@ -27,8 +30,18 @@ type Model interface {
|
||||
// Context is the speach recognition context.
|
||||
type Context interface {
|
||||
SetLanguage(string) error // Set the language to use for speech recognition.
|
||||
SetTranslate(bool) // Set translate flag
|
||||
IsMultilingual() bool // Return true if the model is multilingual.
|
||||
Language() string // Get language
|
||||
SetSpeedup(bool) // Set speedup flag
|
||||
|
||||
SetOffset(time.Duration) // Set offset
|
||||
SetDuration(time.Duration) // Set duration
|
||||
SetThreads(uint) // Set number of threads to use
|
||||
SetSpeedup(bool) // Set speedup flag
|
||||
SetTokenThreshold(float32) // Set timestamp token probability threshold
|
||||
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
|
||||
SetMaxSegmentLength(uint) // Set max segment length in characters
|
||||
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
|
||||
|
||||
// Process mono audio data and return any errors.
|
||||
// If defined, newly generated segments are passed to the
|
||||
@@ -38,6 +51,15 @@ type Context interface {
|
||||
// After process is called, return segments until the end of the stream
|
||||
// is reached, when io.EOF is returned.
|
||||
NextSegment() (Segment, error)
|
||||
|
||||
IsBEG(Token) bool // Test for "begin" token
|
||||
IsSOT(Token) bool // Test for "start of transcription" token
|
||||
IsEOT(Token) bool // Test for "end of transcription" token
|
||||
IsPREV(Token) bool // Test for "start of prev" token
|
||||
IsSOLM(Token) bool // Test for "start of lm" token
|
||||
IsNOT(Token) bool // Test for "No timestamps" token
|
||||
IsLANG(Token, string) bool // Test for token associated with a specific language
|
||||
IsText(Token) bool // Test for text token
|
||||
}
|
||||
|
||||
// Segment is the text result of a speech recognition.
|
||||
|
||||
@@ -23,7 +23,7 @@ var _ Model = (*model)(nil)
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// LIFECYCLE
|
||||
|
||||
func New(path string) (*model, error) {
|
||||
func New(path string) (Model, error) {
|
||||
model := new(model)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, err
|
||||
@@ -64,6 +64,11 @@ func (model *model) String() string {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Return true if model is multilingual (language and translation options are supported)
|
||||
func (model *model) IsMultilingual() bool {
|
||||
return model.ctx.Whisper_is_multilingual() != 0
|
||||
}
|
||||
|
||||
// Return all recognized languages. Initially it is set to auto-detect
|
||||
func (model *model) Languages() []string {
|
||||
result := make([]string, 0, whisper.Whisper_lang_max_id())
|
||||
@@ -91,5 +96,5 @@ func (model *model) NewContext() (Context, error) {
|
||||
params.SetThreads(runtime.NumCPU())
|
||||
|
||||
// Return new context
|
||||
return NewContext(model, params)
|
||||
return newContext(model, params)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user