mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2023-11-04 02:52:44 +03:00
go : adding features to the go-whisper example, go ci, etc (#384)
* Updated bindings so they can be used in third pary packages. * Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin * Added test script * Changes for examples * Reverted * Made the NewContext method private
This commit is contained in:
22
bindings/go/examples/go-whisper/color.go
Normal file
22
bindings/go/examples/go-whisper/color.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CONSTANTS
|
||||
|
||||
const (
|
||||
Reset = "\033[0m"
|
||||
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
|
||||
RGBSuffix = "m"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PUBLIC METHODS
|
||||
|
||||
// Colorize text with RGB values, from 0 to 23
|
||||
func Colorize(text string, v int) string {
|
||||
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
|
||||
// Grayscale colors are in the range 232-255
|
||||
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
|
||||
}
|
||||
@@ -2,6 +2,12 @@ package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// Packages
|
||||
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
|
||||
)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -42,6 +48,26 @@ func (flags *Flags) GetLanguage() string {
|
||||
return flags.Lookup("language").Value.String()
|
||||
}
|
||||
|
||||
func (flags *Flags) IsTranslate() bool {
|
||||
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetOffset() time.Duration {
|
||||
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetDuration() time.Duration {
|
||||
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetThreads() uint {
|
||||
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetOut() string {
|
||||
return strings.ToLower(flags.Lookup("out").Value.String())
|
||||
}
|
||||
|
||||
func (flags *Flags) IsSpeedup() bool {
|
||||
return flags.Lookup("speedup").Value.String() == "true"
|
||||
}
|
||||
@@ -50,12 +76,81 @@ func (flags *Flags) IsTokens() bool {
|
||||
return flags.Lookup("tokens").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) IsColorize() bool {
|
||||
return flags.Lookup("colorize").Value.String() == "true"
|
||||
}
|
||||
|
||||
func (flags *Flags) GetMaxLen() uint {
|
||||
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetMaxTokens() uint {
|
||||
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
|
||||
}
|
||||
|
||||
func (flags *Flags) GetWordThreshold() float32 {
|
||||
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
|
||||
}
|
||||
|
||||
func (flags *Flags) SetParams(context whisper.Context) error {
|
||||
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
|
||||
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
|
||||
if err := context.SetLanguage(lang); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if flags.IsTranslate() && context.IsMultilingual() {
|
||||
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
|
||||
context.SetTranslate(true)
|
||||
}
|
||||
if offset := flags.GetOffset(); offset != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
|
||||
context.SetOffset(offset)
|
||||
}
|
||||
if duration := flags.GetDuration(); duration != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
|
||||
context.SetDuration(duration)
|
||||
}
|
||||
if flags.IsSpeedup() {
|
||||
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
|
||||
context.SetSpeedup(true)
|
||||
}
|
||||
if threads := flags.GetThreads(); threads != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
|
||||
context.SetThreads(threads)
|
||||
}
|
||||
if max_len := flags.GetMaxLen(); max_len != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
|
||||
context.SetMaxSegmentLength(max_len)
|
||||
}
|
||||
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
|
||||
context.SetMaxTokensPerSegment(max_tokens)
|
||||
}
|
||||
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
|
||||
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
|
||||
context.SetTokenThreshold(word_threshold)
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// PRIVATE METHODS
|
||||
|
||||
func registerFlags(flag *Flags) {
|
||||
flag.String("model", "", "Path to the model file")
|
||||
flag.String("language", "", "Language")
|
||||
flag.String("language", "", "Spoken language")
|
||||
flag.Bool("translate", false, "Translate from source language to english")
|
||||
flag.Duration("offset", 0, "Time offset")
|
||||
flag.Duration("duration", 0, "Duration of audio to process")
|
||||
flag.Uint("threads", 0, "Number of threads to use")
|
||||
flag.Bool("speedup", false, "Enable speedup")
|
||||
flag.Uint("max-len", 0, "Maximum segment length in characters")
|
||||
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
|
||||
flag.Float64("word-thold", 0, "Maximum segment score")
|
||||
flag.Bool("tokens", false, "Display tokens")
|
||||
flag.Bool("colorize", false, "Colorize tokens")
|
||||
flag.String("out", "", "Output format (srt, none or leave as empty string)")
|
||||
}
|
||||
|
||||
@@ -35,8 +35,7 @@ func main() {
|
||||
|
||||
// Process files
|
||||
for _, filename := range flags.Args() {
|
||||
fmt.Println("Processing", filename)
|
||||
if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
|
||||
if err := Process(model, filename, flags); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
wav "github.com/go-audio/wav"
|
||||
)
|
||||
|
||||
func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
|
||||
func Process(model whisper.Model, path string, flags *Flags) error {
|
||||
var data []float32
|
||||
|
||||
// Create processing context
|
||||
@@ -20,14 +20,20 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the parameters
|
||||
if err := flags.SetParams(context); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open the file
|
||||
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
|
||||
fh, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
// Decode the WAV file
|
||||
// Decode the WAV file - load the full buffer
|
||||
dec := wav.NewDecoder(fh)
|
||||
if buf, err := dec.FullPCMBuffer(); err != nil {
|
||||
return err
|
||||
@@ -39,42 +45,83 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
|
||||
data = buf.AsFloat32Buffer().Data
|
||||
}
|
||||
|
||||
// Set the parameters
|
||||
// Segment callback when -tokens is specified
|
||||
var cb whisper.SegmentCallback
|
||||
if lang != "" {
|
||||
if err := context.SetLanguage(lang); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if speedup {
|
||||
context.SetSpeedup(true)
|
||||
}
|
||||
if tokens {
|
||||
if flags.IsTokens() {
|
||||
cb = func(segment whisper.Segment) {
|
||||
fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
for _, token := range segment.Tokens {
|
||||
fmt.Printf("%q ", token.Text)
|
||||
if flags.IsColorize() && context.IsText(token) {
|
||||
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
|
||||
} else {
|
||||
fmt.Fprint(flags.Output(), token.Text, " ")
|
||||
}
|
||||
}
|
||||
fmt.Println("")
|
||||
fmt.Fprintln(flags.Output(), "")
|
||||
fmt.Fprintln(flags.Output(), "")
|
||||
}
|
||||
}
|
||||
|
||||
// Process the data
|
||||
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
|
||||
if err := context.Process(data, cb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Print out the results
|
||||
switch {
|
||||
case flags.GetOut() == "srt":
|
||||
return OutputSRT(os.Stdout, context)
|
||||
case flags.GetOut() == "none":
|
||||
return nil
|
||||
default:
|
||||
return Output(os.Stdout, context, flags.IsColorize())
|
||||
}
|
||||
}
|
||||
|
||||
// Output text as SRT file
|
||||
func OutputSRT(w io.Writer, context whisper.Context) error {
|
||||
n := 1
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err == io.EOF {
|
||||
break
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
|
||||
fmt.Fprintln(w, n)
|
||||
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
|
||||
fmt.Fprintln(w, segment.Text)
|
||||
fmt.Fprintln(w, "")
|
||||
n++
|
||||
}
|
||||
|
||||
// Return success
|
||||
return nil
|
||||
}
|
||||
|
||||
// Output text to terminal
|
||||
func Output(w io.Writer, context whisper.Context, colorize bool) error {
|
||||
for {
|
||||
segment, err := context.NextSegment()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
|
||||
if colorize {
|
||||
for _, token := range segment.Tokens {
|
||||
if !context.IsText(token) {
|
||||
continue
|
||||
}
|
||||
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
|
||||
}
|
||||
fmt.Fprint(w, "\n")
|
||||
} else {
|
||||
fmt.Fprintln(w, " ", segment.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return srtTimestamp
|
||||
func srtTimestamp(t time.Duration) string {
|
||||
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user