Inverting deps on SQL, Log and MQ plugins to make them optional dependencies of extended servers, Removing some dead code that brought in unused dependencies Filtering out some non-linux transitive deps. (#1057)

* initial Db helper split - make SQL and datastore packages optional

* abstracting log store

* break out DB, MQ and log drivers as extensions

* cleanup

* fewer deps

* fixing docker test

* hmm dbness

* updating db startup

* Consolidate all your extensions into one convenient package

* cleanup

* clean up dep constraints
This commit is contained in:
Owen Cliffe
2018-06-11 18:23:28 +01:00
committed by Reed Allman
parent 5b2691037e
commit 1ad27f4f0d
519 changed files with 907 additions and 91042 deletions

59
Gopkg.lock generated
View File

@@ -7,21 +7,6 @@
revision = "272470790ad6db791bd6f9db399b2cd2d5879f74" revision = "272470790ad6db791bd6f9db399b2cd2d5879f74"
source = "github.com/apache/thrift" source = "github.com/apache/thrift"
[[projects]]
branch = "master"
name = "github.com/Azure/go-ansiterm"
packages = [
".",
"winterm"
]
revision = "d6e3b3328b783f23731bc4d058875b0371ff8109"
[[projects]]
name = "github.com/Microsoft/go-winio"
packages = ["."]
revision = "7da180ee92d8bd8bb8c37fc560e673e6557c392f"
version = "v0.4.7"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "github.com/Nvveen/Gotty" name = "github.com/Nvveen/Gotty"
@@ -96,7 +81,7 @@
branch = "master" branch = "master"
name = "github.com/containerd/continuity" name = "github.com/containerd/continuity"
packages = ["pathdriver"] packages = ["pathdriver"]
revision = "3e8f2ea4b190484acb976a5b378d373429639a1a" revision = "d3c23511c1bf5851696cba83143d9cbcd666869b"
[[projects]] [[projects]]
name = "github.com/coreos/go-semver" name = "github.com/coreos/go-semver"
@@ -110,26 +95,6 @@
revision = "4ebf1de738443ea7f45f02dc394c4df1942a126d" revision = "4ebf1de738443ea7f45f02dc394c4df1942a126d"
version = "v1.1.0" version = "v1.1.0"
[[projects]]
name = "github.com/docker/distribution"
packages = [
".",
"digestset",
"manifest",
"manifest/schema1",
"manifest/schema2",
"reference",
"registry/api/errcode",
"registry/api/v2",
"registry/client",
"registry/client/auth",
"registry/client/auth/challenge",
"registry/client/transport",
"registry/storage/cache",
"registry/storage/cache/memory"
]
revision = "bc3c7b0525e59d3ecfab3e1568350895fd4a462f"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "github.com/docker/docker" name = "github.com/docker/docker"
@@ -174,12 +139,6 @@
revision = "0dadbb0345b35ec7ef35e228dabb8de89a65bf52" revision = "0dadbb0345b35ec7ef35e228dabb8de89a65bf52"
version = "v0.3.2" version = "v0.3.2"
[[projects]]
branch = "master"
name = "github.com/docker/libtrust"
packages = ["."]
revision = "aabc10ec26b754e797f9028f4589c5b7bd90dc20"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "github.com/fnproject/fdk-go" name = "github.com/fnproject/fdk-go"
@@ -343,18 +302,6 @@
packages = ["."] packages = ["."]
revision = "e89373fe6b4a7413d7acd6da1725b83ef713e6e4" revision = "e89373fe6b4a7413d7acd6da1725b83ef713e6e4"
[[projects]]
name = "github.com/gorilla/context"
packages = ["."]
revision = "1ea25387ff6f684839d82767c1733ff4d4d15d0a"
version = "v1.1"
[[projects]]
name = "github.com/gorilla/mux"
packages = ["."]
revision = "53c1911da2b537f792e7cafcb446b05ffe33b996"
version = "v1.6.1"
[[projects]] [[projects]]
name = "github.com/jmespath/go-jmespath" name = "github.com/jmespath/go-jmespath"
packages = ["."] packages = ["."]
@@ -376,7 +323,7 @@
".", ".",
"oid" "oid"
] ]
revision = "614cb7963ff8ee90114d039a0f92dcd6a79292f9" revision = "90697d60dd844d5ef6ff15135d0203f65d2f53b8"
[[projects]] [[projects]]
branch = "master" branch = "master"
@@ -666,6 +613,6 @@
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
inputs-digest = "5ff01d4a02d97ec5447f99d45f47e593bb94c4581f07baefad209f25d0b88785" inputs-digest = "c988af213ce5e034443bccfd948933a7f8ac9caf98a7c8a4730202dbf32372e7"
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View File

@@ -19,7 +19,10 @@
# name = "github.com/x/y" # name = "github.com/x/y"
# version = "2.4.0" # version = "2.4.0"
ignored = ["github.com/fnproject/fn/cli"] # filter out unused transitive deps
ignored = ["github.com/fnproject/fn/cli",
"github.com/Microsoft/go-winio*",
"github.com/Azure/go-ansiterm*"]
[[constraint]] [[constraint]]
name = "github.com/boltdb/bolt" name = "github.com/boltdb/bolt"
@@ -49,9 +52,6 @@ ignored = ["github.com/fnproject/fn/cli"]
name = "github.com/sirupsen/logrus" name = "github.com/sirupsen/logrus"
revision = "89742aefa4b206dcf400792f3bd35b542998eb3b" revision = "89742aefa4b206dcf400792f3bd35b542998eb3b"
[[constraint]]
name = "github.com/docker/distribution"
revision = "bc3c7b0525e59d3ecfab3e1568350895fd4a462f"
[[constraint]] [[constraint]]
# NOTE: locked for a reason - https://github.com/go-sql-driver/mysql/issues/657 # NOTE: locked for a reason - https://github.com/go-sql-driver/mysql/issues/657

View File

@@ -9,7 +9,6 @@ import (
"time" "time"
"github.com/fnproject/fn/api/agent/drivers" "github.com/fnproject/fn/api/agent/drivers"
"github.com/fsouza/go-dockerclient"
) )
type taskDockerTest struct { type taskDockerTest struct {
@@ -170,20 +169,21 @@ func TestRunnerDockerStdin(t *testing.T) {
} }
} }
func TestRegistry(t *testing.T) { //
image := "fnproject/fn-test-utils" //func TestRegistry(t *testing.T) {
// image := "fnproject/fn-test-utils"
sizer, err := CheckRegistry(context.Background(), image, docker.AuthConfiguration{}) //
if err != nil { // sizer, err := CheckRegistry(context.Background(), image, docker.AuthConfiguration{})
t.Fatal("expected registry check not to fail, got:", err) // if err != nil {
} // t.Fatal("expected registry check not to fail, got:", err)
// }
size, err := sizer.Size() //
if err != nil { // size, err := sizer.Size()
t.Fatal("expected sizer not to fail, got:", err) // if err != nil {
} // t.Fatal("expected sizer not to fail, got:", err)
// }
if size <= 0 { //
t.Fatal("expected positive size for image that exists, got size:", size) // if size <= 0 {
} // t.Fatal("expected positive size for image that exists, got size:", size)
} // }
//}

View File

@@ -1,193 +1,12 @@
package docker package docker
import ( import (
"context"
"crypto/tls"
"io"
"io/ioutil"
"net"
"net/http"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/docker/distribution"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/manifest/schema2"
"github.com/docker/distribution/reference"
registry "github.com/docker/distribution/registry/client"
"github.com/docker/distribution/registry/client/auth"
"github.com/docker/distribution/registry/client/auth/challenge"
"github.com/docker/distribution/registry/client/transport"
"github.com/fnproject/fn/api/agent/drivers"
docker "github.com/fsouza/go-dockerclient"
)
var (
// we need these imported so that they can be unmarshaled properly (yes, docker is mean)
_ = schema1.SchemaVersion
_ = schema2.SchemaVersion
registryTransport = &http.Transport{
Dial: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 2 * time.Minute,
}).Dial,
TLSClientConfig: &tls.Config{
ClientSessionCache: tls.NewLRUClientSessionCache(8192),
},
TLSHandshakeTimeout: 10 * time.Second,
MaxIdleConnsPerHost: 32, // TODO tune; we will likely be making lots of requests to same place
Proxy: http.ProxyFromEnvironment,
IdleConnTimeout: 90 * time.Second,
MaxIdleConns: 512,
ExpectContinueTimeout: 1 * time.Second,
}
) )
const hubURL = "https://registry.hub.docker.com" const hubURL = "https://registry.hub.docker.com"
// CheckRegistry will return a sizer, which can be used to check the size of an
// image if the returned error is nil. If the error returned is nil, then
// authentication against the given credentials was successful, if the
// configuration or image do not specify a config.ServerAddress,
// https://hub.docker.com will be tried. CheckRegistry is a package level
// method since rkt can also use docker images, we may be interested in using
// rkt w/o a docker driver configured; also, we don't have to tote around a
// driver in any tasker that may be interested in registry information (2/2
// cases thus far).
func CheckRegistry(ctx context.Context, image string, config docker.AuthConfiguration) (Sizer, error) {
regURL, repoName, tag := drivers.ParseImage(image)
repoNamed, err := reference.WithName(repoName)
if err != nil {
return nil, err
}
if regURL == "" {
// image address overrides credential address
regURL = config.ServerAddress
}
regURL, err = registryURL(regURL)
if err != nil {
return nil, err
}
cm := challenge.NewSimpleManager()
creds := newCreds(config.Username, config.Password)
tran := transport.NewTransport(registryTransport,
auth.NewAuthorizer(cm,
auth.NewTokenHandler(registryTransport,
creds,
repoNamed.Name(),
"pull",
),
auth.NewBasicHandler(creds),
),
)
tran = &retryWrap{cm, tran}
repo, err := registry.NewRepository(repoNamed, regURL, tran)
if err != nil {
return nil, err
}
manis, err := repo.Manifests(ctx)
if err != nil {
return nil, err
}
mani, err := manis.Get(context.TODO(), "", distribution.WithTag(tag))
if err != nil {
return nil, err
}
blobs := repo.Blobs(ctx)
// most registries aren't that great, and won't provide a size for the top
// level digest, so we need to sum up all the layers. let this be optional
// with the sizer, since tag is good enough to check existence / auth.
return &sizer{mani, blobs}, nil
}
type retryWrap struct {
cm challenge.Manager
tran http.RoundTripper
}
func (d *retryWrap) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := d.tran.RoundTrip(req)
// if it's not authed, we have to add this to the challenge manager,
// and then retry it (it will get authed and the challenge then accepted).
// why the docker distribution transport doesn't do this for you is
// a real testament to what sadists those docker people are.
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
pingPath := req.URL.Path
if v2Root := strings.Index(req.URL.Path, "/v2/"); v2Root != -1 {
pingPath = pingPath[:v2Root+4]
} else if v1Root := strings.Index(req.URL.Path, "/v1/"); v1Root != -1 {
pingPath = pingPath[:v1Root] + "/v2/"
}
// seriously, we have to rewrite this to the ping path,
// since looking up challenges strips to this path. YUP. GLHF.
ogURL := req.URL.Path
resp.Request.URL.Path = pingPath
d.cm.AddResponse(resp)
io.Copy(ioutil.Discard, resp.Body)
resp.Body.Close()
// put the original URL path back and try again now...
req.URL.Path = ogURL
resp, err = d.tran.RoundTrip(req)
}
return resp, err
}
func newCreds(user, pass string) *creds {
return &creds{m: make(map[string]string), user: user, pass: pass}
}
// implement auth.CredentialStore
type creds struct {
m map[string]string
user, pass string
}
func (c *creds) Basic(u *url.URL) (string, string) { return c.user, c.pass }
func (c *creds) RefreshToken(u *url.URL, service string) string { return c.m[service] }
func (c *creds) SetRefreshToken(u *url.URL, service, token string) { c.m[service] = token }
// Sizer returns size information. This interface is liable to contain more
// than a size at some point, change as needed.
type Sizer interface {
Size() (int64, error)
}
type sizer struct {
mani distribution.Manifest
blobs distribution.BlobStore
}
func (s *sizer) Size() (int64, error) {
var sum int64
for _, r := range s.mani.References() {
desc, err := s.blobs.Stat(context.TODO(), r.Digest)
if err != nil {
return 0, err
}
sum += desc.Size
}
return sum, nil
}
func registryURL(addr string) (string, error) { func registryURL(addr string) (string, error) {
if addr == "" || strings.Contains(addr, "hub.docker.com") || strings.Contains(addr, "index.docker.io") { if addr == "" || strings.Contains(addr, "hub.docker.com") || strings.Contains(addr, "index.docker.io") {
return hubURL, nil return hubURL, nil

View File

@@ -2,35 +2,49 @@ package datastore
import ( import (
"context" "context"
"fmt"
"net/url" "net/url"
"fmt"
"github.com/fnproject/fn/api/common" "github.com/fnproject/fn/api/common"
"github.com/fnproject/fn/api/datastore/internal/datastoreutil" "github.com/fnproject/fn/api/datastore/internal/datastoreutil"
"github.com/fnproject/fn/api/datastore/sql"
"github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/models"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// New creates a DataStore from the specified URL
func New(ctx context.Context, dbURL string) (models.Datastore, error) { func New(ctx context.Context, dbURL string) (models.Datastore, error) {
return newds(ctx, dbURL) // teehee
}
func Wrap(ds models.Datastore) models.Datastore {
return datastoreutil.MetricDS(datastoreutil.NewValidator(ds))
}
func newds(ctx context.Context, dbURL string) (models.Datastore, error) {
log := common.Logger(ctx) log := common.Logger(ctx)
u, err := url.Parse(dbURL) u, err := url.Parse(dbURL)
if err != nil { if err != nil {
log.WithError(err).WithFields(logrus.Fields{"url": dbURL}).Fatal("bad DB URL") log.WithError(err).WithFields(logrus.Fields{"url": dbURL}).Fatal("bad DB URL")
} }
log.WithFields(logrus.Fields{"db": u.Scheme}).Debug("creating new datastore") log.WithFields(logrus.Fields{"db": u.Scheme}).Debug("creating new datastore")
switch u.Scheme {
case "sqlite3", "postgres", "mysql": for _, provider := range providers {
return sql.New(ctx, u) if provider.Supports(u) {
default: return provider.New(ctx, u)
return nil, fmt.Errorf("db type not supported %v", u.Scheme) }
} }
return nil, fmt.Errorf("no data store provider found for storage url %s", u)
}
func Wrap(ds models.Datastore) models.Datastore {
return datastoreutil.MetricDS(datastoreutil.NewValidator(ds))
}
// Provider is a datastore provider
type Provider interface {
fmt.Stringer
// Supports indicates if this provider can handle a given data store.
Supports(url *url.URL) bool
// New creates a new data store from the specified URL
New(ctx context.Context, url *url.URL) (models.Datastore, error)
}
var providers []Provider
// AddProvider globally registers a data store provider
func AddProvider(provider Provider) {
logrus.Infof("Adding DataStore provider %s", provider)
providers = append(providers, provider)
} }

View File

@@ -0,0 +1,43 @@
// dbhelpers wrap SQL and specific capabilities of an SQL db
package dbhelper
import (
"fmt"
"github.com/jmoiron/sqlx"
"github.com/sirupsen/logrus"
"net/url"
)
var sqlHelpers []Helper
//Add registers a new SQL helper
func Add(helper Helper) {
logrus.Infof("Registering DB helper %s", helper)
sqlHelpers = append(sqlHelpers, helper)
}
//Helper provides DB-specific SQL capabilities
type Helper interface {
fmt.Stringer
//Supports indicates if this helper supports this driver name
Supports(driverName string) bool
//PreConnect calculates the connect URL for the db from a canonical URL used in Fn config
PreConnect(url *url.URL) (string, error)
//PostCreate Apply any configuration to the DB prior to use
PostCreate(db *sqlx.DB) (*sqlx.DB, error)
//CheckTableExists checks if a table exists in the DB
CheckTableExists(tx *sqlx.Tx, table string) (bool, error)
//IsDuplicateKeyError determines if an error indicates if the prior error was caused by a duplicate key insert
IsDuplicateKeyError(err error) bool
}
//GetHelper returns a helper for a specific driver
func GetHelper(driverName string) (Helper, bool) {
for _, helper := range sqlHelpers {
if helper.Supports(driverName) {
return helper, true
}
logrus.Printf("%s does not support %s", helper, driverName)
}
return nil, false
}

View File

@@ -9,9 +9,8 @@ import (
"sort" "sort"
"strings" "strings"
"github.com/go-sql-driver/mysql" "github.com/fnproject/fn/api/datastore/sql/dbhelper"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@@ -302,28 +301,29 @@ func SetVersion(ctx context.Context, tx *sqlx.Tx, version int64, dirty bool) err
} }
func Version(ctx context.Context, tx *sqlx.Tx) (version int64, dirty bool, err error) { func Version(ctx context.Context, tx *sqlx.Tx) (version int64, dirty bool, err error) {
helper, ok := dbhelper.GetHelper(tx.DriverName())
if !ok {
return 0, false, fmt.Errorf("no db helper registered for for %s", tx.DriverName())
}
tableExists, err := helper.CheckTableExists(tx, MigrationsTable)
if err != nil {
return 0, false, err
}
if !tableExists {
return NilVersion, false, nil
}
query := tx.Rebind(`SELECT version, dirty FROM ` + MigrationsTable + ` LIMIT 1`) query := tx.Rebind(`SELECT version, dirty FROM ` + MigrationsTable + ` LIMIT 1`)
err = tx.QueryRowContext(ctx, query).Scan(&version, &dirty) err = tx.QueryRowContext(ctx, query).Scan(&version, &dirty)
switch { switch {
case err == sql.ErrNoRows: case err == sql.ErrNoRows:
return NilVersion, false, nil return NilVersion, false, nil
case err != nil: case err != nil:
if e, ok := err.(*mysql.MySQLError); ok {
if e.Number == 0 {
return NilVersion, false, nil
}
}
if e, ok := err.(*pq.Error); ok {
if e.Code.Name() == "undefined_table" {
return NilVersion, false, nil
}
}
// sqlite3 returns 'no such table' but the error is not typed
if strings.Contains(err.Error(), "no such table") {
return NilVersion, false, nil
}
return 0, false, err return 0, false, err
default: default:

View File

@@ -6,8 +6,8 @@ import (
"fmt" "fmt"
"testing" "testing"
_ "github.com/fnproject/fn/api/datastore/sql/sqlite"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
) )
const testsqlite3 = "file::memory:?mode=memory&cache=shared" const testsqlite3 = "file::memory:?mode=memory&cache=shared"

View File

@@ -0,0 +1,61 @@
package mysql
import (
"fmt"
"github.com/fnproject/fn/api/datastore/sql/dbhelper"
"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"net/url"
"strings"
)
type mysqlHelper int
func (mysqlHelper) Supports(scheme string) bool {
return scheme == "mysql"
}
func (mysqlHelper) PreConnect(url *url.URL) (string, error) {
return strings.TrimPrefix(url.String(), url.Scheme+"://"), nil
}
func (mysqlHelper) PostCreate(db *sqlx.DB) (*sqlx.DB, error) {
return db, nil
}
func (mysqlHelper) CheckTableExists(tx *sqlx.Tx, table string) (bool, error) {
query := tx.Rebind(fmt.Sprintf(`SELECT count(*)
FROM information_schema.TABLES
WHERE TABLE_NAME = '%s'
`, table))
row := tx.QueryRow(query)
var count int
err := row.Scan(&count)
if err != nil {
return false, err
}
exists := count > 0
return exists, nil
}
func (mysqlHelper) String() string {
return "mysql"
}
func (mysqlHelper) IsDuplicateKeyError(err error) bool {
switch mErr := err.(type) {
case *mysql.MySQLError:
if mErr.Number == 1062 {
return true
}
}
return false
}
func init() {
dbhelper.Add(mysqlHelper(0))
}

View File

@@ -0,0 +1,63 @@
package postgres
import (
"github.com/fnproject/fn/api/datastore/sql/dbhelper"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"net/url"
)
type postgresHelper int
func (postgresHelper) Supports(scheme string) bool {
switch scheme {
case "postgres", "pgx":
return true
}
return false
}
func (postgresHelper) PreConnect(url *url.URL) (string, error) {
return url.String(), nil
}
func (postgresHelper) PostCreate(db *sqlx.DB) (*sqlx.DB, error) {
return db, nil
}
func (postgresHelper) CheckTableExists(tx *sqlx.Tx, table string) (bool, error) {
query := tx.Rebind(`SELECT count(*)
FROM information_schema.TABLES
WHERE TABLE_NAME = 'apps'
`)
row := tx.QueryRow(query)
var count int
err := row.Scan(&count)
if err != nil {
return false, err
}
exists := count > 0
return exists, nil
}
func (postgresHelper) String() string {
return "postgres"
}
func (postgresHelper) IsDuplicateKeyError(err error) bool {
switch dbErr := err.(type) {
case *pq.Error:
if dbErr.Code == "23505" {
return true
}
}
return false
}
func init() {
dbhelper.Add(postgresHelper(0))
}

View File

@@ -4,12 +4,10 @@ import (
"bytes" "bytes"
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"io" "io"
"net/url" "net/url"
"os" "os"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -18,13 +16,11 @@ import (
"github.com/fnproject/fn/api/datastore/sql/migratex" "github.com/fnproject/fn/api/datastore/sql/migratex"
"github.com/fnproject/fn/api/datastore/sql/migrations" "github.com/fnproject/fn/api/datastore/sql/migrations"
"github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/models"
"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/lib/pq"
_ "github.com/lib/pq" "github.com/fnproject/fn/api/datastore"
"github.com/mattn/go-sqlite3" "github.com/fnproject/fn/api/datastore/sql/dbhelper"
_ "github.com/mattn/go-sqlite3" "github.com/fnproject/fn/api/logs"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@@ -103,13 +99,44 @@ var ( // compiler will yell nice things about our upbringing as a child
) )
type SQLStore struct { type SQLStore struct {
db *sqlx.DB helper dbhelper.Helper
db *sqlx.DB
} }
type sqlDsProvider int
// New will open the db specified by url, create any tables necessary // New will open the db specified by url, create any tables necessary
// and return a models.Datastore safe for concurrent usage. // and return a models.Datastore safe for concurrent usage.
func New(ctx context.Context, url *url.URL) (*SQLStore, error) { func New(ctx context.Context, u *url.URL) (*SQLStore, error) {
return newDS(ctx, url) return newDS(ctx, u)
}
func (sqlDsProvider) Supports(u *url.URL) bool {
_, ok := dbhelper.GetHelper(u.Scheme)
return ok
}
func (sqlDsProvider) New(ctx context.Context, u *url.URL) (models.Datastore, error) {
return newDS(ctx, u)
}
func (sqlDsProvider) String() string {
return "sql"
}
type sqlLogsProvider int
func (sqlLogsProvider) String() string {
return "sql"
}
func (sqlLogsProvider) Supports(u *url.URL) bool {
_, ok := dbhelper.GetHelper(u.Scheme)
return ok
}
func (sqlLogsProvider) New(ctx context.Context, u *url.URL) (models.LogStore, error) {
return newDS(ctx, u)
} }
// for test methods, return concrete type, but don't expose // for test methods, return concrete type, but don't expose
@@ -117,27 +144,19 @@ func newDS(ctx context.Context, url *url.URL) (*SQLStore, error) {
driver := url.Scheme driver := url.Scheme
log := common.Logger(ctx) log := common.Logger(ctx)
// driver must be one of these for sqlx to work, double check: helper, ok := dbhelper.GetHelper(driver)
switch driver {
case "postgres", "pgx", "mysql", "sqlite3": if !ok {
default: return nil, fmt.Errorf("DB helper '%s' is not supported", driver)
return nil, errors.New("invalid db driver, refer to the code")
} }
if driver == "sqlite3" { uri, err := helper.PreConnect(url)
// make all the dirs so we can make the file..
dir := filepath.Dir(url.Path) if err != nil {
err := os.MkdirAll(dir, 0755) return nil, fmt.Errorf("failed to initialise db helper %s : %s", driver, err)
if err != nil {
return nil, err
}
} }
uri := url.String() log.WithFields(logrus.Fields{"url": uri}).Info("Connecting to DB")
if driver != "postgres" {
// postgres seems to need this as a prefix in lib/pq, everyone else wants it stripped of scheme
uri = strings.TrimPrefix(url.String(), url.Scheme+"://")
}
sqldb, err := sql.Open(driver, uri) sqldb, err := sql.Open(driver, uri)
if err != nil { if err != nil {
@@ -158,12 +177,12 @@ func newDS(ctx context.Context, url *url.URL) (*SQLStore, error) {
db.SetMaxIdleConns(maxIdleConns) db.SetMaxIdleConns(maxIdleConns)
log.WithFields(logrus.Fields{"max_idle_connections": maxIdleConns, "datastore": driver}).Info("datastore dialed") log.WithFields(logrus.Fields{"max_idle_connections": maxIdleConns, "datastore": driver}).Info("datastore dialed")
switch driver { // NOTE: fixes weird sqlite3 behavior db, err = helper.PostCreate(db)
case "sqlite3": if err != nil {
db.SetMaxOpenConns(1) log.WithFields(logrus.Fields{"url": uri}).WithError(err).Error("couldn't initialize db")
return nil, err
} }
sdb := &SQLStore{db: db, helper: helper}
sdb := &SQLStore{db: db}
// NOTE: runMigrations happens before we create all the tables, so that it // NOTE: runMigrations happens before we create all the tables, so that it
// can detect whether the db did not exist and insert the latest version of // can detect whether the db did not exist and insert the latest version of
@@ -231,41 +250,11 @@ func pingWithRetry(ctx context.Context, db *sqlx.DB) (err error) {
return err return err
} }
// checkExistence checks if tables have been created yet, it is not concerned
// about the existence of the schema migration version (since migrations were
// added to existing dbs, we need to know whether the db exists without migrations
// or if it's brand new).
func checkExistence(tx *sqlx.Tx) (bool, error) {
query := tx.Rebind(`SELECT count(*)
FROM information_schema.TABLES
WHERE TABLE_NAME = 'apps'
`)
if tx.DriverName() == "sqlite3" {
// sqlite3 is special, of course
query = tx.Rebind(`SELECT count(*)
FROM sqlite_master
WHERE name = 'apps'
`)
}
row := tx.QueryRow(query)
var count int
err := row.Scan(&count)
if err != nil {
return false, err
}
exists := count > 0
return exists, nil
}
// check if the db already existed, if the db is brand new then we can skip // check if the db already existed, if the db is brand new then we can skip
// over all the migrations BUT we must be sure to set the right migration // over all the migrations BUT we must be sure to set the right migration
// number so that only current migrations are skipped, not any future ones. // number so that only current migrations are skipped, not any future ones.
func (ds *SQLStore) runMigrations(ctx context.Context, tx *sqlx.Tx, migrations []migratex.Migration) error { func (ds *SQLStore) runMigrations(ctx context.Context, tx *sqlx.Tx, migrations []migratex.Migration) error {
dbExists, err := checkExistence(tx) dbExists, err := ds.helper.CheckTableExists(tx, "apps")
if err != nil { if err != nil {
return err return err
} }
@@ -355,19 +344,8 @@ func (ds *SQLStore) InsertApp(ctx context.Context, app *models.App) (*models.App
);`) );`)
_, err := ds.db.NamedExecContext(ctx, query, app) _, err := ds.db.NamedExecContext(ctx, query, app)
if err != nil { if err != nil {
switch err := err.(type) { if ds.helper.IsDuplicateKeyError(err) {
case *mysql.MySQLError: return nil, models.ErrAppsAlreadyExists
if err.Number == 1062 {
return nil, models.ErrAppsAlreadyExists
}
case *pq.Error:
if err.Code == "23505" {
return nil, models.ErrAppsAlreadyExists
}
case sqlite3.Error:
if err.ExtendedCode == sqlite3.ErrConstraintUnique || err.ExtendedCode == sqlite3.ErrConstraintPrimaryKey {
return nil, models.ErrAppsAlreadyExists
}
} }
return nil, err return nil, err
} }
@@ -924,3 +902,8 @@ func (ds *SQLStore) GetDatabase() *sqlx.DB {
func (ds *SQLStore) Close() error { func (ds *SQLStore) Close() error {
return ds.db.Close() return ds.db.Close()
} }
func init() {
datastore.AddProvider(sqlDsProvider(0))
logs.AddProvider(sqlLogsProvider(0))
}

View File

@@ -10,6 +10,9 @@ import (
"github.com/fnproject/fn/api/datastore/internal/datastoreutil" "github.com/fnproject/fn/api/datastore/internal/datastoreutil"
"github.com/fnproject/fn/api/datastore/sql/migratex" "github.com/fnproject/fn/api/datastore/sql/migratex"
"github.com/fnproject/fn/api/datastore/sql/migrations" "github.com/fnproject/fn/api/datastore/sql/migrations"
_ "github.com/fnproject/fn/api/datastore/sql/mysql"
_ "github.com/fnproject/fn/api/datastore/sql/postgres"
_ "github.com/fnproject/fn/api/datastore/sql/sqlite"
logstoretest "github.com/fnproject/fn/api/logs/testing" logstoretest "github.com/fnproject/fn/api/logs/testing"
"github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/models"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"

View File

@@ -0,0 +1,74 @@
package sqlite
import (
"fmt"
"github.com/fnproject/fn/api/datastore/sql/dbhelper"
"github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
"net/url"
"os"
"path/filepath"
"strings"
)
type sqliteHelper int
func (sqliteHelper) Supports(scheme string) bool {
switch scheme {
case "sqlite3", "sqlite":
return true
}
return false
}
func (sqliteHelper) PreConnect(url *url.URL) (string, error) {
// make all the dirs so we can make the file..
dir := filepath.Dir(url.Path)
err := os.MkdirAll(dir, 0755)
if err != nil {
return "", err
}
return strings.TrimPrefix(url.String(), url.Scheme+"://"), nil
}
func (sqliteHelper) PostCreate(db *sqlx.DB) (*sqlx.DB, error) {
db.SetMaxOpenConns(1)
return db, nil
}
func (sqliteHelper) CheckTableExists(tx *sqlx.Tx, table string) (bool, error) {
query := tx.Rebind(fmt.Sprintf(`SELECT count(*)
FROM sqlite_master
WHERE name = '%s'
`, table))
row := tx.QueryRow(query)
var count int
err := row.Scan(&count)
if err != nil {
return false, err
}
exists := count > 0
return exists, nil
}
func (sqliteHelper) String() string {
return "sqlite"
}
func (sqliteHelper) IsDuplicateKeyError(err error) bool {
sqliteErr, ok := err.(sqlite3.Error)
if ok {
if sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique || sqliteErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey {
return true
}
}
return false
}
func init() {
dbhelper.Add(sqliteHelper(0))
}

View File

@@ -7,14 +7,30 @@ import (
"context" "context"
"github.com/fnproject/fn/api/common" "github.com/fnproject/fn/api/common"
"github.com/fnproject/fn/api/datastore/sql"
"github.com/fnproject/fn/api/logs/metrics" "github.com/fnproject/fn/api/logs/metrics"
"github.com/fnproject/fn/api/logs/s3"
"github.com/fnproject/fn/api/logs/validator" "github.com/fnproject/fn/api/logs/validator"
"github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/models"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// Provider defines a source that can create log stores
type Provider interface {
fmt.Stringer
// Supports indicates if this provider can handle a specific URL scheme
Supports(url *url.URL) bool
//Create a new log store from the corresponding URL
New(ctx context.Context, url *url.URL) (models.LogStore, error)
}
var providers []Provider
// AddProvider globally registers a new LogStore provider
func AddProvider(pf Provider) {
logrus.Info("Adding log provider %s", pf)
providers = append(providers, pf)
}
// New Creates a new log store based on a given URL
func New(ctx context.Context, dbURL string) (models.LogStore, error) { func New(ctx context.Context, dbURL string) (models.LogStore, error) {
log := common.Logger(ctx) log := common.Logger(ctx)
u, err := url.Parse(dbURL) u, err := url.Parse(dbURL)
@@ -22,17 +38,13 @@ func New(ctx context.Context, dbURL string) (models.LogStore, error) {
log.WithError(err).WithFields(logrus.Fields{"url": dbURL}).Fatal("bad DB URL") log.WithError(err).WithFields(logrus.Fields{"url": dbURL}).Fatal("bad DB URL")
} }
log.WithFields(logrus.Fields{"db": u.Scheme}).Debug("creating log store") log.WithFields(logrus.Fields{"db": u.Scheme}).Debug("creating log store")
var ls models.LogStore
switch u.Scheme {
case "sqlite3", "postgres", "mysql":
ls, err = sql.New(ctx, u)
case "s3":
ls, err = s3.New(u)
default:
err = fmt.Errorf("db type not supported %v", u.Scheme)
}
return ls, err for _, p := range providers {
if p.Supports(u) {
return p.New(ctx, u)
}
}
return nil, fmt.Errorf("no log store provider available for url %s", dbURL)
} }
func Wrap(ls models.LogStore) models.LogStore { func Wrap(ls models.LogStore) models.LogStore {

View File

@@ -21,6 +21,7 @@ import (
"github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/fnproject/fn/api/common" "github.com/fnproject/fn/api/common"
"github.com/fnproject/fn/api/id" "github.com/fnproject/fn/api/id"
"github.com/fnproject/fn/api/logs"
"github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/models"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"go.opencensus.io/stats" "go.opencensus.io/stats"
@@ -49,6 +50,8 @@ type store struct {
bucket string bucket string
} }
type s3StoreProvider int
// decorator around the Reader interface that keeps track of the number of bytes read // decorator around the Reader interface that keeps track of the number of bytes read
// in order to avoid double buffering and track Reader size // in order to avoid double buffering and track Reader size
type countingReader struct { type countingReader struct {
@@ -80,10 +83,18 @@ func createStore(bucketName, endpoint, region, accessKeyID, secretAccessKey stri
} }
} }
func (s3StoreProvider) String() string {
return "s3"
}
func (s3StoreProvider) Supports(u *url.URL) bool {
return u.Scheme == "s3"
}
// New returns an s3 api compatible log store. // New returns an s3 api compatible log store.
// url format: s3://access_key_id:secret_access_key@host/region/bucket_name?ssl=true // url format: s3://access_key_id:secret_access_key@host/region/bucket_name?ssl=true
// Note that access_key_id and secret_access_key must be URL encoded if they contain unsafe characters! // Note that access_key_id and secret_access_key must be URL encoded if they contain unsafe characters!
func New(u *url.URL) (models.LogStore, error) { func (s3StoreProvider) New(ctx context.Context, u *url.URL) (models.LogStore, error) {
endpoint := u.Host endpoint := u.Host
var accessKeyID, secretAccessKey string var accessKeyID, secretAccessKey string
@@ -464,3 +475,7 @@ func init() {
func (s *store) Close() error { func (s *store) Close() error {
return nil return nil
} }
func init() {
logs.AddProvider(s3StoreProvider(0))
}

View File

@@ -5,6 +5,7 @@ import (
"os" "os"
"testing" "testing"
"context"
logTesting "github.com/fnproject/fn/api/logs/testing" logTesting "github.com/fnproject/fn/api/logs/testing"
) )
@@ -20,7 +21,7 @@ func TestS3(t *testing.T) {
t.Fatalf("failed to parse url: %v", err) t.Fatalf("failed to parse url: %v", err)
} }
ls, err := New(uLog) ls, err := s3StoreProvider(0).New(context.Background(), uLog)
if err != nil { if err != nil {
t.Fatalf("failed to create s3 datastore: %v", err) t.Fatalf("failed to create s3 datastore: %v", err)
} }

View File

@@ -1,42 +0,0 @@
package models
import (
"encoding/json"
strfmt "github.com/go-openapi/strfmt"
"github.com/go-openapi/validate"
)
/*Reason Machine usable reason for job being in this state.
Valid values for error status are `timeout | killed | bad_exit`.
Valid values for cancelled status are `client_request`.
For everything else, this is undefined.
swagger:model Reason
*/
type Reason string
// for schema
var reasonEnum []interface{}
func (m Reason) validateReasonEnum(path, location string, value Reason) error {
if reasonEnum == nil {
var res []Reason
if err := json.Unmarshal([]byte(`["timeout","killed","bad_exit","client_request"]`), &res); err != nil {
return err
}
for _, v := range res {
reasonEnum = append(reasonEnum, v)
}
}
err := validate.Enum(path, location, value, reasonEnum)
return err
}
// Validate validates this reason
func (m Reason) Validate(formats strfmt.Registry) error {
// value enum
return m.validateReasonEnum("", "body", m)
}

View File

@@ -1,4 +1,4 @@
package mqs package bolt
import ( import (
"context" "context"
@@ -14,9 +14,12 @@ import (
"github.com/boltdb/bolt" "github.com/boltdb/bolt"
"github.com/fnproject/fn/api/common" "github.com/fnproject/fn/api/common"
"github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/models"
"github.com/fnproject/fn/api/mqs"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
type boltProvider int
type BoltDbMQ struct { type BoltDbMQ struct {
db *bolt.DB db *bolt.DB
ticker *time.Ticker ticker *time.Ticker
@@ -52,7 +55,19 @@ func timeoutName(i int) []byte {
return []byte(fmt.Sprintf("functions_%d_timeout", i)) return []byte(fmt.Sprintf("functions_%d_timeout", i))
} }
func NewBoltMQ(url *url.URL) (*BoltDbMQ, error) { func (boltProvider) Supports(url *url.URL) bool {
switch url.Scheme {
case "bolt":
return true
}
return false
}
func (boltProvider) String() string {
return "bolt"
}
func (boltProvider) New(url *url.URL) (models.MessageQueue, error) {
dir := filepath.Dir(url.Path) dir := filepath.Dir(url.Path)
log := logrus.WithFields(logrus.Fields{"mq": url.Scheme, "dir": dir}) log := logrus.WithFields(logrus.Fields{"mq": url.Scheme, "dir": dir})
err := os.MkdirAll(dir, 0755) err := os.MkdirAll(dir, 0755)
@@ -354,3 +369,7 @@ func (mq *BoltDbMQ) Close() error {
mq.ticker.Stop() mq.ticker.Stop()
return mq.db.Close() return mq.db.Close()
} }
func init() {
mqs.AddProvider(boltProvider(0))
}

View File

@@ -1,18 +1,21 @@
package mqs package memory
import ( import (
"context" "context"
"errors" "errors"
"math/rand"
"sync" "sync"
"time" "time"
"github.com/fnproject/fn/api/common" "github.com/fnproject/fn/api/common"
"github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/models"
"github.com/fnproject/fn/api/mqs"
"github.com/google/btree" "github.com/google/btree"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"net/url"
) )
type memoryProvider int
type MemoryMQ struct { type MemoryMQ struct {
// WorkQueue A buffered channel that we can send work requests on. // WorkQueue A buffered channel that we can send work requests on.
PriorityQueues []chan *models.Call PriorityQueues []chan *models.Call
@@ -26,20 +29,17 @@ type MemoryMQ struct {
Mutex sync.Mutex Mutex sync.Mutex
} }
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randSeq(n int) string {
rand.Seed(time.Now().Unix())
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}
const NumPriorities = 3 const NumPriorities = 3
func NewMemoryMQ() *MemoryMQ { func (memoryProvider) Supports(url *url.URL) bool {
switch url.Scheme {
case "memory":
return true
}
return false
}
func (memoryProvider) New(url *url.URL) (models.MessageQueue, error) {
var queues []chan *models.Call var queues []chan *models.Call
for i := 0; i < NumPriorities; i++ { for i := 0; i < NumPriorities; i++ {
queues = append(queues, make(chan *models.Call, 5000)) queues = append(queues, make(chan *models.Call, 5000))
@@ -53,7 +53,11 @@ func NewMemoryMQ() *MemoryMQ {
} }
mq.start() mq.start()
logrus.Info("MemoryMQ initialized") logrus.Info("MemoryMQ initialized")
return mq return mq, nil
}
func (memoryProvider) String() string {
return "memory"
} }
func (mq *MemoryMQ) start() { func (mq *MemoryMQ) start() {
@@ -198,3 +202,7 @@ func (mq *MemoryMQ) Close() error {
mq.Ticker.Stop() mq.Ticker.Stop()
return nil return nil
} }
func init() {
mqs.AddProvider(memoryProvider(0))
}

View File

@@ -11,6 +11,22 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// Provider for message queue extensions
type Provider interface {
fmt.Stringer
//Supports indicates if this provider can handle a specific URL scheme
Supports(url *url.URL) bool
//New creates a new message queue from a given URL
New(url *url.URL) (models.MessageQueue, error)
}
var mqProviders []Provider
// AddProvider registers a new global message queue provider
func AddProvider(p Provider) {
mqProviders = append(mqProviders, p)
}
// New will parse the URL and return the correct MQ implementation. // New will parse the URL and return the correct MQ implementation.
func New(mqURL string) (models.MessageQueue, error) { func New(mqURL string) (models.MessageQueue, error) {
mq, err := newmq(mqURL) mq, err := newmq(mqURL)
@@ -27,15 +43,11 @@ func newmq(mqURL string) (models.MessageQueue, error) {
logrus.WithError(err).WithFields(logrus.Fields{"url": mqURL}).Fatal("bad MQ URL") logrus.WithError(err).WithFields(logrus.Fields{"url": mqURL}).Fatal("bad MQ URL")
} }
logrus.WithFields(logrus.Fields{"mq": u.Scheme}).Debug("selecting MQ") logrus.WithFields(logrus.Fields{"mq": u.Scheme}).Debug("selecting MQ")
switch u.Scheme { for _, p := range mqProviders {
case "memory": if p.Supports(u) {
return NewMemoryMQ(), nil return p.New(u)
case "redis": }
return NewRedisMQ(u)
case "bolt":
return NewBoltMQ(u)
} }
return nil, fmt.Errorf("mq type not supported %v", u.Scheme) return nil, fmt.Errorf("mq type not supported %v", u.Scheme)
} }

View File

@@ -1,4 +1,4 @@
package mqs package redis
import ( import (
"context" "context"
@@ -11,6 +11,7 @@ import (
"github.com/fnproject/fn/api/common" "github.com/fnproject/fn/api/common"
"github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/models"
"github.com/fnproject/fn/api/mqs"
"github.com/garyburd/redigo/redis" "github.com/garyburd/redigo/redis"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@@ -22,7 +23,21 @@ type RedisMQ struct {
prefix string prefix string
} }
func NewRedisMQ(url *url.URL) (*RedisMQ, error) { type redisProvider int
func (redisProvider) Supports(url *url.URL) bool {
switch url.Scheme {
case "redis":
return true
}
return false
}
func (redisProvider) String() string {
return "redis"
}
func (redisProvider) New(url *url.URL) (models.MessageQueue, error) {
pool := &redis.Pool{ pool := &redis.Pool{
MaxIdle: 512, MaxIdle: 512,
// I'm not sure if allowing the pool to block if more than 16 connections are required is a good idea. // I'm not sure if allowing the pool to block if more than 16 connections are required is a good idea.
@@ -314,3 +329,7 @@ func (mq *RedisMQ) Close() error {
mq.ticker.Stop() mq.ticker.Stop()
return mq.pool.Close() return mq.pool.Close()
} }
func init() {
mqs.AddProvider(redisProvider(0))
}

View File

@@ -0,0 +1,14 @@
// defaultexts are the extensions that are auto-loaded in to the default fnserver binary
// included here as a package to simplify inclusion in testing
package defaultexts
import (
_ "github.com/fnproject/fn/api/datastore/sql"
_ "github.com/fnproject/fn/api/datastore/sql/mysql"
_ "github.com/fnproject/fn/api/datastore/sql/postgres"
_ "github.com/fnproject/fn/api/datastore/sql/sqlite"
_ "github.com/fnproject/fn/api/logs/s3"
_ "github.com/fnproject/fn/api/mqs/bolt"
_ "github.com/fnproject/fn/api/mqs/memory"
_ "github.com/fnproject/fn/api/mqs/redis"
)

View File

@@ -16,6 +16,7 @@ import (
"github.com/fnproject/fn/api/agent" "github.com/fnproject/fn/api/agent"
"github.com/fnproject/fn/api/datastore" "github.com/fnproject/fn/api/datastore"
"github.com/fnproject/fn/api/datastore/sql" "github.com/fnproject/fn/api/datastore/sql"
_ "github.com/fnproject/fn/api/datastore/sql/sqlite"
"github.com/fnproject/fn/api/id" "github.com/fnproject/fn/api/id"
"github.com/fnproject/fn/api/logs" "github.com/fnproject/fn/api/logs"
"github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/models"

View File

@@ -5,6 +5,8 @@ import (
"github.com/fnproject/fn/api/server" "github.com/fnproject/fn/api/server"
// EXTENSIONS: Add extension imports here or use `fn build-server`. Learn more: https://github.com/fnproject/fn/blob/master/docs/operating/extending.md // EXTENSIONS: Add extension imports here or use `fn build-server`. Learn more: https://github.com/fnproject/fn/blob/master/docs/operating/extending.md
_ "github.com/fnproject/fn/api/server/defaultexts"
) )
func main() { func main() {

View File

@@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/fnproject/fn/api/server" "github.com/fnproject/fn/api/server"
_ "github.com/fnproject/fn/api/server/defaultexts"
) )
func stopServer(done chan struct{}, stop func()) { func stopServer(done chan struct{}, stop func()) {

View File

@@ -9,6 +9,7 @@ import (
"github.com/fnproject/fn/api/agent/hybrid" "github.com/fnproject/fn/api/agent/hybrid"
pool "github.com/fnproject/fn/api/runnerpool" pool "github.com/fnproject/fn/api/runnerpool"
"github.com/fnproject/fn/api/server" "github.com/fnproject/fn/api/server"
_ "github.com/fnproject/fn/api/server/defaultexts"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"

View File

@@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2015 Microsoft Corporation
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@@ -1,12 +0,0 @@
# go-ansiterm
This is a cross platform Ansi Terminal Emulation library. It reads a stream of Ansi characters and produces the appropriate function calls. The results of the function calls are platform dependent.
For example the parser might receive "ESC, [, A" as a stream of three characters. This is the code for Cursor Up (http://www.vt100.net/docs/vt510-rm/CUU). The parser then calls the cursor up function (CUU()) on an event handler. The event handler determines what platform specific work must be done to cause the cursor to move up one position.
The parser (parser.go) is a partial implementation of this state machine (http://vt100.net/emu/vt500_parser.png). There are also two event handler implementations, one for tests (test_event_handler.go) to validate that the expected events are being produced and called, the other is a Windows implementation (winterm/win_event_handler.go).
See parser_test.go for examples exercising the state machine and generating appropriate function calls.
-----
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

View File

@@ -1,188 +0,0 @@
package ansiterm
const LogEnv = "DEBUG_TERMINAL"
// ANSI constants
// References:
// -- http://www.ecma-international.org/publications/standards/Ecma-048.htm
// -- http://man7.org/linux/man-pages/man4/console_codes.4.html
// -- http://manpages.ubuntu.com/manpages/intrepid/man4/console_codes.4.html
// -- http://en.wikipedia.org/wiki/ANSI_escape_code
// -- http://vt100.net/emu/dec_ansi_parser
// -- http://vt100.net/emu/vt500_parser.svg
// -- http://invisible-island.net/xterm/ctlseqs/ctlseqs.html
// -- http://www.inwap.com/pdp10/ansicode.txt
const (
// ECMA-48 Set Graphics Rendition
// Note:
// -- Constants leading with an underscore (e.g., _ANSI_xxx) are unsupported or reserved
// -- Fonts could possibly be supported via SetCurrentConsoleFontEx
// -- Windows does not expose the per-window cursor (i.e., caret) blink times
ANSI_SGR_RESET = 0
ANSI_SGR_BOLD = 1
ANSI_SGR_DIM = 2
_ANSI_SGR_ITALIC = 3
ANSI_SGR_UNDERLINE = 4
_ANSI_SGR_BLINKSLOW = 5
_ANSI_SGR_BLINKFAST = 6
ANSI_SGR_REVERSE = 7
_ANSI_SGR_INVISIBLE = 8
_ANSI_SGR_LINETHROUGH = 9
_ANSI_SGR_FONT_00 = 10
_ANSI_SGR_FONT_01 = 11
_ANSI_SGR_FONT_02 = 12
_ANSI_SGR_FONT_03 = 13
_ANSI_SGR_FONT_04 = 14
_ANSI_SGR_FONT_05 = 15
_ANSI_SGR_FONT_06 = 16
_ANSI_SGR_FONT_07 = 17
_ANSI_SGR_FONT_08 = 18
_ANSI_SGR_FONT_09 = 19
_ANSI_SGR_FONT_10 = 20
_ANSI_SGR_DOUBLEUNDERLINE = 21
ANSI_SGR_BOLD_DIM_OFF = 22
_ANSI_SGR_ITALIC_OFF = 23
ANSI_SGR_UNDERLINE_OFF = 24
_ANSI_SGR_BLINK_OFF = 25
_ANSI_SGR_RESERVED_00 = 26
ANSI_SGR_REVERSE_OFF = 27
_ANSI_SGR_INVISIBLE_OFF = 28
_ANSI_SGR_LINETHROUGH_OFF = 29
ANSI_SGR_FOREGROUND_BLACK = 30
ANSI_SGR_FOREGROUND_RED = 31
ANSI_SGR_FOREGROUND_GREEN = 32
ANSI_SGR_FOREGROUND_YELLOW = 33
ANSI_SGR_FOREGROUND_BLUE = 34
ANSI_SGR_FOREGROUND_MAGENTA = 35
ANSI_SGR_FOREGROUND_CYAN = 36
ANSI_SGR_FOREGROUND_WHITE = 37
_ANSI_SGR_RESERVED_01 = 38
ANSI_SGR_FOREGROUND_DEFAULT = 39
ANSI_SGR_BACKGROUND_BLACK = 40
ANSI_SGR_BACKGROUND_RED = 41
ANSI_SGR_BACKGROUND_GREEN = 42
ANSI_SGR_BACKGROUND_YELLOW = 43
ANSI_SGR_BACKGROUND_BLUE = 44
ANSI_SGR_BACKGROUND_MAGENTA = 45
ANSI_SGR_BACKGROUND_CYAN = 46
ANSI_SGR_BACKGROUND_WHITE = 47
_ANSI_SGR_RESERVED_02 = 48
ANSI_SGR_BACKGROUND_DEFAULT = 49
// 50 - 65: Unsupported
ANSI_MAX_CMD_LENGTH = 4096
MAX_INPUT_EVENTS = 128
DEFAULT_WIDTH = 80
DEFAULT_HEIGHT = 24
ANSI_BEL = 0x07
ANSI_BACKSPACE = 0x08
ANSI_TAB = 0x09
ANSI_LINE_FEED = 0x0A
ANSI_VERTICAL_TAB = 0x0B
ANSI_FORM_FEED = 0x0C
ANSI_CARRIAGE_RETURN = 0x0D
ANSI_ESCAPE_PRIMARY = 0x1B
ANSI_ESCAPE_SECONDARY = 0x5B
ANSI_OSC_STRING_ENTRY = 0x5D
ANSI_COMMAND_FIRST = 0x40
ANSI_COMMAND_LAST = 0x7E
DCS_ENTRY = 0x90
CSI_ENTRY = 0x9B
OSC_STRING = 0x9D
ANSI_PARAMETER_SEP = ";"
ANSI_CMD_G0 = '('
ANSI_CMD_G1 = ')'
ANSI_CMD_G2 = '*'
ANSI_CMD_G3 = '+'
ANSI_CMD_DECPNM = '>'
ANSI_CMD_DECPAM = '='
ANSI_CMD_OSC = ']'
ANSI_CMD_STR_TERM = '\\'
KEY_CONTROL_PARAM_2 = ";2"
KEY_CONTROL_PARAM_3 = ";3"
KEY_CONTROL_PARAM_4 = ";4"
KEY_CONTROL_PARAM_5 = ";5"
KEY_CONTROL_PARAM_6 = ";6"
KEY_CONTROL_PARAM_7 = ";7"
KEY_CONTROL_PARAM_8 = ";8"
KEY_ESC_CSI = "\x1B["
KEY_ESC_N = "\x1BN"
KEY_ESC_O = "\x1BO"
FILL_CHARACTER = ' '
)
func getByteRange(start byte, end byte) []byte {
bytes := make([]byte, 0, 32)
for i := start; i <= end; i++ {
bytes = append(bytes, byte(i))
}
return bytes
}
var toGroundBytes = getToGroundBytes()
var executors = getExecuteBytes()
// SPACE 20+A0 hex Always and everywhere a blank space
// Intermediate 20-2F hex !"#$%&'()*+,-./
var intermeds = getByteRange(0x20, 0x2F)
// Parameters 30-3F hex 0123456789:;<=>?
// CSI Parameters 30-39, 3B hex 0123456789;
var csiParams = getByteRange(0x30, 0x3F)
var csiCollectables = append(getByteRange(0x30, 0x39), getByteRange(0x3B, 0x3F)...)
// Uppercase 40-5F hex @ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_
var upperCase = getByteRange(0x40, 0x5F)
// Lowercase 60-7E hex `abcdefghijlkmnopqrstuvwxyz{|}~
var lowerCase = getByteRange(0x60, 0x7E)
// Alphabetics 40-7E hex (all of upper and lower case)
var alphabetics = append(upperCase, lowerCase...)
var printables = getByteRange(0x20, 0x7F)
var escapeIntermediateToGroundBytes = getByteRange(0x30, 0x7E)
var escapeToGroundBytes = getEscapeToGroundBytes()
// See http://www.vt100.net/emu/vt500_parser.png for description of the complex
// byte ranges below
func getEscapeToGroundBytes() []byte {
escapeToGroundBytes := getByteRange(0x30, 0x4F)
escapeToGroundBytes = append(escapeToGroundBytes, getByteRange(0x51, 0x57)...)
escapeToGroundBytes = append(escapeToGroundBytes, 0x59)
escapeToGroundBytes = append(escapeToGroundBytes, 0x5A)
escapeToGroundBytes = append(escapeToGroundBytes, 0x5C)
escapeToGroundBytes = append(escapeToGroundBytes, getByteRange(0x60, 0x7E)...)
return escapeToGroundBytes
}
func getExecuteBytes() []byte {
executeBytes := getByteRange(0x00, 0x17)
executeBytes = append(executeBytes, 0x19)
executeBytes = append(executeBytes, getByteRange(0x1C, 0x1F)...)
return executeBytes
}
func getToGroundBytes() []byte {
groundBytes := []byte{0x18}
groundBytes = append(groundBytes, 0x1A)
groundBytes = append(groundBytes, getByteRange(0x80, 0x8F)...)
groundBytes = append(groundBytes, getByteRange(0x91, 0x97)...)
groundBytes = append(groundBytes, 0x99)
groundBytes = append(groundBytes, 0x9A)
groundBytes = append(groundBytes, 0x9C)
return groundBytes
}
// Delete 7F hex Always and everywhere ignored
// C1 Control 80-9F hex 32 additional control characters
// G1 Displayable A1-FE hex 94 additional displayable characters
// Special A0+FF hex Same as SPACE and DELETE

View File

@@ -1,7 +0,0 @@
package ansiterm
type ansiContext struct {
currentChar byte
paramBuffer []byte
interBuffer []byte
}

View File

@@ -1,49 +0,0 @@
package ansiterm
type csiEntryState struct {
baseState
}
func (csiState csiEntryState) Handle(b byte) (s state, e error) {
csiState.parser.logf("CsiEntry::Handle %#x", b)
nextState, err := csiState.baseState.Handle(b)
if nextState != nil || err != nil {
return nextState, err
}
switch {
case sliceContains(alphabetics, b):
return csiState.parser.ground, nil
case sliceContains(csiCollectables, b):
return csiState.parser.csiParam, nil
case sliceContains(executors, b):
return csiState, csiState.parser.execute()
}
return csiState, nil
}
func (csiState csiEntryState) Transition(s state) error {
csiState.parser.logf("CsiEntry::Transition %s --> %s", csiState.Name(), s.Name())
csiState.baseState.Transition(s)
switch s {
case csiState.parser.ground:
return csiState.parser.csiDispatch()
case csiState.parser.csiParam:
switch {
case sliceContains(csiParams, csiState.parser.context.currentChar):
csiState.parser.collectParam()
case sliceContains(intermeds, csiState.parser.context.currentChar):
csiState.parser.collectInter()
}
}
return nil
}
func (csiState csiEntryState) Enter() error {
csiState.parser.clear()
return nil
}

View File

@@ -1,38 +0,0 @@
package ansiterm
type csiParamState struct {
baseState
}
func (csiState csiParamState) Handle(b byte) (s state, e error) {
csiState.parser.logf("CsiParam::Handle %#x", b)
nextState, err := csiState.baseState.Handle(b)
if nextState != nil || err != nil {
return nextState, err
}
switch {
case sliceContains(alphabetics, b):
return csiState.parser.ground, nil
case sliceContains(csiCollectables, b):
csiState.parser.collectParam()
return csiState, nil
case sliceContains(executors, b):
return csiState, csiState.parser.execute()
}
return csiState, nil
}
func (csiState csiParamState) Transition(s state) error {
csiState.parser.logf("CsiParam::Transition %s --> %s", csiState.Name(), s.Name())
csiState.baseState.Transition(s)
switch s {
case csiState.parser.ground:
return csiState.parser.csiDispatch()
}
return nil
}

View File

@@ -1,36 +0,0 @@
package ansiterm
type escapeIntermediateState struct {
baseState
}
func (escState escapeIntermediateState) Handle(b byte) (s state, e error) {
escState.parser.logf("escapeIntermediateState::Handle %#x", b)
nextState, err := escState.baseState.Handle(b)
if nextState != nil || err != nil {
return nextState, err
}
switch {
case sliceContains(intermeds, b):
return escState, escState.parser.collectInter()
case sliceContains(executors, b):
return escState, escState.parser.execute()
case sliceContains(escapeIntermediateToGroundBytes, b):
return escState.parser.ground, nil
}
return escState, nil
}
func (escState escapeIntermediateState) Transition(s state) error {
escState.parser.logf("escapeIntermediateState::Transition %s --> %s", escState.Name(), s.Name())
escState.baseState.Transition(s)
switch s {
case escState.parser.ground:
return escState.parser.escDispatch()
}
return nil
}

View File

@@ -1,47 +0,0 @@
package ansiterm
type escapeState struct {
baseState
}
func (escState escapeState) Handle(b byte) (s state, e error) {
escState.parser.logf("escapeState::Handle %#x", b)
nextState, err := escState.baseState.Handle(b)
if nextState != nil || err != nil {
return nextState, err
}
switch {
case b == ANSI_ESCAPE_SECONDARY:
return escState.parser.csiEntry, nil
case b == ANSI_OSC_STRING_ENTRY:
return escState.parser.oscString, nil
case sliceContains(executors, b):
return escState, escState.parser.execute()
case sliceContains(escapeToGroundBytes, b):
return escState.parser.ground, nil
case sliceContains(intermeds, b):
return escState.parser.escapeIntermediate, nil
}
return escState, nil
}
func (escState escapeState) Transition(s state) error {
escState.parser.logf("Escape::Transition %s --> %s", escState.Name(), s.Name())
escState.baseState.Transition(s)
switch s {
case escState.parser.ground:
return escState.parser.escDispatch()
case escState.parser.escapeIntermediate:
return escState.parser.collectInter()
}
return nil
}
func (escState escapeState) Enter() error {
escState.parser.clear()
return nil
}

View File

@@ -1,90 +0,0 @@
package ansiterm
type AnsiEventHandler interface {
// Print
Print(b byte) error
// Execute C0 commands
Execute(b byte) error
// CUrsor Up
CUU(int) error
// CUrsor Down
CUD(int) error
// CUrsor Forward
CUF(int) error
// CUrsor Backward
CUB(int) error
// Cursor to Next Line
CNL(int) error
// Cursor to Previous Line
CPL(int) error
// Cursor Horizontal position Absolute
CHA(int) error
// Vertical line Position Absolute
VPA(int) error
// CUrsor Position
CUP(int, int) error
// Horizontal and Vertical Position (depends on PUM)
HVP(int, int) error
// Text Cursor Enable Mode
DECTCEM(bool) error
// Origin Mode
DECOM(bool) error
// 132 Column Mode
DECCOLM(bool) error
// Erase in Display
ED(int) error
// Erase in Line
EL(int) error
// Insert Line
IL(int) error
// Delete Line
DL(int) error
// Insert Character
ICH(int) error
// Delete Character
DCH(int) error
// Set Graphics Rendition
SGR([]int) error
// Pan Down
SU(int) error
// Pan Up
SD(int) error
// Device Attributes
DA([]string) error
// Set Top and Bottom Margins
DECSTBM(int, int) error
// Index
IND() error
// Reverse Index
RI() error
// Flush updates from previous commands
Flush() error
}

View File

@@ -1,24 +0,0 @@
package ansiterm
type groundState struct {
baseState
}
func (gs groundState) Handle(b byte) (s state, e error) {
gs.parser.context.currentChar = b
nextState, err := gs.baseState.Handle(b)
if nextState != nil || err != nil {
return nextState, err
}
switch {
case sliceContains(printables, b):
return gs, gs.parser.print()
case sliceContains(executors, b):
return gs, gs.parser.execute()
}
return gs, nil
}

View File

@@ -1,31 +0,0 @@
package ansiterm
type oscStringState struct {
baseState
}
func (oscState oscStringState) Handle(b byte) (s state, e error) {
oscState.parser.logf("OscString::Handle %#x", b)
nextState, err := oscState.baseState.Handle(b)
if nextState != nil || err != nil {
return nextState, err
}
switch {
case isOscStringTerminator(b):
return oscState.parser.ground, nil
}
return oscState, nil
}
// See below for OSC string terminators for linux
// http://man7.org/linux/man-pages/man4/console_codes.4.html
func isOscStringTerminator(b byte) bool {
if b == ANSI_BEL || b == 0x5C {
return true
}
return false
}

View File

@@ -1,151 +0,0 @@
package ansiterm
import (
"errors"
"log"
"os"
)
type AnsiParser struct {
currState state
eventHandler AnsiEventHandler
context *ansiContext
csiEntry state
csiParam state
dcsEntry state
escape state
escapeIntermediate state
error state
ground state
oscString state
stateMap []state
logf func(string, ...interface{})
}
type Option func(*AnsiParser)
func WithLogf(f func(string, ...interface{})) Option {
return func(ap *AnsiParser) {
ap.logf = f
}
}
func CreateParser(initialState string, evtHandler AnsiEventHandler, opts ...Option) *AnsiParser {
ap := &AnsiParser{
eventHandler: evtHandler,
context: &ansiContext{},
}
for _, o := range opts {
o(ap)
}
if isDebugEnv := os.Getenv(LogEnv); isDebugEnv == "1" {
logFile, _ := os.Create("ansiParser.log")
logger := log.New(logFile, "", log.LstdFlags)
if ap.logf != nil {
l := ap.logf
ap.logf = func(s string, v ...interface{}) {
l(s, v...)
logger.Printf(s, v...)
}
} else {
ap.logf = logger.Printf
}
}
if ap.logf == nil {
ap.logf = func(string, ...interface{}) {}
}
ap.csiEntry = csiEntryState{baseState{name: "CsiEntry", parser: ap}}
ap.csiParam = csiParamState{baseState{name: "CsiParam", parser: ap}}
ap.dcsEntry = dcsEntryState{baseState{name: "DcsEntry", parser: ap}}
ap.escape = escapeState{baseState{name: "Escape", parser: ap}}
ap.escapeIntermediate = escapeIntermediateState{baseState{name: "EscapeIntermediate", parser: ap}}
ap.error = errorState{baseState{name: "Error", parser: ap}}
ap.ground = groundState{baseState{name: "Ground", parser: ap}}
ap.oscString = oscStringState{baseState{name: "OscString", parser: ap}}
ap.stateMap = []state{
ap.csiEntry,
ap.csiParam,
ap.dcsEntry,
ap.escape,
ap.escapeIntermediate,
ap.error,
ap.ground,
ap.oscString,
}
ap.currState = getState(initialState, ap.stateMap)
ap.logf("CreateParser: parser %p", ap)
return ap
}
func getState(name string, states []state) state {
for _, el := range states {
if el.Name() == name {
return el
}
}
return nil
}
func (ap *AnsiParser) Parse(bytes []byte) (int, error) {
for i, b := range bytes {
if err := ap.handle(b); err != nil {
return i, err
}
}
return len(bytes), ap.eventHandler.Flush()
}
func (ap *AnsiParser) handle(b byte) error {
ap.context.currentChar = b
newState, err := ap.currState.Handle(b)
if err != nil {
return err
}
if newState == nil {
ap.logf("WARNING: newState is nil")
return errors.New("New state of 'nil' is invalid.")
}
if newState != ap.currState {
if err := ap.changeState(newState); err != nil {
return err
}
}
return nil
}
func (ap *AnsiParser) changeState(newState state) error {
ap.logf("ChangeState %s --> %s", ap.currState.Name(), newState.Name())
// Exit old state
if err := ap.currState.Exit(); err != nil {
ap.logf("Exit state '%s' failed with : '%v'", ap.currState.Name(), err)
return err
}
// Perform transition action
if err := ap.currState.Transition(newState); err != nil {
ap.logf("Transition from '%s' to '%s' failed with: '%v'", ap.currState.Name(), newState.Name, err)
return err
}
// Enter new state
if err := newState.Enter(); err != nil {
ap.logf("Enter state '%s' failed with: '%v'", newState.Name(), err)
return err
}
ap.currState = newState
return nil
}

View File

@@ -1,99 +0,0 @@
package ansiterm
import (
"strconv"
)
func parseParams(bytes []byte) ([]string, error) {
paramBuff := make([]byte, 0, 0)
params := []string{}
for _, v := range bytes {
if v == ';' {
if len(paramBuff) > 0 {
// Completed parameter, append it to the list
s := string(paramBuff)
params = append(params, s)
paramBuff = make([]byte, 0, 0)
}
} else {
paramBuff = append(paramBuff, v)
}
}
// Last parameter may not be terminated with ';'
if len(paramBuff) > 0 {
s := string(paramBuff)
params = append(params, s)
}
return params, nil
}
func parseCmd(context ansiContext) (string, error) {
return string(context.currentChar), nil
}
func getInt(params []string, dflt int) int {
i := getInts(params, 1, dflt)[0]
return i
}
func getInts(params []string, minCount int, dflt int) []int {
ints := []int{}
for _, v := range params {
i, _ := strconv.Atoi(v)
// Zero is mapped to the default value in VT100.
if i == 0 {
i = dflt
}
ints = append(ints, i)
}
if len(ints) < minCount {
remaining := minCount - len(ints)
for i := 0; i < remaining; i++ {
ints = append(ints, dflt)
}
}
return ints
}
func (ap *AnsiParser) modeDispatch(param string, set bool) error {
switch param {
case "?3":
return ap.eventHandler.DECCOLM(set)
case "?6":
return ap.eventHandler.DECOM(set)
case "?25":
return ap.eventHandler.DECTCEM(set)
}
return nil
}
func (ap *AnsiParser) hDispatch(params []string) error {
if len(params) == 1 {
return ap.modeDispatch(params[0], true)
}
return nil
}
func (ap *AnsiParser) lDispatch(params []string) error {
if len(params) == 1 {
return ap.modeDispatch(params[0], false)
}
return nil
}
func getEraseParam(params []string) int {
param := getInt(params, 0)
if param < 0 || 3 < param {
param = 0
}
return param
}

View File

@@ -1,119 +0,0 @@
package ansiterm
func (ap *AnsiParser) collectParam() error {
currChar := ap.context.currentChar
ap.logf("collectParam %#x", currChar)
ap.context.paramBuffer = append(ap.context.paramBuffer, currChar)
return nil
}
func (ap *AnsiParser) collectInter() error {
currChar := ap.context.currentChar
ap.logf("collectInter %#x", currChar)
ap.context.paramBuffer = append(ap.context.interBuffer, currChar)
return nil
}
func (ap *AnsiParser) escDispatch() error {
cmd, _ := parseCmd(*ap.context)
intermeds := ap.context.interBuffer
ap.logf("escDispatch currentChar: %#x", ap.context.currentChar)
ap.logf("escDispatch: %v(%v)", cmd, intermeds)
switch cmd {
case "D": // IND
return ap.eventHandler.IND()
case "E": // NEL, equivalent to CRLF
err := ap.eventHandler.Execute(ANSI_CARRIAGE_RETURN)
if err == nil {
err = ap.eventHandler.Execute(ANSI_LINE_FEED)
}
return err
case "M": // RI
return ap.eventHandler.RI()
}
return nil
}
func (ap *AnsiParser) csiDispatch() error {
cmd, _ := parseCmd(*ap.context)
params, _ := parseParams(ap.context.paramBuffer)
ap.logf("Parsed params: %v with length: %d", params, len(params))
ap.logf("csiDispatch: %v(%v)", cmd, params)
switch cmd {
case "@":
return ap.eventHandler.ICH(getInt(params, 1))
case "A":
return ap.eventHandler.CUU(getInt(params, 1))
case "B":
return ap.eventHandler.CUD(getInt(params, 1))
case "C":
return ap.eventHandler.CUF(getInt(params, 1))
case "D":
return ap.eventHandler.CUB(getInt(params, 1))
case "E":
return ap.eventHandler.CNL(getInt(params, 1))
case "F":
return ap.eventHandler.CPL(getInt(params, 1))
case "G":
return ap.eventHandler.CHA(getInt(params, 1))
case "H":
ints := getInts(params, 2, 1)
x, y := ints[0], ints[1]
return ap.eventHandler.CUP(x, y)
case "J":
param := getEraseParam(params)
return ap.eventHandler.ED(param)
case "K":
param := getEraseParam(params)
return ap.eventHandler.EL(param)
case "L":
return ap.eventHandler.IL(getInt(params, 1))
case "M":
return ap.eventHandler.DL(getInt(params, 1))
case "P":
return ap.eventHandler.DCH(getInt(params, 1))
case "S":
return ap.eventHandler.SU(getInt(params, 1))
case "T":
return ap.eventHandler.SD(getInt(params, 1))
case "c":
return ap.eventHandler.DA(params)
case "d":
return ap.eventHandler.VPA(getInt(params, 1))
case "f":
ints := getInts(params, 2, 1)
x, y := ints[0], ints[1]
return ap.eventHandler.HVP(x, y)
case "h":
return ap.hDispatch(params)
case "l":
return ap.lDispatch(params)
case "m":
return ap.eventHandler.SGR(getInts(params, 1, 0))
case "r":
ints := getInts(params, 2, 1)
top, bottom := ints[0], ints[1]
return ap.eventHandler.DECSTBM(top, bottom)
default:
ap.logf("ERROR: Unsupported CSI command: '%s', with full context: %v", cmd, ap.context)
return nil
}
}
func (ap *AnsiParser) print() error {
return ap.eventHandler.Print(ap.context.currentChar)
}
func (ap *AnsiParser) clear() error {
ap.context = &ansiContext{}
return nil
}
func (ap *AnsiParser) execute() error {
return ap.eventHandler.Execute(ap.context.currentChar)
}

View File

@@ -1,141 +0,0 @@
package ansiterm
import (
"fmt"
"testing"
)
func TestStateTransitions(t *testing.T) {
stateTransitionHelper(t, "CsiEntry", "Ground", alphabetics)
stateTransitionHelper(t, "CsiEntry", "CsiParam", csiCollectables)
stateTransitionHelper(t, "Escape", "CsiEntry", []byte{ANSI_ESCAPE_SECONDARY})
stateTransitionHelper(t, "Escape", "OscString", []byte{0x5D})
stateTransitionHelper(t, "Escape", "Ground", escapeToGroundBytes)
stateTransitionHelper(t, "Escape", "EscapeIntermediate", intermeds)
stateTransitionHelper(t, "EscapeIntermediate", "EscapeIntermediate", intermeds)
stateTransitionHelper(t, "EscapeIntermediate", "EscapeIntermediate", executors)
stateTransitionHelper(t, "EscapeIntermediate", "Ground", escapeIntermediateToGroundBytes)
stateTransitionHelper(t, "OscString", "Ground", []byte{ANSI_BEL})
stateTransitionHelper(t, "OscString", "Ground", []byte{0x5C})
stateTransitionHelper(t, "Ground", "Ground", executors)
}
func TestAnyToX(t *testing.T) {
anyToXHelper(t, []byte{ANSI_ESCAPE_PRIMARY}, "Escape")
anyToXHelper(t, []byte{DCS_ENTRY}, "DcsEntry")
anyToXHelper(t, []byte{OSC_STRING}, "OscString")
anyToXHelper(t, []byte{CSI_ENTRY}, "CsiEntry")
anyToXHelper(t, toGroundBytes, "Ground")
}
func TestCollectCsiParams(t *testing.T) {
parser, _ := createTestParser("CsiEntry")
parser.Parse(csiCollectables)
buffer := parser.context.paramBuffer
bufferCount := len(buffer)
if bufferCount != len(csiCollectables) {
t.Errorf("Buffer: %v", buffer)
t.Errorf("CsiParams: %v", csiCollectables)
t.Errorf("Buffer count failure: %d != %d", bufferCount, len(csiParams))
return
}
for i, v := range csiCollectables {
if v != buffer[i] {
t.Errorf("Buffer: %v", buffer)
t.Errorf("CsiParams: %v", csiParams)
t.Errorf("Mismatch at buffer[%d] = %d", i, buffer[i])
}
}
}
func TestParseParams(t *testing.T) {
parseParamsHelper(t, []byte{}, []string{})
parseParamsHelper(t, []byte{';'}, []string{})
parseParamsHelper(t, []byte{';', ';'}, []string{})
parseParamsHelper(t, []byte{'7'}, []string{"7"})
parseParamsHelper(t, []byte{'7', ';'}, []string{"7"})
parseParamsHelper(t, []byte{'7', ';', ';'}, []string{"7"})
parseParamsHelper(t, []byte{'7', ';', ';', '8'}, []string{"7", "8"})
parseParamsHelper(t, []byte{'7', ';', '8', ';'}, []string{"7", "8"})
parseParamsHelper(t, []byte{'7', ';', ';', '8', ';', ';'}, []string{"7", "8"})
parseParamsHelper(t, []byte{'7', '8'}, []string{"78"})
parseParamsHelper(t, []byte{'7', '8', ';'}, []string{"78"})
parseParamsHelper(t, []byte{'7', '8', ';', '9', '0'}, []string{"78", "90"})
parseParamsHelper(t, []byte{'7', '8', ';', ';', '9', '0'}, []string{"78", "90"})
parseParamsHelper(t, []byte{'7', '8', ';', '9', '0', ';'}, []string{"78", "90"})
parseParamsHelper(t, []byte{'7', '8', ';', '9', '0', ';', ';'}, []string{"78", "90"})
}
func TestCursor(t *testing.T) {
cursorSingleParamHelper(t, 'A', "CUU")
cursorSingleParamHelper(t, 'B', "CUD")
cursorSingleParamHelper(t, 'C', "CUF")
cursorSingleParamHelper(t, 'D', "CUB")
cursorSingleParamHelper(t, 'E', "CNL")
cursorSingleParamHelper(t, 'F', "CPL")
cursorSingleParamHelper(t, 'G', "CHA")
cursorTwoParamHelper(t, 'H', "CUP")
cursorTwoParamHelper(t, 'f', "HVP")
funcCallParamHelper(t, []byte{'?', '2', '5', 'h'}, "CsiEntry", "Ground", []string{"DECTCEM([true])"})
funcCallParamHelper(t, []byte{'?', '2', '5', 'l'}, "CsiEntry", "Ground", []string{"DECTCEM([false])"})
}
func TestErase(t *testing.T) {
// Erase in Display
eraseHelper(t, 'J', "ED")
// Erase in Line
eraseHelper(t, 'K', "EL")
}
func TestSelectGraphicRendition(t *testing.T) {
funcCallParamHelper(t, []byte{'m'}, "CsiEntry", "Ground", []string{"SGR([0])"})
funcCallParamHelper(t, []byte{'0', 'm'}, "CsiEntry", "Ground", []string{"SGR([0])"})
funcCallParamHelper(t, []byte{'0', ';', '1', 'm'}, "CsiEntry", "Ground", []string{"SGR([0 1])"})
funcCallParamHelper(t, []byte{'0', ';', '1', ';', '2', 'm'}, "CsiEntry", "Ground", []string{"SGR([0 1 2])"})
}
func TestScroll(t *testing.T) {
scrollHelper(t, 'S', "SU")
scrollHelper(t, 'T', "SD")
}
func TestPrint(t *testing.T) {
parser, evtHandler := createTestParser("Ground")
parser.Parse(printables)
validateState(t, parser.currState, "Ground")
for i, v := range printables {
expectedCall := fmt.Sprintf("Print([%s])", string(v))
actualCall := evtHandler.FunctionCalls[i]
if actualCall != expectedCall {
t.Errorf("Actual != Expected: %v != %v at %d", actualCall, expectedCall, i)
}
}
}
func TestClear(t *testing.T) {
p, _ := createTestParser("Ground")
fillContext(p.context)
p.clear()
validateEmptyContext(t, p.context)
}
func TestClearOnStateChange(t *testing.T) {
clearOnStateChangeHelper(t, "Ground", "Escape", []byte{ANSI_ESCAPE_PRIMARY})
clearOnStateChangeHelper(t, "Ground", "CsiEntry", []byte{CSI_ENTRY})
}
func TestC0(t *testing.T) {
expectedCall := "Execute([" + string(ANSI_LINE_FEED) + "])"
c0Helper(t, []byte{ANSI_LINE_FEED}, "Ground", []string{expectedCall})
expectedCall = "Execute([" + string(ANSI_CARRIAGE_RETURN) + "])"
c0Helper(t, []byte{ANSI_CARRIAGE_RETURN}, "Ground", []string{expectedCall})
}
func TestEscDispatch(t *testing.T) {
funcCallParamHelper(t, []byte{'M'}, "Escape", "Ground", []string{"RI([])"})
}

View File

@@ -1,114 +0,0 @@
package ansiterm
import (
"fmt"
"testing"
)
func getStateNames() []string {
parser, _ := createTestParser("Ground")
stateNames := []string{}
for _, state := range parser.stateMap {
stateNames = append(stateNames, state.Name())
}
return stateNames
}
func stateTransitionHelper(t *testing.T, start string, end string, bytes []byte) {
for _, b := range bytes {
bytes := []byte{byte(b)}
parser, _ := createTestParser(start)
parser.Parse(bytes)
validateState(t, parser.currState, end)
}
}
func anyToXHelper(t *testing.T, bytes []byte, expectedState string) {
for _, s := range getStateNames() {
stateTransitionHelper(t, s, expectedState, bytes)
}
}
func funcCallParamHelper(t *testing.T, bytes []byte, start string, expected string, expectedCalls []string) {
parser, evtHandler := createTestParser(start)
parser.Parse(bytes)
validateState(t, parser.currState, expected)
validateFuncCalls(t, evtHandler.FunctionCalls, expectedCalls)
}
func parseParamsHelper(t *testing.T, bytes []byte, expectedParams []string) {
params, err := parseParams(bytes)
if err != nil {
t.Errorf("Parameter parse error: %v", err)
return
}
if len(params) != len(expectedParams) {
t.Errorf("Parsed parameters: %v", params)
t.Errorf("Expected parameters: %v", expectedParams)
t.Errorf("Parameter length failure: %d != %d", len(params), len(expectedParams))
return
}
for i, v := range expectedParams {
if v != params[i] {
t.Errorf("Parsed parameters: %v", params)
t.Errorf("Expected parameters: %v", expectedParams)
t.Errorf("Parameter parse failure: %s != %s at position %d", v, params[i], i)
}
}
}
func cursorSingleParamHelper(t *testing.T, command byte, funcName string) {
funcCallParamHelper(t, []byte{command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'0', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'2', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2])", funcName)})
funcCallParamHelper(t, []byte{'2', '3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([23])", funcName)})
funcCallParamHelper(t, []byte{'2', ';', '3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2])", funcName)})
funcCallParamHelper(t, []byte{'2', ';', '3', ';', '4', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2])", funcName)})
}
func cursorTwoParamHelper(t *testing.T, command byte, funcName string) {
funcCallParamHelper(t, []byte{command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1 1])", funcName)})
funcCallParamHelper(t, []byte{'0', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1 1])", funcName)})
funcCallParamHelper(t, []byte{'2', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2 1])", funcName)})
funcCallParamHelper(t, []byte{'2', '3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([23 1])", funcName)})
funcCallParamHelper(t, []byte{'2', ';', '3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2 3])", funcName)})
funcCallParamHelper(t, []byte{'2', ';', '3', ';', '4', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2 3])", funcName)})
}
func eraseHelper(t *testing.T, command byte, funcName string) {
funcCallParamHelper(t, []byte{command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([0])", funcName)})
funcCallParamHelper(t, []byte{'0', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([0])", funcName)})
funcCallParamHelper(t, []byte{'1', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'2', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([2])", funcName)})
funcCallParamHelper(t, []byte{'3', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([3])", funcName)})
funcCallParamHelper(t, []byte{'4', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([0])", funcName)})
funcCallParamHelper(t, []byte{'1', ';', '2', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
}
func scrollHelper(t *testing.T, command byte, funcName string) {
funcCallParamHelper(t, []byte{command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'0', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'1', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([1])", funcName)})
funcCallParamHelper(t, []byte{'5', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([5])", funcName)})
funcCallParamHelper(t, []byte{'4', ';', '6', command}, "CsiEntry", "Ground", []string{fmt.Sprintf("%s([4])", funcName)})
}
func clearOnStateChangeHelper(t *testing.T, start string, end string, bytes []byte) {
p, _ := createTestParser(start)
fillContext(p.context)
p.Parse(bytes)
validateState(t, p.currState, end)
validateEmptyContext(t, p.context)
}
func c0Helper(t *testing.T, bytes []byte, expectedState string, expectedCalls []string) {
parser, evtHandler := createTestParser("Ground")
parser.Parse(bytes)
validateState(t, parser.currState, expectedState)
validateFuncCalls(t, evtHandler.FunctionCalls, expectedCalls)
}

View File

@@ -1,66 +0,0 @@
package ansiterm
import (
"testing"
)
func createTestParser(s string) (*AnsiParser, *TestAnsiEventHandler) {
evtHandler := CreateTestAnsiEventHandler()
parser := CreateParser(s, evtHandler)
return parser, evtHandler
}
func validateState(t *testing.T, actualState state, expectedStateName string) {
actualName := "Nil"
if actualState != nil {
actualName = actualState.Name()
}
if actualName != expectedStateName {
t.Errorf("Invalid state: '%s' != '%s'", actualName, expectedStateName)
}
}
func validateFuncCalls(t *testing.T, actualCalls []string, expectedCalls []string) {
actualCount := len(actualCalls)
expectedCount := len(expectedCalls)
if actualCount != expectedCount {
t.Errorf("Actual calls: %v", actualCalls)
t.Errorf("Expected calls: %v", expectedCalls)
t.Errorf("Call count error: %d != %d", actualCount, expectedCount)
return
}
for i, v := range actualCalls {
if v != expectedCalls[i] {
t.Errorf("Actual calls: %v", actualCalls)
t.Errorf("Expected calls: %v", expectedCalls)
t.Errorf("Mismatched calls: %s != %s with lengths %d and %d", v, expectedCalls[i], len(v), len(expectedCalls[i]))
}
}
}
func fillContext(context *ansiContext) {
context.currentChar = 'A'
context.paramBuffer = []byte{'C', 'D', 'E'}
context.interBuffer = []byte{'F', 'G', 'H'}
}
func validateEmptyContext(t *testing.T, context *ansiContext) {
var expectedCurrChar byte = 0x0
if context.currentChar != expectedCurrChar {
t.Errorf("Currentchar mismatch '%#x' != '%#x'", context.currentChar, expectedCurrChar)
}
if len(context.paramBuffer) != 0 {
t.Errorf("Non-empty parameter buffer: %v", context.paramBuffer)
}
if len(context.paramBuffer) != 0 {
t.Errorf("Non-empty intermediate buffer: %v", context.interBuffer)
}
}

View File

@@ -1,71 +0,0 @@
package ansiterm
type stateID int
type state interface {
Enter() error
Exit() error
Handle(byte) (state, error)
Name() string
Transition(state) error
}
type baseState struct {
name string
parser *AnsiParser
}
func (base baseState) Enter() error {
return nil
}
func (base baseState) Exit() error {
return nil
}
func (base baseState) Handle(b byte) (s state, e error) {
switch {
case b == CSI_ENTRY:
return base.parser.csiEntry, nil
case b == DCS_ENTRY:
return base.parser.dcsEntry, nil
case b == ANSI_ESCAPE_PRIMARY:
return base.parser.escape, nil
case b == OSC_STRING:
return base.parser.oscString, nil
case sliceContains(toGroundBytes, b):
return base.parser.ground, nil
}
return nil, nil
}
func (base baseState) Name() string {
return base.name
}
func (base baseState) Transition(s state) error {
if s == base.parser.ground {
execBytes := []byte{0x18}
execBytes = append(execBytes, 0x1A)
execBytes = append(execBytes, getByteRange(0x80, 0x8F)...)
execBytes = append(execBytes, getByteRange(0x91, 0x97)...)
execBytes = append(execBytes, 0x99)
execBytes = append(execBytes, 0x9A)
if sliceContains(execBytes, base.parser.context.currentChar) {
return base.parser.execute()
}
}
return nil
}
type dcsEntryState struct {
baseState
}
type errorState struct {
baseState
}

View File

@@ -1,173 +0,0 @@
package ansiterm
import (
"fmt"
"strconv"
)
type TestAnsiEventHandler struct {
FunctionCalls []string
}
func CreateTestAnsiEventHandler() *TestAnsiEventHandler {
evtHandler := TestAnsiEventHandler{}
evtHandler.FunctionCalls = make([]string, 0)
return &evtHandler
}
func (h *TestAnsiEventHandler) recordCall(call string, params []string) {
s := fmt.Sprintf("%s(%v)", call, params)
h.FunctionCalls = append(h.FunctionCalls, s)
}
func (h *TestAnsiEventHandler) Print(b byte) error {
h.recordCall("Print", []string{string(b)})
return nil
}
func (h *TestAnsiEventHandler) Execute(b byte) error {
h.recordCall("Execute", []string{string(b)})
return nil
}
func (h *TestAnsiEventHandler) CUU(param int) error {
h.recordCall("CUU", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CUD(param int) error {
h.recordCall("CUD", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CUF(param int) error {
h.recordCall("CUF", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CUB(param int) error {
h.recordCall("CUB", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CNL(param int) error {
h.recordCall("CNL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CPL(param int) error {
h.recordCall("CPL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CHA(param int) error {
h.recordCall("CHA", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) VPA(param int) error {
h.recordCall("VPA", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) CUP(x int, y int) error {
xS, yS := strconv.Itoa(x), strconv.Itoa(y)
h.recordCall("CUP", []string{xS, yS})
return nil
}
func (h *TestAnsiEventHandler) HVP(x int, y int) error {
xS, yS := strconv.Itoa(x), strconv.Itoa(y)
h.recordCall("HVP", []string{xS, yS})
return nil
}
func (h *TestAnsiEventHandler) DECTCEM(visible bool) error {
h.recordCall("DECTCEM", []string{strconv.FormatBool(visible)})
return nil
}
func (h *TestAnsiEventHandler) DECOM(visible bool) error {
h.recordCall("DECOM", []string{strconv.FormatBool(visible)})
return nil
}
func (h *TestAnsiEventHandler) DECCOLM(use132 bool) error {
h.recordCall("DECOLM", []string{strconv.FormatBool(use132)})
return nil
}
func (h *TestAnsiEventHandler) ED(param int) error {
h.recordCall("ED", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) EL(param int) error {
h.recordCall("EL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) IL(param int) error {
h.recordCall("IL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) DL(param int) error {
h.recordCall("DL", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) ICH(param int) error {
h.recordCall("ICH", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) DCH(param int) error {
h.recordCall("DCH", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) SGR(params []int) error {
strings := []string{}
for _, v := range params {
strings = append(strings, strconv.Itoa(v))
}
h.recordCall("SGR", strings)
return nil
}
func (h *TestAnsiEventHandler) SU(param int) error {
h.recordCall("SU", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) SD(param int) error {
h.recordCall("SD", []string{strconv.Itoa(param)})
return nil
}
func (h *TestAnsiEventHandler) DA(params []string) error {
h.recordCall("DA", params)
return nil
}
func (h *TestAnsiEventHandler) DECSTBM(top int, bottom int) error {
topS, bottomS := strconv.Itoa(top), strconv.Itoa(bottom)
h.recordCall("DECSTBM", []string{topS, bottomS})
return nil
}
func (h *TestAnsiEventHandler) RI() error {
h.recordCall("RI", nil)
return nil
}
func (h *TestAnsiEventHandler) IND() error {
h.recordCall("IND", nil)
return nil
}
func (h *TestAnsiEventHandler) Flush() error {
return nil
}

View File

@@ -1,21 +0,0 @@
package ansiterm
import (
"strconv"
)
func sliceContains(bytes []byte, b byte) bool {
for _, v := range bytes {
if v == b {
return true
}
}
return false
}
func convertBytesToInteger(bytes []byte) int {
s := string(bytes)
i, _ := strconv.Atoi(s)
return i
}

View File

@@ -1,182 +0,0 @@
// +build windows
package winterm
import (
"fmt"
"os"
"strconv"
"strings"
"syscall"
"github.com/Azure/go-ansiterm"
)
// Windows keyboard constants
// See https://msdn.microsoft.com/en-us/library/windows/desktop/dd375731(v=vs.85).aspx.
const (
VK_PRIOR = 0x21 // PAGE UP key
VK_NEXT = 0x22 // PAGE DOWN key
VK_END = 0x23 // END key
VK_HOME = 0x24 // HOME key
VK_LEFT = 0x25 // LEFT ARROW key
VK_UP = 0x26 // UP ARROW key
VK_RIGHT = 0x27 // RIGHT ARROW key
VK_DOWN = 0x28 // DOWN ARROW key
VK_SELECT = 0x29 // SELECT key
VK_PRINT = 0x2A // PRINT key
VK_EXECUTE = 0x2B // EXECUTE key
VK_SNAPSHOT = 0x2C // PRINT SCREEN key
VK_INSERT = 0x2D // INS key
VK_DELETE = 0x2E // DEL key
VK_HELP = 0x2F // HELP key
VK_F1 = 0x70 // F1 key
VK_F2 = 0x71 // F2 key
VK_F3 = 0x72 // F3 key
VK_F4 = 0x73 // F4 key
VK_F5 = 0x74 // F5 key
VK_F6 = 0x75 // F6 key
VK_F7 = 0x76 // F7 key
VK_F8 = 0x77 // F8 key
VK_F9 = 0x78 // F9 key
VK_F10 = 0x79 // F10 key
VK_F11 = 0x7A // F11 key
VK_F12 = 0x7B // F12 key
RIGHT_ALT_PRESSED = 0x0001
LEFT_ALT_PRESSED = 0x0002
RIGHT_CTRL_PRESSED = 0x0004
LEFT_CTRL_PRESSED = 0x0008
SHIFT_PRESSED = 0x0010
NUMLOCK_ON = 0x0020
SCROLLLOCK_ON = 0x0040
CAPSLOCK_ON = 0x0080
ENHANCED_KEY = 0x0100
)
type ansiCommand struct {
CommandBytes []byte
Command string
Parameters []string
IsSpecial bool
}
func newAnsiCommand(command []byte) *ansiCommand {
if isCharacterSelectionCmdChar(command[1]) {
// Is Character Set Selection commands
return &ansiCommand{
CommandBytes: command,
Command: string(command),
IsSpecial: true,
}
}
// last char is command character
lastCharIndex := len(command) - 1
ac := &ansiCommand{
CommandBytes: command,
Command: string(command[lastCharIndex]),
IsSpecial: false,
}
// more than a single escape
if lastCharIndex != 0 {
start := 1
// skip if double char escape sequence
if command[0] == ansiterm.ANSI_ESCAPE_PRIMARY && command[1] == ansiterm.ANSI_ESCAPE_SECONDARY {
start++
}
// convert this to GetNextParam method
ac.Parameters = strings.Split(string(command[start:lastCharIndex]), ansiterm.ANSI_PARAMETER_SEP)
}
return ac
}
func (ac *ansiCommand) paramAsSHORT(index int, defaultValue int16) int16 {
if index < 0 || index >= len(ac.Parameters) {
return defaultValue
}
param, err := strconv.ParseInt(ac.Parameters[index], 10, 16)
if err != nil {
return defaultValue
}
return int16(param)
}
func (ac *ansiCommand) String() string {
return fmt.Sprintf("0x%v \"%v\" (\"%v\")",
bytesToHex(ac.CommandBytes),
ac.Command,
strings.Join(ac.Parameters, "\",\""))
}
// isAnsiCommandChar returns true if the passed byte falls within the range of ANSI commands.
// See http://manpages.ubuntu.com/manpages/intrepid/man4/console_codes.4.html.
func isAnsiCommandChar(b byte) bool {
switch {
case ansiterm.ANSI_COMMAND_FIRST <= b && b <= ansiterm.ANSI_COMMAND_LAST && b != ansiterm.ANSI_ESCAPE_SECONDARY:
return true
case b == ansiterm.ANSI_CMD_G1 || b == ansiterm.ANSI_CMD_OSC || b == ansiterm.ANSI_CMD_DECPAM || b == ansiterm.ANSI_CMD_DECPNM:
// non-CSI escape sequence terminator
return true
case b == ansiterm.ANSI_CMD_STR_TERM || b == ansiterm.ANSI_BEL:
// String escape sequence terminator
return true
}
return false
}
func isXtermOscSequence(command []byte, current byte) bool {
return (len(command) >= 2 && command[0] == ansiterm.ANSI_ESCAPE_PRIMARY && command[1] == ansiterm.ANSI_CMD_OSC && current != ansiterm.ANSI_BEL)
}
func isCharacterSelectionCmdChar(b byte) bool {
return (b == ansiterm.ANSI_CMD_G0 || b == ansiterm.ANSI_CMD_G1 || b == ansiterm.ANSI_CMD_G2 || b == ansiterm.ANSI_CMD_G3)
}
// bytesToHex converts a slice of bytes to a human-readable string.
func bytesToHex(b []byte) string {
hex := make([]string, len(b))
for i, ch := range b {
hex[i] = fmt.Sprintf("%X", ch)
}
return strings.Join(hex, "")
}
// ensureInRange adjusts the passed value, if necessary, to ensure it is within
// the passed min / max range.
func ensureInRange(n int16, min int16, max int16) int16 {
if n < min {
return min
} else if n > max {
return max
} else {
return n
}
}
func GetStdFile(nFile int) (*os.File, uintptr) {
var file *os.File
switch nFile {
case syscall.STD_INPUT_HANDLE:
file = os.Stdin
case syscall.STD_OUTPUT_HANDLE:
file = os.Stdout
case syscall.STD_ERROR_HANDLE:
file = os.Stderr
default:
panic(fmt.Errorf("Invalid standard handle identifier: %v", nFile))
}
fd, err := syscall.GetStdHandle(nFile)
if err != nil {
panic(fmt.Errorf("Invalid standard handle identifier: %v -- %v", nFile, err))
}
return file, uintptr(fd)
}

View File

@@ -1,327 +0,0 @@
// +build windows
package winterm
import (
"fmt"
"syscall"
"unsafe"
)
//===========================================================================================================
// IMPORTANT NOTE:
//
// The methods below make extensive use of the "unsafe" package to obtain the required pointers.
// Beginning in Go 1.3, the garbage collector may release local variables (e.g., incoming arguments, stack
// variables) the pointers reference *before* the API completes.
//
// As a result, in those cases, the code must hint that the variables remain in active by invoking the
// dummy method "use" (see below). Newer versions of Go are planned to change the mechanism to no longer
// require unsafe pointers.
//
// If you add or modify methods, ENSURE protection of local variables through the "use" builtin to inform
// the garbage collector the variables remain in use if:
//
// -- The value is not a pointer (e.g., int32, struct)
// -- The value is not referenced by the method after passing the pointer to Windows
//
// See http://golang.org/doc/go1.3.
//===========================================================================================================
var (
kernel32DLL = syscall.NewLazyDLL("kernel32.dll")
getConsoleCursorInfoProc = kernel32DLL.NewProc("GetConsoleCursorInfo")
setConsoleCursorInfoProc = kernel32DLL.NewProc("SetConsoleCursorInfo")
setConsoleCursorPositionProc = kernel32DLL.NewProc("SetConsoleCursorPosition")
setConsoleModeProc = kernel32DLL.NewProc("SetConsoleMode")
getConsoleScreenBufferInfoProc = kernel32DLL.NewProc("GetConsoleScreenBufferInfo")
setConsoleScreenBufferSizeProc = kernel32DLL.NewProc("SetConsoleScreenBufferSize")
scrollConsoleScreenBufferProc = kernel32DLL.NewProc("ScrollConsoleScreenBufferA")
setConsoleTextAttributeProc = kernel32DLL.NewProc("SetConsoleTextAttribute")
setConsoleWindowInfoProc = kernel32DLL.NewProc("SetConsoleWindowInfo")
writeConsoleOutputProc = kernel32DLL.NewProc("WriteConsoleOutputW")
readConsoleInputProc = kernel32DLL.NewProc("ReadConsoleInputW")
waitForSingleObjectProc = kernel32DLL.NewProc("WaitForSingleObject")
)
// Windows Console constants
const (
// Console modes
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms686033(v=vs.85).aspx.
ENABLE_PROCESSED_INPUT = 0x0001
ENABLE_LINE_INPUT = 0x0002
ENABLE_ECHO_INPUT = 0x0004
ENABLE_WINDOW_INPUT = 0x0008
ENABLE_MOUSE_INPUT = 0x0010
ENABLE_INSERT_MODE = 0x0020
ENABLE_QUICK_EDIT_MODE = 0x0040
ENABLE_EXTENDED_FLAGS = 0x0080
ENABLE_AUTO_POSITION = 0x0100
ENABLE_VIRTUAL_TERMINAL_INPUT = 0x0200
ENABLE_PROCESSED_OUTPUT = 0x0001
ENABLE_WRAP_AT_EOL_OUTPUT = 0x0002
ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004
DISABLE_NEWLINE_AUTO_RETURN = 0x0008
ENABLE_LVB_GRID_WORLDWIDE = 0x0010
// Character attributes
// Note:
// -- The attributes are combined to produce various colors (e.g., Blue + Green will create Cyan).
// Clearing all foreground or background colors results in black; setting all creates white.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms682088(v=vs.85).aspx#_win32_character_attributes.
FOREGROUND_BLUE uint16 = 0x0001
FOREGROUND_GREEN uint16 = 0x0002
FOREGROUND_RED uint16 = 0x0004
FOREGROUND_INTENSITY uint16 = 0x0008
FOREGROUND_MASK uint16 = 0x000F
BACKGROUND_BLUE uint16 = 0x0010
BACKGROUND_GREEN uint16 = 0x0020
BACKGROUND_RED uint16 = 0x0040
BACKGROUND_INTENSITY uint16 = 0x0080
BACKGROUND_MASK uint16 = 0x00F0
COMMON_LVB_MASK uint16 = 0xFF00
COMMON_LVB_REVERSE_VIDEO uint16 = 0x4000
COMMON_LVB_UNDERSCORE uint16 = 0x8000
// Input event types
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms683499(v=vs.85).aspx.
KEY_EVENT = 0x0001
MOUSE_EVENT = 0x0002
WINDOW_BUFFER_SIZE_EVENT = 0x0004
MENU_EVENT = 0x0008
FOCUS_EVENT = 0x0010
// WaitForSingleObject return codes
WAIT_ABANDONED = 0x00000080
WAIT_FAILED = 0xFFFFFFFF
WAIT_SIGNALED = 0x0000000
WAIT_TIMEOUT = 0x00000102
// WaitForSingleObject wait duration
WAIT_INFINITE = 0xFFFFFFFF
WAIT_ONE_SECOND = 1000
WAIT_HALF_SECOND = 500
WAIT_QUARTER_SECOND = 250
)
// Windows API Console types
// -- See https://msdn.microsoft.com/en-us/library/windows/desktop/ms682101(v=vs.85).aspx for Console specific types (e.g., COORD)
// -- See https://msdn.microsoft.com/en-us/library/aa296569(v=vs.60).aspx for comments on alignment
type (
CHAR_INFO struct {
UnicodeChar uint16
Attributes uint16
}
CONSOLE_CURSOR_INFO struct {
Size uint32
Visible int32
}
CONSOLE_SCREEN_BUFFER_INFO struct {
Size COORD
CursorPosition COORD
Attributes uint16
Window SMALL_RECT
MaximumWindowSize COORD
}
COORD struct {
X int16
Y int16
}
SMALL_RECT struct {
Left int16
Top int16
Right int16
Bottom int16
}
// INPUT_RECORD is a C/C++ union of which KEY_EVENT_RECORD is one case, it is also the largest
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms683499(v=vs.85).aspx.
INPUT_RECORD struct {
EventType uint16
KeyEvent KEY_EVENT_RECORD
}
KEY_EVENT_RECORD struct {
KeyDown int32
RepeatCount uint16
VirtualKeyCode uint16
VirtualScanCode uint16
UnicodeChar uint16
ControlKeyState uint32
}
WINDOW_BUFFER_SIZE struct {
Size COORD
}
)
// boolToBOOL converts a Go bool into a Windows int32.
func boolToBOOL(f bool) int32 {
if f {
return int32(1)
} else {
return int32(0)
}
}
// GetConsoleCursorInfo retrieves information about the size and visiblity of the console cursor.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms683163(v=vs.85).aspx.
func GetConsoleCursorInfo(handle uintptr, cursorInfo *CONSOLE_CURSOR_INFO) error {
r1, r2, err := getConsoleCursorInfoProc.Call(handle, uintptr(unsafe.Pointer(cursorInfo)), 0)
return checkError(r1, r2, err)
}
// SetConsoleCursorInfo sets the size and visiblity of the console cursor.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms686019(v=vs.85).aspx.
func SetConsoleCursorInfo(handle uintptr, cursorInfo *CONSOLE_CURSOR_INFO) error {
r1, r2, err := setConsoleCursorInfoProc.Call(handle, uintptr(unsafe.Pointer(cursorInfo)), 0)
return checkError(r1, r2, err)
}
// SetConsoleCursorPosition location of the console cursor.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms686025(v=vs.85).aspx.
func SetConsoleCursorPosition(handle uintptr, coord COORD) error {
r1, r2, err := setConsoleCursorPositionProc.Call(handle, coordToPointer(coord))
use(coord)
return checkError(r1, r2, err)
}
// GetConsoleMode gets the console mode for given file descriptor
// See http://msdn.microsoft.com/en-us/library/windows/desktop/ms683167(v=vs.85).aspx.
func GetConsoleMode(handle uintptr) (mode uint32, err error) {
err = syscall.GetConsoleMode(syscall.Handle(handle), &mode)
return mode, err
}
// SetConsoleMode sets the console mode for given file descriptor
// See http://msdn.microsoft.com/en-us/library/windows/desktop/ms686033(v=vs.85).aspx.
func SetConsoleMode(handle uintptr, mode uint32) error {
r1, r2, err := setConsoleModeProc.Call(handle, uintptr(mode), 0)
use(mode)
return checkError(r1, r2, err)
}
// GetConsoleScreenBufferInfo retrieves information about the specified console screen buffer.
// See http://msdn.microsoft.com/en-us/library/windows/desktop/ms683171(v=vs.85).aspx.
func GetConsoleScreenBufferInfo(handle uintptr) (*CONSOLE_SCREEN_BUFFER_INFO, error) {
info := CONSOLE_SCREEN_BUFFER_INFO{}
err := checkError(getConsoleScreenBufferInfoProc.Call(handle, uintptr(unsafe.Pointer(&info)), 0))
if err != nil {
return nil, err
}
return &info, nil
}
func ScrollConsoleScreenBuffer(handle uintptr, scrollRect SMALL_RECT, clipRect SMALL_RECT, destOrigin COORD, char CHAR_INFO) error {
r1, r2, err := scrollConsoleScreenBufferProc.Call(handle, uintptr(unsafe.Pointer(&scrollRect)), uintptr(unsafe.Pointer(&clipRect)), coordToPointer(destOrigin), uintptr(unsafe.Pointer(&char)))
use(scrollRect)
use(clipRect)
use(destOrigin)
use(char)
return checkError(r1, r2, err)
}
// SetConsoleScreenBufferSize sets the size of the console screen buffer.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms686044(v=vs.85).aspx.
func SetConsoleScreenBufferSize(handle uintptr, coord COORD) error {
r1, r2, err := setConsoleScreenBufferSizeProc.Call(handle, coordToPointer(coord))
use(coord)
return checkError(r1, r2, err)
}
// SetConsoleTextAttribute sets the attributes of characters written to the
// console screen buffer by the WriteFile or WriteConsole function.
// See http://msdn.microsoft.com/en-us/library/windows/desktop/ms686047(v=vs.85).aspx.
func SetConsoleTextAttribute(handle uintptr, attribute uint16) error {
r1, r2, err := setConsoleTextAttributeProc.Call(handle, uintptr(attribute), 0)
use(attribute)
return checkError(r1, r2, err)
}
// SetConsoleWindowInfo sets the size and position of the console screen buffer's window.
// Note that the size and location must be within and no larger than the backing console screen buffer.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms686125(v=vs.85).aspx.
func SetConsoleWindowInfo(handle uintptr, isAbsolute bool, rect SMALL_RECT) error {
r1, r2, err := setConsoleWindowInfoProc.Call(handle, uintptr(boolToBOOL(isAbsolute)), uintptr(unsafe.Pointer(&rect)))
use(isAbsolute)
use(rect)
return checkError(r1, r2, err)
}
// WriteConsoleOutput writes the CHAR_INFOs from the provided buffer to the active console buffer.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms687404(v=vs.85).aspx.
func WriteConsoleOutput(handle uintptr, buffer []CHAR_INFO, bufferSize COORD, bufferCoord COORD, writeRegion *SMALL_RECT) error {
r1, r2, err := writeConsoleOutputProc.Call(handle, uintptr(unsafe.Pointer(&buffer[0])), coordToPointer(bufferSize), coordToPointer(bufferCoord), uintptr(unsafe.Pointer(writeRegion)))
use(buffer)
use(bufferSize)
use(bufferCoord)
return checkError(r1, r2, err)
}
// ReadConsoleInput reads (and removes) data from the console input buffer.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms684961(v=vs.85).aspx.
func ReadConsoleInput(handle uintptr, buffer []INPUT_RECORD, count *uint32) error {
r1, r2, err := readConsoleInputProc.Call(handle, uintptr(unsafe.Pointer(&buffer[0])), uintptr(len(buffer)), uintptr(unsafe.Pointer(count)))
use(buffer)
return checkError(r1, r2, err)
}
// WaitForSingleObject waits for the passed handle to be signaled.
// It returns true if the handle was signaled; false otherwise.
// See https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032(v=vs.85).aspx.
func WaitForSingleObject(handle uintptr, msWait uint32) (bool, error) {
r1, _, err := waitForSingleObjectProc.Call(handle, uintptr(uint32(msWait)))
switch r1 {
case WAIT_ABANDONED, WAIT_TIMEOUT:
return false, nil
case WAIT_SIGNALED:
return true, nil
}
use(msWait)
return false, err
}
// String helpers
func (info CONSOLE_SCREEN_BUFFER_INFO) String() string {
return fmt.Sprintf("Size(%v) Cursor(%v) Window(%v) Max(%v)", info.Size, info.CursorPosition, info.Window, info.MaximumWindowSize)
}
func (coord COORD) String() string {
return fmt.Sprintf("%v,%v", coord.X, coord.Y)
}
func (rect SMALL_RECT) String() string {
return fmt.Sprintf("(%v,%v),(%v,%v)", rect.Left, rect.Top, rect.Right, rect.Bottom)
}
// checkError evaluates the results of a Windows API call and returns the error if it failed.
func checkError(r1, r2 uintptr, err error) error {
// Windows APIs return non-zero to indicate success
if r1 != 0 {
return nil
}
// Return the error if provided, otherwise default to EINVAL
if err != nil {
return err
}
return syscall.EINVAL
}
// coordToPointer converts a COORD into a uintptr (by fooling the type system).
func coordToPointer(c COORD) uintptr {
// Note: This code assumes the two SHORTs are correctly laid out; the "cast" to uint32 is just to get a pointer to pass.
return uintptr(*((*uint32)(unsafe.Pointer(&c))))
}
// use is a no-op, but the compiler cannot see that it is.
// Calling use(p) ensures that p is kept live until that point.
func use(p interface{}) {}

View File

@@ -1,100 +0,0 @@
// +build windows
package winterm
import "github.com/Azure/go-ansiterm"
const (
FOREGROUND_COLOR_MASK = FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE
BACKGROUND_COLOR_MASK = BACKGROUND_RED | BACKGROUND_GREEN | BACKGROUND_BLUE
)
// collectAnsiIntoWindowsAttributes modifies the passed Windows text mode flags to reflect the
// request represented by the passed ANSI mode.
func collectAnsiIntoWindowsAttributes(windowsMode uint16, inverted bool, baseMode uint16, ansiMode int16) (uint16, bool) {
switch ansiMode {
// Mode styles
case ansiterm.ANSI_SGR_BOLD:
windowsMode = windowsMode | FOREGROUND_INTENSITY
case ansiterm.ANSI_SGR_DIM, ansiterm.ANSI_SGR_BOLD_DIM_OFF:
windowsMode &^= FOREGROUND_INTENSITY
case ansiterm.ANSI_SGR_UNDERLINE:
windowsMode = windowsMode | COMMON_LVB_UNDERSCORE
case ansiterm.ANSI_SGR_REVERSE:
inverted = true
case ansiterm.ANSI_SGR_REVERSE_OFF:
inverted = false
case ansiterm.ANSI_SGR_UNDERLINE_OFF:
windowsMode &^= COMMON_LVB_UNDERSCORE
// Foreground colors
case ansiterm.ANSI_SGR_FOREGROUND_DEFAULT:
windowsMode = (windowsMode &^ FOREGROUND_MASK) | (baseMode & FOREGROUND_MASK)
case ansiterm.ANSI_SGR_FOREGROUND_BLACK:
windowsMode = (windowsMode &^ FOREGROUND_COLOR_MASK)
case ansiterm.ANSI_SGR_FOREGROUND_RED:
windowsMode = (windowsMode &^ FOREGROUND_COLOR_MASK) | FOREGROUND_RED
case ansiterm.ANSI_SGR_FOREGROUND_GREEN:
windowsMode = (windowsMode &^ FOREGROUND_COLOR_MASK) | FOREGROUND_GREEN
case ansiterm.ANSI_SGR_FOREGROUND_YELLOW:
windowsMode = (windowsMode &^ FOREGROUND_COLOR_MASK) | FOREGROUND_RED | FOREGROUND_GREEN
case ansiterm.ANSI_SGR_FOREGROUND_BLUE:
windowsMode = (windowsMode &^ FOREGROUND_COLOR_MASK) | FOREGROUND_BLUE
case ansiterm.ANSI_SGR_FOREGROUND_MAGENTA:
windowsMode = (windowsMode &^ FOREGROUND_COLOR_MASK) | FOREGROUND_RED | FOREGROUND_BLUE
case ansiterm.ANSI_SGR_FOREGROUND_CYAN:
windowsMode = (windowsMode &^ FOREGROUND_COLOR_MASK) | FOREGROUND_GREEN | FOREGROUND_BLUE
case ansiterm.ANSI_SGR_FOREGROUND_WHITE:
windowsMode = (windowsMode &^ FOREGROUND_COLOR_MASK) | FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE
// Background colors
case ansiterm.ANSI_SGR_BACKGROUND_DEFAULT:
// Black with no intensity
windowsMode = (windowsMode &^ BACKGROUND_MASK) | (baseMode & BACKGROUND_MASK)
case ansiterm.ANSI_SGR_BACKGROUND_BLACK:
windowsMode = (windowsMode &^ BACKGROUND_COLOR_MASK)
case ansiterm.ANSI_SGR_BACKGROUND_RED:
windowsMode = (windowsMode &^ BACKGROUND_COLOR_MASK) | BACKGROUND_RED
case ansiterm.ANSI_SGR_BACKGROUND_GREEN:
windowsMode = (windowsMode &^ BACKGROUND_COLOR_MASK) | BACKGROUND_GREEN
case ansiterm.ANSI_SGR_BACKGROUND_YELLOW:
windowsMode = (windowsMode &^ BACKGROUND_COLOR_MASK) | BACKGROUND_RED | BACKGROUND_GREEN
case ansiterm.ANSI_SGR_BACKGROUND_BLUE:
windowsMode = (windowsMode &^ BACKGROUND_COLOR_MASK) | BACKGROUND_BLUE
case ansiterm.ANSI_SGR_BACKGROUND_MAGENTA:
windowsMode = (windowsMode &^ BACKGROUND_COLOR_MASK) | BACKGROUND_RED | BACKGROUND_BLUE
case ansiterm.ANSI_SGR_BACKGROUND_CYAN:
windowsMode = (windowsMode &^ BACKGROUND_COLOR_MASK) | BACKGROUND_GREEN | BACKGROUND_BLUE
case ansiterm.ANSI_SGR_BACKGROUND_WHITE:
windowsMode = (windowsMode &^ BACKGROUND_COLOR_MASK) | BACKGROUND_RED | BACKGROUND_GREEN | BACKGROUND_BLUE
}
return windowsMode, inverted
}
// invertAttributes inverts the foreground and background colors of a Windows attributes value
func invertAttributes(windowsMode uint16) uint16 {
return (COMMON_LVB_MASK & windowsMode) | ((FOREGROUND_MASK & windowsMode) << 4) | ((BACKGROUND_MASK & windowsMode) >> 4)
}

View File

@@ -1,101 +0,0 @@
// +build windows
package winterm
const (
horizontal = iota
vertical
)
func (h *windowsAnsiEventHandler) getCursorWindow(info *CONSOLE_SCREEN_BUFFER_INFO) SMALL_RECT {
if h.originMode {
sr := h.effectiveSr(info.Window)
return SMALL_RECT{
Top: sr.top,
Bottom: sr.bottom,
Left: 0,
Right: info.Size.X - 1,
}
} else {
return SMALL_RECT{
Top: info.Window.Top,
Bottom: info.Window.Bottom,
Left: 0,
Right: info.Size.X - 1,
}
}
}
// setCursorPosition sets the cursor to the specified position, bounded to the screen size
func (h *windowsAnsiEventHandler) setCursorPosition(position COORD, window SMALL_RECT) error {
position.X = ensureInRange(position.X, window.Left, window.Right)
position.Y = ensureInRange(position.Y, window.Top, window.Bottom)
err := SetConsoleCursorPosition(h.fd, position)
if err != nil {
return err
}
h.logf("Cursor position set: (%d, %d)", position.X, position.Y)
return err
}
func (h *windowsAnsiEventHandler) moveCursorVertical(param int) error {
return h.moveCursor(vertical, param)
}
func (h *windowsAnsiEventHandler) moveCursorHorizontal(param int) error {
return h.moveCursor(horizontal, param)
}
func (h *windowsAnsiEventHandler) moveCursor(moveMode int, param int) error {
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
position := info.CursorPosition
switch moveMode {
case horizontal:
position.X += int16(param)
case vertical:
position.Y += int16(param)
}
if err = h.setCursorPosition(position, h.getCursorWindow(info)); err != nil {
return err
}
return nil
}
func (h *windowsAnsiEventHandler) moveCursorLine(param int) error {
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
position := info.CursorPosition
position.X = 0
position.Y += int16(param)
if err = h.setCursorPosition(position, h.getCursorWindow(info)); err != nil {
return err
}
return nil
}
func (h *windowsAnsiEventHandler) moveCursorColumn(param int) error {
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
position := info.CursorPosition
position.X = int16(param) - 1
if err = h.setCursorPosition(position, h.getCursorWindow(info)); err != nil {
return err
}
return nil
}

View File

@@ -1,84 +0,0 @@
// +build windows
package winterm
import "github.com/Azure/go-ansiterm"
func (h *windowsAnsiEventHandler) clearRange(attributes uint16, fromCoord COORD, toCoord COORD) error {
// Ignore an invalid (negative area) request
if toCoord.Y < fromCoord.Y {
return nil
}
var err error
var coordStart = COORD{}
var coordEnd = COORD{}
xCurrent, yCurrent := fromCoord.X, fromCoord.Y
xEnd, yEnd := toCoord.X, toCoord.Y
// Clear any partial initial line
if xCurrent > 0 {
coordStart.X, coordStart.Y = xCurrent, yCurrent
coordEnd.X, coordEnd.Y = xEnd, yCurrent
err = h.clearRect(attributes, coordStart, coordEnd)
if err != nil {
return err
}
xCurrent = 0
yCurrent += 1
}
// Clear intervening rectangular section
if yCurrent < yEnd {
coordStart.X, coordStart.Y = xCurrent, yCurrent
coordEnd.X, coordEnd.Y = xEnd, yEnd-1
err = h.clearRect(attributes, coordStart, coordEnd)
if err != nil {
return err
}
xCurrent = 0
yCurrent = yEnd
}
// Clear remaining partial ending line
coordStart.X, coordStart.Y = xCurrent, yCurrent
coordEnd.X, coordEnd.Y = xEnd, yEnd
err = h.clearRect(attributes, coordStart, coordEnd)
if err != nil {
return err
}
return nil
}
func (h *windowsAnsiEventHandler) clearRect(attributes uint16, fromCoord COORD, toCoord COORD) error {
region := SMALL_RECT{Top: fromCoord.Y, Left: fromCoord.X, Bottom: toCoord.Y, Right: toCoord.X}
width := toCoord.X - fromCoord.X + 1
height := toCoord.Y - fromCoord.Y + 1
size := uint32(width) * uint32(height)
if size <= 0 {
return nil
}
buffer := make([]CHAR_INFO, size)
char := CHAR_INFO{ansiterm.FILL_CHARACTER, attributes}
for i := 0; i < int(size); i++ {
buffer[i] = char
}
err := WriteConsoleOutput(h.fd, buffer, COORD{X: width, Y: height}, COORD{X: 0, Y: 0}, &region)
if err != nil {
return err
}
return nil
}

View File

@@ -1,118 +0,0 @@
// +build windows
package winterm
// effectiveSr gets the current effective scroll region in buffer coordinates
func (h *windowsAnsiEventHandler) effectiveSr(window SMALL_RECT) scrollRegion {
top := addInRange(window.Top, h.sr.top, window.Top, window.Bottom)
bottom := addInRange(window.Top, h.sr.bottom, window.Top, window.Bottom)
if top >= bottom {
top = window.Top
bottom = window.Bottom
}
return scrollRegion{top: top, bottom: bottom}
}
func (h *windowsAnsiEventHandler) scrollUp(param int) error {
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
sr := h.effectiveSr(info.Window)
return h.scroll(param, sr, info)
}
func (h *windowsAnsiEventHandler) scrollDown(param int) error {
return h.scrollUp(-param)
}
func (h *windowsAnsiEventHandler) deleteLines(param int) error {
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
start := info.CursorPosition.Y
sr := h.effectiveSr(info.Window)
// Lines cannot be inserted or deleted outside the scrolling region.
if start >= sr.top && start <= sr.bottom {
sr.top = start
return h.scroll(param, sr, info)
} else {
return nil
}
}
func (h *windowsAnsiEventHandler) insertLines(param int) error {
return h.deleteLines(-param)
}
// scroll scrolls the provided scroll region by param lines. The scroll region is in buffer coordinates.
func (h *windowsAnsiEventHandler) scroll(param int, sr scrollRegion, info *CONSOLE_SCREEN_BUFFER_INFO) error {
h.logf("scroll: scrollTop: %d, scrollBottom: %d", sr.top, sr.bottom)
h.logf("scroll: windowTop: %d, windowBottom: %d", info.Window.Top, info.Window.Bottom)
// Copy from and clip to the scroll region (full buffer width)
scrollRect := SMALL_RECT{
Top: sr.top,
Bottom: sr.bottom,
Left: 0,
Right: info.Size.X - 1,
}
// Origin to which area should be copied
destOrigin := COORD{
X: 0,
Y: sr.top - int16(param),
}
char := CHAR_INFO{
UnicodeChar: ' ',
Attributes: h.attributes,
}
if err := ScrollConsoleScreenBuffer(h.fd, scrollRect, scrollRect, destOrigin, char); err != nil {
return err
}
return nil
}
func (h *windowsAnsiEventHandler) deleteCharacters(param int) error {
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
return h.scrollLine(param, info.CursorPosition, info)
}
func (h *windowsAnsiEventHandler) insertCharacters(param int) error {
return h.deleteCharacters(-param)
}
// scrollLine scrolls a line horizontally starting at the provided position by a number of columns.
func (h *windowsAnsiEventHandler) scrollLine(columns int, position COORD, info *CONSOLE_SCREEN_BUFFER_INFO) error {
// Copy from and clip to the scroll region (full buffer width)
scrollRect := SMALL_RECT{
Top: position.Y,
Bottom: position.Y,
Left: position.X,
Right: info.Size.X - 1,
}
// Origin to which area should be copied
destOrigin := COORD{
X: position.X - int16(columns),
Y: position.Y,
}
char := CHAR_INFO{
UnicodeChar: ' ',
Attributes: h.attributes,
}
if err := ScrollConsoleScreenBuffer(h.fd, scrollRect, scrollRect, destOrigin, char); err != nil {
return err
}
return nil
}

View File

@@ -1,9 +0,0 @@
// +build windows
package winterm
// AddInRange increments a value by the passed quantity while ensuring the values
// always remain within the supplied min / max range.
func addInRange(n int16, increment int16, min int16, max int16) int16 {
return ensureInRange(n+increment, min, max)
}

View File

@@ -1,743 +0,0 @@
// +build windows
package winterm
import (
"bytes"
"log"
"os"
"strconv"
"github.com/Azure/go-ansiterm"
)
type windowsAnsiEventHandler struct {
fd uintptr
file *os.File
infoReset *CONSOLE_SCREEN_BUFFER_INFO
sr scrollRegion
buffer bytes.Buffer
attributes uint16
inverted bool
wrapNext bool
drewMarginByte bool
originMode bool
marginByte byte
curInfo *CONSOLE_SCREEN_BUFFER_INFO
curPos COORD
logf func(string, ...interface{})
}
type Option func(*windowsAnsiEventHandler)
func WithLogf(f func(string, ...interface{})) Option {
return func(w *windowsAnsiEventHandler) {
w.logf = f
}
}
func CreateWinEventHandler(fd uintptr, file *os.File, opts ...Option) ansiterm.AnsiEventHandler {
infoReset, err := GetConsoleScreenBufferInfo(fd)
if err != nil {
return nil
}
h := &windowsAnsiEventHandler{
fd: fd,
file: file,
infoReset: infoReset,
attributes: infoReset.Attributes,
}
for _, o := range opts {
o(h)
}
if isDebugEnv := os.Getenv(ansiterm.LogEnv); isDebugEnv == "1" {
logFile, _ := os.Create("winEventHandler.log")
logger := log.New(logFile, "", log.LstdFlags)
if h.logf != nil {
l := h.logf
h.logf = func(s string, v ...interface{}) {
l(s, v...)
logger.Printf(s, v...)
}
} else {
h.logf = logger.Printf
}
}
if h.logf == nil {
h.logf = func(string, ...interface{}) {}
}
return h
}
type scrollRegion struct {
top int16
bottom int16
}
// simulateLF simulates a LF or CR+LF by scrolling if necessary to handle the
// current cursor position and scroll region settings, in which case it returns
// true. If no special handling is necessary, then it does nothing and returns
// false.
//
// In the false case, the caller should ensure that a carriage return
// and line feed are inserted or that the text is otherwise wrapped.
func (h *windowsAnsiEventHandler) simulateLF(includeCR bool) (bool, error) {
if h.wrapNext {
if err := h.Flush(); err != nil {
return false, err
}
h.clearWrap()
}
pos, info, err := h.getCurrentInfo()
if err != nil {
return false, err
}
sr := h.effectiveSr(info.Window)
if pos.Y == sr.bottom {
// Scrolling is necessary. Let Windows automatically scroll if the scrolling region
// is the full window.
if sr.top == info.Window.Top && sr.bottom == info.Window.Bottom {
if includeCR {
pos.X = 0
h.updatePos(pos)
}
return false, nil
}
// A custom scroll region is active. Scroll the window manually to simulate
// the LF.
if err := h.Flush(); err != nil {
return false, err
}
h.logf("Simulating LF inside scroll region")
if err := h.scrollUp(1); err != nil {
return false, err
}
if includeCR {
pos.X = 0
if err := SetConsoleCursorPosition(h.fd, pos); err != nil {
return false, err
}
}
return true, nil
} else if pos.Y < info.Window.Bottom {
// Let Windows handle the LF.
pos.Y++
if includeCR {
pos.X = 0
}
h.updatePos(pos)
return false, nil
} else {
// The cursor is at the bottom of the screen but outside the scroll
// region. Skip the LF.
h.logf("Simulating LF outside scroll region")
if includeCR {
if err := h.Flush(); err != nil {
return false, err
}
pos.X = 0
if err := SetConsoleCursorPosition(h.fd, pos); err != nil {
return false, err
}
}
return true, nil
}
}
// executeLF executes a LF without a CR.
func (h *windowsAnsiEventHandler) executeLF() error {
handled, err := h.simulateLF(false)
if err != nil {
return err
}
if !handled {
// Windows LF will reset the cursor column position. Write the LF
// and restore the cursor position.
pos, _, err := h.getCurrentInfo()
if err != nil {
return err
}
h.buffer.WriteByte(ansiterm.ANSI_LINE_FEED)
if pos.X != 0 {
if err := h.Flush(); err != nil {
return err
}
h.logf("Resetting cursor position for LF without CR")
if err := SetConsoleCursorPosition(h.fd, pos); err != nil {
return err
}
}
}
return nil
}
func (h *windowsAnsiEventHandler) Print(b byte) error {
if h.wrapNext {
h.buffer.WriteByte(h.marginByte)
h.clearWrap()
if _, err := h.simulateLF(true); err != nil {
return err
}
}
pos, info, err := h.getCurrentInfo()
if err != nil {
return err
}
if pos.X == info.Size.X-1 {
h.wrapNext = true
h.marginByte = b
} else {
pos.X++
h.updatePos(pos)
h.buffer.WriteByte(b)
}
return nil
}
func (h *windowsAnsiEventHandler) Execute(b byte) error {
switch b {
case ansiterm.ANSI_TAB:
h.logf("Execute(TAB)")
// Move to the next tab stop, but preserve auto-wrap if already set.
if !h.wrapNext {
pos, info, err := h.getCurrentInfo()
if err != nil {
return err
}
pos.X = (pos.X + 8) - pos.X%8
if pos.X >= info.Size.X {
pos.X = info.Size.X - 1
}
if err := h.Flush(); err != nil {
return err
}
if err := SetConsoleCursorPosition(h.fd, pos); err != nil {
return err
}
}
return nil
case ansiterm.ANSI_BEL:
h.buffer.WriteByte(ansiterm.ANSI_BEL)
return nil
case ansiterm.ANSI_BACKSPACE:
if h.wrapNext {
if err := h.Flush(); err != nil {
return err
}
h.clearWrap()
}
pos, _, err := h.getCurrentInfo()
if err != nil {
return err
}
if pos.X > 0 {
pos.X--
h.updatePos(pos)
h.buffer.WriteByte(ansiterm.ANSI_BACKSPACE)
}
return nil
case ansiterm.ANSI_VERTICAL_TAB, ansiterm.ANSI_FORM_FEED:
// Treat as true LF.
return h.executeLF()
case ansiterm.ANSI_LINE_FEED:
// Simulate a CR and LF for now since there is no way in go-ansiterm
// to tell if the LF should include CR (and more things break when it's
// missing than when it's incorrectly added).
handled, err := h.simulateLF(true)
if handled || err != nil {
return err
}
return h.buffer.WriteByte(ansiterm.ANSI_LINE_FEED)
case ansiterm.ANSI_CARRIAGE_RETURN:
if h.wrapNext {
if err := h.Flush(); err != nil {
return err
}
h.clearWrap()
}
pos, _, err := h.getCurrentInfo()
if err != nil {
return err
}
if pos.X != 0 {
pos.X = 0
h.updatePos(pos)
h.buffer.WriteByte(ansiterm.ANSI_CARRIAGE_RETURN)
}
return nil
default:
return nil
}
}
func (h *windowsAnsiEventHandler) CUU(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("CUU: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
return h.moveCursorVertical(-param)
}
func (h *windowsAnsiEventHandler) CUD(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("CUD: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
return h.moveCursorVertical(param)
}
func (h *windowsAnsiEventHandler) CUF(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("CUF: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
return h.moveCursorHorizontal(param)
}
func (h *windowsAnsiEventHandler) CUB(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("CUB: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
return h.moveCursorHorizontal(-param)
}
func (h *windowsAnsiEventHandler) CNL(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("CNL: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
return h.moveCursorLine(param)
}
func (h *windowsAnsiEventHandler) CPL(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("CPL: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
return h.moveCursorLine(-param)
}
func (h *windowsAnsiEventHandler) CHA(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("CHA: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
return h.moveCursorColumn(param)
}
func (h *windowsAnsiEventHandler) VPA(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("VPA: [[%d]]", param)
h.clearWrap()
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
window := h.getCursorWindow(info)
position := info.CursorPosition
position.Y = window.Top + int16(param) - 1
return h.setCursorPosition(position, window)
}
func (h *windowsAnsiEventHandler) CUP(row int, col int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("CUP: [[%d %d]]", row, col)
h.clearWrap()
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
window := h.getCursorWindow(info)
position := COORD{window.Left + int16(col) - 1, window.Top + int16(row) - 1}
return h.setCursorPosition(position, window)
}
func (h *windowsAnsiEventHandler) HVP(row int, col int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("HVP: [[%d %d]]", row, col)
h.clearWrap()
return h.CUP(row, col)
}
func (h *windowsAnsiEventHandler) DECTCEM(visible bool) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("DECTCEM: [%v]", []string{strconv.FormatBool(visible)})
h.clearWrap()
return nil
}
func (h *windowsAnsiEventHandler) DECOM(enable bool) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("DECOM: [%v]", []string{strconv.FormatBool(enable)})
h.clearWrap()
h.originMode = enable
return h.CUP(1, 1)
}
func (h *windowsAnsiEventHandler) DECCOLM(use132 bool) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("DECCOLM: [%v]", []string{strconv.FormatBool(use132)})
h.clearWrap()
if err := h.ED(2); err != nil {
return err
}
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
targetWidth := int16(80)
if use132 {
targetWidth = 132
}
if info.Size.X < targetWidth {
if err := SetConsoleScreenBufferSize(h.fd, COORD{targetWidth, info.Size.Y}); err != nil {
h.logf("set buffer failed: %v", err)
return err
}
}
window := info.Window
window.Left = 0
window.Right = targetWidth - 1
if err := SetConsoleWindowInfo(h.fd, true, window); err != nil {
h.logf("set window failed: %v", err)
return err
}
if info.Size.X > targetWidth {
if err := SetConsoleScreenBufferSize(h.fd, COORD{targetWidth, info.Size.Y}); err != nil {
h.logf("set buffer failed: %v", err)
return err
}
}
return SetConsoleCursorPosition(h.fd, COORD{0, 0})
}
func (h *windowsAnsiEventHandler) ED(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("ED: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
// [J -- Erases from the cursor to the end of the screen, including the cursor position.
// [1J -- Erases from the beginning of the screen to the cursor, including the cursor position.
// [2J -- Erases the complete display. The cursor does not move.
// Notes:
// -- Clearing the entire buffer, versus just the Window, works best for Windows Consoles
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
var start COORD
var end COORD
switch param {
case 0:
start = info.CursorPosition
end = COORD{info.Size.X - 1, info.Size.Y - 1}
case 1:
start = COORD{0, 0}
end = info.CursorPosition
case 2:
start = COORD{0, 0}
end = COORD{info.Size.X - 1, info.Size.Y - 1}
}
err = h.clearRange(h.attributes, start, end)
if err != nil {
return err
}
// If the whole buffer was cleared, move the window to the top while preserving
// the window-relative cursor position.
if param == 2 {
pos := info.CursorPosition
window := info.Window
pos.Y -= window.Top
window.Bottom -= window.Top
window.Top = 0
if err := SetConsoleCursorPosition(h.fd, pos); err != nil {
return err
}
if err := SetConsoleWindowInfo(h.fd, true, window); err != nil {
return err
}
}
return nil
}
func (h *windowsAnsiEventHandler) EL(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("EL: [%v]", strconv.Itoa(param))
h.clearWrap()
// [K -- Erases from the cursor to the end of the line, including the cursor position.
// [1K -- Erases from the beginning of the line to the cursor, including the cursor position.
// [2K -- Erases the complete line.
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
var start COORD
var end COORD
switch param {
case 0:
start = info.CursorPosition
end = COORD{info.Size.X, info.CursorPosition.Y}
case 1:
start = COORD{0, info.CursorPosition.Y}
end = info.CursorPosition
case 2:
start = COORD{0, info.CursorPosition.Y}
end = COORD{info.Size.X, info.CursorPosition.Y}
}
err = h.clearRange(h.attributes, start, end)
if err != nil {
return err
}
return nil
}
func (h *windowsAnsiEventHandler) IL(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("IL: [%v]", strconv.Itoa(param))
h.clearWrap()
return h.insertLines(param)
}
func (h *windowsAnsiEventHandler) DL(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("DL: [%v]", strconv.Itoa(param))
h.clearWrap()
return h.deleteLines(param)
}
func (h *windowsAnsiEventHandler) ICH(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("ICH: [%v]", strconv.Itoa(param))
h.clearWrap()
return h.insertCharacters(param)
}
func (h *windowsAnsiEventHandler) DCH(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("DCH: [%v]", strconv.Itoa(param))
h.clearWrap()
return h.deleteCharacters(param)
}
func (h *windowsAnsiEventHandler) SGR(params []int) error {
if err := h.Flush(); err != nil {
return err
}
strings := []string{}
for _, v := range params {
strings = append(strings, strconv.Itoa(v))
}
h.logf("SGR: [%v]", strings)
if len(params) <= 0 {
h.attributes = h.infoReset.Attributes
h.inverted = false
} else {
for _, attr := range params {
if attr == ansiterm.ANSI_SGR_RESET {
h.attributes = h.infoReset.Attributes
h.inverted = false
continue
}
h.attributes, h.inverted = collectAnsiIntoWindowsAttributes(h.attributes, h.inverted, h.infoReset.Attributes, int16(attr))
}
}
attributes := h.attributes
if h.inverted {
attributes = invertAttributes(attributes)
}
err := SetConsoleTextAttribute(h.fd, attributes)
if err != nil {
return err
}
return nil
}
func (h *windowsAnsiEventHandler) SU(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("SU: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
return h.scrollUp(param)
}
func (h *windowsAnsiEventHandler) SD(param int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("SD: [%v]", []string{strconv.Itoa(param)})
h.clearWrap()
return h.scrollDown(param)
}
func (h *windowsAnsiEventHandler) DA(params []string) error {
h.logf("DA: [%v]", params)
// DA cannot be implemented because it must send data on the VT100 input stream,
// which is not available to go-ansiterm.
return nil
}
func (h *windowsAnsiEventHandler) DECSTBM(top int, bottom int) error {
if err := h.Flush(); err != nil {
return err
}
h.logf("DECSTBM: [%d, %d]", top, bottom)
// Windows is 0 indexed, Linux is 1 indexed
h.sr.top = int16(top - 1)
h.sr.bottom = int16(bottom - 1)
// This command also moves the cursor to the origin.
h.clearWrap()
return h.CUP(1, 1)
}
func (h *windowsAnsiEventHandler) RI() error {
if err := h.Flush(); err != nil {
return err
}
h.logf("RI: []")
h.clearWrap()
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
sr := h.effectiveSr(info.Window)
if info.CursorPosition.Y == sr.top {
return h.scrollDown(1)
}
return h.moveCursorVertical(-1)
}
func (h *windowsAnsiEventHandler) IND() error {
h.logf("IND: []")
return h.executeLF()
}
func (h *windowsAnsiEventHandler) Flush() error {
h.curInfo = nil
if h.buffer.Len() > 0 {
h.logf("Flush: [%s]", h.buffer.Bytes())
if _, err := h.buffer.WriteTo(h.file); err != nil {
return err
}
}
if h.wrapNext && !h.drewMarginByte {
h.logf("Flush: drawing margin byte '%c'", h.marginByte)
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return err
}
charInfo := []CHAR_INFO{{UnicodeChar: uint16(h.marginByte), Attributes: info.Attributes}}
size := COORD{1, 1}
position := COORD{0, 0}
region := SMALL_RECT{Left: info.CursorPosition.X, Top: info.CursorPosition.Y, Right: info.CursorPosition.X, Bottom: info.CursorPosition.Y}
if err := WriteConsoleOutput(h.fd, charInfo, size, position, &region); err != nil {
return err
}
h.drewMarginByte = true
}
return nil
}
// cacheConsoleInfo ensures that the current console screen information has been queried
// since the last call to Flush(). It must be called before accessing h.curInfo or h.curPos.
func (h *windowsAnsiEventHandler) getCurrentInfo() (COORD, *CONSOLE_SCREEN_BUFFER_INFO, error) {
if h.curInfo == nil {
info, err := GetConsoleScreenBufferInfo(h.fd)
if err != nil {
return COORD{}, nil, err
}
h.curInfo = info
h.curPos = info.CursorPosition
}
return h.curPos, h.curInfo, nil
}
func (h *windowsAnsiEventHandler) updatePos(pos COORD) {
if h.curInfo == nil {
panic("failed to call getCurrentInfo before calling updatePos")
}
h.curPos = pos
}
// clearWrap clears the state where the cursor is in the margin
// waiting for the next character before wrapping the line. This must
// be done before most operations that act on the cursor.
func (h *windowsAnsiEventHandler) clearWrap() {
h.wrapNext = false
h.drewMarginByte = false
}

View File

@@ -1 +0,0 @@
*.exe

View File

@@ -1,22 +0,0 @@
The MIT License (MIT)
Copyright (c) 2015 Microsoft
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,22 +0,0 @@
# go-winio
This repository contains utilities for efficiently performing Win32 IO operations in
Go. Currently, this is focused on accessing named pipes and other file handles, and
for using named pipes as a net transport.
This code relies on IO completion ports to avoid blocking IO on system threads, allowing Go
to reuse the thread to schedule another goroutine. This limits support to Windows Vista and
newer operating systems. This is similar to the implementation of network sockets in Go's net
package.
Please see the LICENSE file for licensing information.
This project has adopted the [Microsoft Open Source Code of
Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
see the [Code of Conduct
FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact
[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional
questions or comments.
Thanks to natefinch for the inspiration for this library. See https://github.com/natefinch/npipe
for another named pipe implementation.

View File

@@ -1,27 +0,0 @@
Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,344 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tar implements access to tar archives.
// It aims to cover most of the variations, including those produced
// by GNU and BSD tars.
//
// References:
// http://www.freebsd.org/cgi/man.cgi?query=tar&sektion=5
// http://www.gnu.org/software/tar/manual/html_node/Standard.html
// http://pubs.opengroup.org/onlinepubs/9699919799/utilities/pax.html
package tar
import (
"bytes"
"errors"
"fmt"
"os"
"path"
"time"
)
const (
blockSize = 512
// Types
TypeReg = '0' // regular file
TypeRegA = '\x00' // regular file
TypeLink = '1' // hard link
TypeSymlink = '2' // symbolic link
TypeChar = '3' // character device node
TypeBlock = '4' // block device node
TypeDir = '5' // directory
TypeFifo = '6' // fifo node
TypeCont = '7' // reserved
TypeXHeader = 'x' // extended header
TypeXGlobalHeader = 'g' // global extended header
TypeGNULongName = 'L' // Next file has a long name
TypeGNULongLink = 'K' // Next file symlinks to a file w/ a long name
TypeGNUSparse = 'S' // sparse file
)
// A Header represents a single header in a tar archive.
// Some fields may not be populated.
type Header struct {
Name string // name of header file entry
Mode int64 // permission and mode bits
Uid int // user id of owner
Gid int // group id of owner
Size int64 // length in bytes
ModTime time.Time // modified time
Typeflag byte // type of header entry
Linkname string // target name of link
Uname string // user name of owner
Gname string // group name of owner
Devmajor int64 // major number of character or block device
Devminor int64 // minor number of character or block device
AccessTime time.Time // access time
ChangeTime time.Time // status change time
CreationTime time.Time // creation time
Xattrs map[string]string
Winheaders map[string]string
}
// File name constants from the tar spec.
const (
fileNameSize = 100 // Maximum number of bytes in a standard tar name.
fileNamePrefixSize = 155 // Maximum number of ustar extension bytes.
)
// FileInfo returns an os.FileInfo for the Header.
func (h *Header) FileInfo() os.FileInfo {
return headerFileInfo{h}
}
// headerFileInfo implements os.FileInfo.
type headerFileInfo struct {
h *Header
}
func (fi headerFileInfo) Size() int64 { return fi.h.Size }
func (fi headerFileInfo) IsDir() bool { return fi.Mode().IsDir() }
func (fi headerFileInfo) ModTime() time.Time { return fi.h.ModTime }
func (fi headerFileInfo) Sys() interface{} { return fi.h }
// Name returns the base name of the file.
func (fi headerFileInfo) Name() string {
if fi.IsDir() {
return path.Base(path.Clean(fi.h.Name))
}
return path.Base(fi.h.Name)
}
// Mode returns the permission and mode bits for the headerFileInfo.
func (fi headerFileInfo) Mode() (mode os.FileMode) {
// Set file permission bits.
mode = os.FileMode(fi.h.Mode).Perm()
// Set setuid, setgid and sticky bits.
if fi.h.Mode&c_ISUID != 0 {
// setuid
mode |= os.ModeSetuid
}
if fi.h.Mode&c_ISGID != 0 {
// setgid
mode |= os.ModeSetgid
}
if fi.h.Mode&c_ISVTX != 0 {
// sticky
mode |= os.ModeSticky
}
// Set file mode bits.
// clear perm, setuid, setgid and sticky bits.
m := os.FileMode(fi.h.Mode) &^ 07777
if m == c_ISDIR {
// directory
mode |= os.ModeDir
}
if m == c_ISFIFO {
// named pipe (FIFO)
mode |= os.ModeNamedPipe
}
if m == c_ISLNK {
// symbolic link
mode |= os.ModeSymlink
}
if m == c_ISBLK {
// device file
mode |= os.ModeDevice
}
if m == c_ISCHR {
// Unix character device
mode |= os.ModeDevice
mode |= os.ModeCharDevice
}
if m == c_ISSOCK {
// Unix domain socket
mode |= os.ModeSocket
}
switch fi.h.Typeflag {
case TypeSymlink:
// symbolic link
mode |= os.ModeSymlink
case TypeChar:
// character device node
mode |= os.ModeDevice
mode |= os.ModeCharDevice
case TypeBlock:
// block device node
mode |= os.ModeDevice
case TypeDir:
// directory
mode |= os.ModeDir
case TypeFifo:
// fifo node
mode |= os.ModeNamedPipe
}
return mode
}
// sysStat, if non-nil, populates h from system-dependent fields of fi.
var sysStat func(fi os.FileInfo, h *Header) error
// Mode constants from the tar spec.
const (
c_ISUID = 04000 // Set uid
c_ISGID = 02000 // Set gid
c_ISVTX = 01000 // Save text (sticky bit)
c_ISDIR = 040000 // Directory
c_ISFIFO = 010000 // FIFO
c_ISREG = 0100000 // Regular file
c_ISLNK = 0120000 // Symbolic link
c_ISBLK = 060000 // Block special file
c_ISCHR = 020000 // Character special file
c_ISSOCK = 0140000 // Socket
)
// Keywords for the PAX Extended Header
const (
paxAtime = "atime"
paxCharset = "charset"
paxComment = "comment"
paxCtime = "ctime" // please note that ctime is not a valid pax header.
paxCreationTime = "LIBARCHIVE.creationtime"
paxGid = "gid"
paxGname = "gname"
paxLinkpath = "linkpath"
paxMtime = "mtime"
paxPath = "path"
paxSize = "size"
paxUid = "uid"
paxUname = "uname"
paxXattr = "SCHILY.xattr."
paxWindows = "MSWINDOWS."
paxNone = ""
)
// FileInfoHeader creates a partially-populated Header from fi.
// If fi describes a symlink, FileInfoHeader records link as the link target.
// If fi describes a directory, a slash is appended to the name.
// Because os.FileInfo's Name method returns only the base name of
// the file it describes, it may be necessary to modify the Name field
// of the returned header to provide the full path name of the file.
func FileInfoHeader(fi os.FileInfo, link string) (*Header, error) {
if fi == nil {
return nil, errors.New("tar: FileInfo is nil")
}
fm := fi.Mode()
h := &Header{
Name: fi.Name(),
ModTime: fi.ModTime(),
Mode: int64(fm.Perm()), // or'd with c_IS* constants later
}
switch {
case fm.IsRegular():
h.Mode |= c_ISREG
h.Typeflag = TypeReg
h.Size = fi.Size()
case fi.IsDir():
h.Typeflag = TypeDir
h.Mode |= c_ISDIR
h.Name += "/"
case fm&os.ModeSymlink != 0:
h.Typeflag = TypeSymlink
h.Mode |= c_ISLNK
h.Linkname = link
case fm&os.ModeDevice != 0:
if fm&os.ModeCharDevice != 0 {
h.Mode |= c_ISCHR
h.Typeflag = TypeChar
} else {
h.Mode |= c_ISBLK
h.Typeflag = TypeBlock
}
case fm&os.ModeNamedPipe != 0:
h.Typeflag = TypeFifo
h.Mode |= c_ISFIFO
case fm&os.ModeSocket != 0:
h.Mode |= c_ISSOCK
default:
return nil, fmt.Errorf("archive/tar: unknown file mode %v", fm)
}
if fm&os.ModeSetuid != 0 {
h.Mode |= c_ISUID
}
if fm&os.ModeSetgid != 0 {
h.Mode |= c_ISGID
}
if fm&os.ModeSticky != 0 {
h.Mode |= c_ISVTX
}
// If possible, populate additional fields from OS-specific
// FileInfo fields.
if sys, ok := fi.Sys().(*Header); ok {
// This FileInfo came from a Header (not the OS). Use the
// original Header to populate all remaining fields.
h.Uid = sys.Uid
h.Gid = sys.Gid
h.Uname = sys.Uname
h.Gname = sys.Gname
h.AccessTime = sys.AccessTime
h.ChangeTime = sys.ChangeTime
if sys.Xattrs != nil {
h.Xattrs = make(map[string]string)
for k, v := range sys.Xattrs {
h.Xattrs[k] = v
}
}
if sys.Typeflag == TypeLink {
// hard link
h.Typeflag = TypeLink
h.Size = 0
h.Linkname = sys.Linkname
}
}
if sysStat != nil {
return h, sysStat(fi, h)
}
return h, nil
}
var zeroBlock = make([]byte, blockSize)
// POSIX specifies a sum of the unsigned byte values, but the Sun tar uses signed byte values.
// We compute and return both.
func checksum(header []byte) (unsigned int64, signed int64) {
for i := 0; i < len(header); i++ {
if i == 148 {
// The chksum field (header[148:156]) is special: it should be treated as space bytes.
unsigned += ' ' * 8
signed += ' ' * 8
i += 7
continue
}
unsigned += int64(header[i])
signed += int64(int8(header[i]))
}
return
}
type slicer []byte
func (sp *slicer) next(n int) (b []byte) {
s := *sp
b, *sp = s[0:n], s[n:]
return
}
func isASCII(s string) bool {
for _, c := range s {
if c >= 0x80 {
return false
}
}
return true
}
func toASCII(s string) string {
if isASCII(s) {
return s
}
var buf bytes.Buffer
for _, c := range s {
if c < 0x80 {
buf.WriteByte(byte(c))
}
}
return buf.String()
}
// isHeaderOnlyType checks if the given type flag is of the type that has no
// data section even if a size is specified.
func isHeaderOnlyType(flag byte) bool {
switch flag {
case TypeLink, TypeSymlink, TypeChar, TypeBlock, TypeDir, TypeFifo:
return true
default:
return false
}
}

View File

@@ -1,80 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar_test
import (
"archive/tar"
"bytes"
"fmt"
"io"
"log"
"os"
)
func Example() {
// Create a buffer to write our archive to.
buf := new(bytes.Buffer)
// Create a new tar archive.
tw := tar.NewWriter(buf)
// Add some files to the archive.
var files = []struct {
Name, Body string
}{
{"readme.txt", "This archive contains some text files."},
{"gopher.txt", "Gopher names:\nGeorge\nGeoffrey\nGonzo"},
{"todo.txt", "Get animal handling license."},
}
for _, file := range files {
hdr := &tar.Header{
Name: file.Name,
Mode: 0600,
Size: int64(len(file.Body)),
}
if err := tw.WriteHeader(hdr); err != nil {
log.Fatalln(err)
}
if _, err := tw.Write([]byte(file.Body)); err != nil {
log.Fatalln(err)
}
}
// Make sure to check the error on Close.
if err := tw.Close(); err != nil {
log.Fatalln(err)
}
// Open the tar archive for reading.
r := bytes.NewReader(buf.Bytes())
tr := tar.NewReader(r)
// Iterate through the files in the archive.
for {
hdr, err := tr.Next()
if err == io.EOF {
// end of tar archive
break
}
if err != nil {
log.Fatalln(err)
}
fmt.Printf("Contents of %s:\n", hdr.Name)
if _, err := io.Copy(os.Stdout, tr); err != nil {
log.Fatalln(err)
}
fmt.Println()
}
// Output:
// Contents of readme.txt:
// This archive contains some text files.
// Contents of gopher.txt:
// Gopher names:
// George
// Geoffrey
// Gonzo
// Contents of todo.txt:
// Get animal handling license.
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,20 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build linux dragonfly openbsd solaris
package tar
import (
"syscall"
"time"
)
func statAtime(st *syscall.Stat_t) time.Time {
return time.Unix(st.Atim.Unix())
}
func statCtime(st *syscall.Stat_t) time.Time {
return time.Unix(st.Ctim.Unix())
}

View File

@@ -1,20 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin freebsd netbsd
package tar
import (
"syscall"
"time"
)
func statAtime(st *syscall.Stat_t) time.Time {
return time.Unix(st.Atimespec.Unix())
}
func statCtime(st *syscall.Stat_t) time.Time {
return time.Unix(st.Ctimespec.Unix())
}

View File

@@ -1,32 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build linux darwin dragonfly freebsd openbsd netbsd solaris
package tar
import (
"os"
"syscall"
)
func init() {
sysStat = statUnix
}
func statUnix(fi os.FileInfo, h *Header) error {
sys, ok := fi.Sys().(*syscall.Stat_t)
if !ok {
return nil
}
h.Uid = int(sys.Uid)
h.Gid = int(sys.Gid)
// TODO(bradfitz): populate username & group. os/user
// doesn't cache LookupId lookups, and lacks group
// lookup functions.
h.AccessTime = statAtime(sys)
h.ChangeTime = statCtime(sys)
// TODO(bradfitz): major/minor device numbers?
return nil
}

View File

@@ -1,325 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import (
"bytes"
"io/ioutil"
"os"
"path"
"reflect"
"strings"
"testing"
"time"
)
func TestFileInfoHeader(t *testing.T) {
fi, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
h, err := FileInfoHeader(fi, "")
if err != nil {
t.Fatalf("FileInfoHeader: %v", err)
}
if g, e := h.Name, "small.txt"; g != e {
t.Errorf("Name = %q; want %q", g, e)
}
if g, e := h.Mode, int64(fi.Mode().Perm())|c_ISREG; g != e {
t.Errorf("Mode = %#o; want %#o", g, e)
}
if g, e := h.Size, int64(5); g != e {
t.Errorf("Size = %v; want %v", g, e)
}
if g, e := h.ModTime, fi.ModTime(); !g.Equal(e) {
t.Errorf("ModTime = %v; want %v", g, e)
}
// FileInfoHeader should error when passing nil FileInfo
if _, err := FileInfoHeader(nil, ""); err == nil {
t.Fatalf("Expected error when passing nil to FileInfoHeader")
}
}
func TestFileInfoHeaderDir(t *testing.T) {
fi, err := os.Stat("testdata")
if err != nil {
t.Fatal(err)
}
h, err := FileInfoHeader(fi, "")
if err != nil {
t.Fatalf("FileInfoHeader: %v", err)
}
if g, e := h.Name, "testdata/"; g != e {
t.Errorf("Name = %q; want %q", g, e)
}
// Ignoring c_ISGID for golang.org/issue/4867
if g, e := h.Mode&^c_ISGID, int64(fi.Mode().Perm())|c_ISDIR; g != e {
t.Errorf("Mode = %#o; want %#o", g, e)
}
if g, e := h.Size, int64(0); g != e {
t.Errorf("Size = %v; want %v", g, e)
}
if g, e := h.ModTime, fi.ModTime(); !g.Equal(e) {
t.Errorf("ModTime = %v; want %v", g, e)
}
}
func TestFileInfoHeaderSymlink(t *testing.T) {
h, err := FileInfoHeader(symlink{}, "some-target")
if err != nil {
t.Fatal(err)
}
if g, e := h.Name, "some-symlink"; g != e {
t.Errorf("Name = %q; want %q", g, e)
}
if g, e := h.Linkname, "some-target"; g != e {
t.Errorf("Linkname = %q; want %q", g, e)
}
}
type symlink struct{}
func (symlink) Name() string { return "some-symlink" }
func (symlink) Size() int64 { return 0 }
func (symlink) Mode() os.FileMode { return os.ModeSymlink }
func (symlink) ModTime() time.Time { return time.Time{} }
func (symlink) IsDir() bool { return false }
func (symlink) Sys() interface{} { return nil }
func TestRoundTrip(t *testing.T) {
data := []byte("some file contents")
var b bytes.Buffer
tw := NewWriter(&b)
hdr := &Header{
Name: "file.txt",
Uid: 1 << 21, // too big for 8 octal digits
Size: int64(len(data)),
ModTime: time.Now(),
}
// tar only supports second precision.
hdr.ModTime = hdr.ModTime.Add(-time.Duration(hdr.ModTime.Nanosecond()) * time.Nanosecond)
if err := tw.WriteHeader(hdr); err != nil {
t.Fatalf("tw.WriteHeader: %v", err)
}
if _, err := tw.Write(data); err != nil {
t.Fatalf("tw.Write: %v", err)
}
if err := tw.Close(); err != nil {
t.Fatalf("tw.Close: %v", err)
}
// Read it back.
tr := NewReader(&b)
rHdr, err := tr.Next()
if err != nil {
t.Fatalf("tr.Next: %v", err)
}
if !reflect.DeepEqual(rHdr, hdr) {
t.Errorf("Header mismatch.\n got %+v\nwant %+v", rHdr, hdr)
}
rData, err := ioutil.ReadAll(tr)
if err != nil {
t.Fatalf("Read: %v", err)
}
if !bytes.Equal(rData, data) {
t.Errorf("Data mismatch.\n got %q\nwant %q", rData, data)
}
}
type headerRoundTripTest struct {
h *Header
fm os.FileMode
}
func TestHeaderRoundTrip(t *testing.T) {
golden := []headerRoundTripTest{
// regular file.
{
h: &Header{
Name: "test.txt",
Mode: 0644 | c_ISREG,
Size: 12,
ModTime: time.Unix(1360600916, 0),
Typeflag: TypeReg,
},
fm: 0644,
},
// symbolic link.
{
h: &Header{
Name: "link.txt",
Mode: 0777 | c_ISLNK,
Size: 0,
ModTime: time.Unix(1360600852, 0),
Typeflag: TypeSymlink,
},
fm: 0777 | os.ModeSymlink,
},
// character device node.
{
h: &Header{
Name: "dev/null",
Mode: 0666 | c_ISCHR,
Size: 0,
ModTime: time.Unix(1360578951, 0),
Typeflag: TypeChar,
},
fm: 0666 | os.ModeDevice | os.ModeCharDevice,
},
// block device node.
{
h: &Header{
Name: "dev/sda",
Mode: 0660 | c_ISBLK,
Size: 0,
ModTime: time.Unix(1360578954, 0),
Typeflag: TypeBlock,
},
fm: 0660 | os.ModeDevice,
},
// directory.
{
h: &Header{
Name: "dir/",
Mode: 0755 | c_ISDIR,
Size: 0,
ModTime: time.Unix(1360601116, 0),
Typeflag: TypeDir,
},
fm: 0755 | os.ModeDir,
},
// fifo node.
{
h: &Header{
Name: "dev/initctl",
Mode: 0600 | c_ISFIFO,
Size: 0,
ModTime: time.Unix(1360578949, 0),
Typeflag: TypeFifo,
},
fm: 0600 | os.ModeNamedPipe,
},
// setuid.
{
h: &Header{
Name: "bin/su",
Mode: 0755 | c_ISREG | c_ISUID,
Size: 23232,
ModTime: time.Unix(1355405093, 0),
Typeflag: TypeReg,
},
fm: 0755 | os.ModeSetuid,
},
// setguid.
{
h: &Header{
Name: "group.txt",
Mode: 0750 | c_ISREG | c_ISGID,
Size: 0,
ModTime: time.Unix(1360602346, 0),
Typeflag: TypeReg,
},
fm: 0750 | os.ModeSetgid,
},
// sticky.
{
h: &Header{
Name: "sticky.txt",
Mode: 0600 | c_ISREG | c_ISVTX,
Size: 7,
ModTime: time.Unix(1360602540, 0),
Typeflag: TypeReg,
},
fm: 0600 | os.ModeSticky,
},
// hard link.
{
h: &Header{
Name: "hard.txt",
Mode: 0644 | c_ISREG,
Size: 0,
Linkname: "file.txt",
ModTime: time.Unix(1360600916, 0),
Typeflag: TypeLink,
},
fm: 0644,
},
// More information.
{
h: &Header{
Name: "info.txt",
Mode: 0600 | c_ISREG,
Size: 0,
Uid: 1000,
Gid: 1000,
ModTime: time.Unix(1360602540, 0),
Uname: "slartibartfast",
Gname: "users",
Typeflag: TypeReg,
},
fm: 0600,
},
}
for i, g := range golden {
fi := g.h.FileInfo()
h2, err := FileInfoHeader(fi, "")
if err != nil {
t.Error(err)
continue
}
if strings.Contains(fi.Name(), "/") {
t.Errorf("FileInfo of %q contains slash: %q", g.h.Name, fi.Name())
}
name := path.Base(g.h.Name)
if fi.IsDir() {
name += "/"
}
if got, want := h2.Name, name; got != want {
t.Errorf("i=%d: Name: got %v, want %v", i, got, want)
}
if got, want := h2.Size, g.h.Size; got != want {
t.Errorf("i=%d: Size: got %v, want %v", i, got, want)
}
if got, want := h2.Uid, g.h.Uid; got != want {
t.Errorf("i=%d: Uid: got %d, want %d", i, got, want)
}
if got, want := h2.Gid, g.h.Gid; got != want {
t.Errorf("i=%d: Gid: got %d, want %d", i, got, want)
}
if got, want := h2.Uname, g.h.Uname; got != want {
t.Errorf("i=%d: Uname: got %q, want %q", i, got, want)
}
if got, want := h2.Gname, g.h.Gname; got != want {
t.Errorf("i=%d: Gname: got %q, want %q", i, got, want)
}
if got, want := h2.Linkname, g.h.Linkname; got != want {
t.Errorf("i=%d: Linkname: got %v, want %v", i, got, want)
}
if got, want := h2.Typeflag, g.h.Typeflag; got != want {
t.Logf("%#v %#v", g.h, fi.Sys())
t.Errorf("i=%d: Typeflag: got %q, want %q", i, got, want)
}
if got, want := h2.Mode, g.h.Mode; got != want {
t.Errorf("i=%d: Mode: got %o, want %o", i, got, want)
}
if got, want := fi.Mode(), g.fm; got != want {
t.Errorf("i=%d: fi.Mode: got %o, want %o", i, got, want)
}
if got, want := h2.AccessTime, g.h.AccessTime; got != want {
t.Errorf("i=%d: AccessTime: got %v, want %v", i, got, want)
}
if got, want := h2.ChangeTime, g.h.ChangeTime; got != want {
t.Errorf("i=%d: ChangeTime: got %v, want %v", i, got, want)
}
if got, want := h2.ModTime, g.h.ModTime; got != want {
t.Errorf("i=%d: ModTime: got %v, want %v", i, got, want)
}
if sysh, ok := fi.Sys().(*Header); !ok || sysh != g.h {
t.Errorf("i=%d: Sys didn't return original *Header", i)
}
}
}

View File

@@ -1 +0,0 @@
Kilts

View File

@@ -1 +0,0 @@
Google.com

Binary file not shown.

View File

@@ -1,444 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
// TODO(dsymonds):
// - catch more errors (no first header, etc.)
import (
"bytes"
"errors"
"fmt"
"io"
"path"
"sort"
"strconv"
"strings"
"time"
)
var (
ErrWriteTooLong = errors.New("archive/tar: write too long")
ErrFieldTooLong = errors.New("archive/tar: header field too long")
ErrWriteAfterClose = errors.New("archive/tar: write after close")
errInvalidHeader = errors.New("archive/tar: header field too long or contains invalid values")
)
// A Writer provides sequential writing of a tar archive in POSIX.1 format.
// A tar archive consists of a sequence of files.
// Call WriteHeader to begin a new file, and then call Write to supply that file's data,
// writing at most hdr.Size bytes in total.
type Writer struct {
w io.Writer
err error
nb int64 // number of unwritten bytes for current file entry
pad int64 // amount of padding to write after current file entry
closed bool
usedBinary bool // whether the binary numeric field extension was used
preferPax bool // use pax header instead of binary numeric header
hdrBuff [blockSize]byte // buffer to use in writeHeader when writing a regular header
paxHdrBuff [blockSize]byte // buffer to use in writeHeader when writing a pax header
}
type formatter struct {
err error // Last error seen
}
// NewWriter creates a new Writer writing to w.
func NewWriter(w io.Writer) *Writer { return &Writer{w: w, preferPax: true} }
// Flush finishes writing the current file (optional).
func (tw *Writer) Flush() error {
if tw.nb > 0 {
tw.err = fmt.Errorf("archive/tar: missed writing %d bytes", tw.nb)
return tw.err
}
n := tw.nb + tw.pad
for n > 0 && tw.err == nil {
nr := n
if nr > blockSize {
nr = blockSize
}
var nw int
nw, tw.err = tw.w.Write(zeroBlock[0:nr])
n -= int64(nw)
}
tw.nb = 0
tw.pad = 0
return tw.err
}
// Write s into b, terminating it with a NUL if there is room.
func (f *formatter) formatString(b []byte, s string) {
if len(s) > len(b) {
f.err = ErrFieldTooLong
return
}
ascii := toASCII(s)
copy(b, ascii)
if len(ascii) < len(b) {
b[len(ascii)] = 0
}
}
// Encode x as an octal ASCII string and write it into b with leading zeros.
func (f *formatter) formatOctal(b []byte, x int64) {
s := strconv.FormatInt(x, 8)
// leading zeros, but leave room for a NUL.
for len(s)+1 < len(b) {
s = "0" + s
}
f.formatString(b, s)
}
// fitsInBase256 reports whether x can be encoded into n bytes using base-256
// encoding. Unlike octal encoding, base-256 encoding does not require that the
// string ends with a NUL character. Thus, all n bytes are available for output.
//
// If operating in binary mode, this assumes strict GNU binary mode; which means
// that the first byte can only be either 0x80 or 0xff. Thus, the first byte is
// equivalent to the sign bit in two's complement form.
func fitsInBase256(n int, x int64) bool {
var binBits = uint(n-1) * 8
return n >= 9 || (x >= -1<<binBits && x < 1<<binBits)
}
// Write x into b, as binary (GNUtar/star extension).
func (f *formatter) formatNumeric(b []byte, x int64) {
if fitsInBase256(len(b), x) {
for i := len(b) - 1; i >= 0; i-- {
b[i] = byte(x)
x >>= 8
}
b[0] |= 0x80 // Highest bit indicates binary format
return
}
f.formatOctal(b, 0) // Last resort, just write zero
f.err = ErrFieldTooLong
}
var (
minTime = time.Unix(0, 0)
// There is room for 11 octal digits (33 bits) of mtime.
maxTime = minTime.Add((1<<33 - 1) * time.Second)
)
// WriteHeader writes hdr and prepares to accept the file's contents.
// WriteHeader calls Flush if it is not the first header.
// Calling after a Close will return ErrWriteAfterClose.
func (tw *Writer) WriteHeader(hdr *Header) error {
return tw.writeHeader(hdr, true)
}
// WriteHeader writes hdr and prepares to accept the file's contents.
// WriteHeader calls Flush if it is not the first header.
// Calling after a Close will return ErrWriteAfterClose.
// As this method is called internally by writePax header to allow it to
// suppress writing the pax header.
func (tw *Writer) writeHeader(hdr *Header, allowPax bool) error {
if tw.closed {
return ErrWriteAfterClose
}
if tw.err == nil {
tw.Flush()
}
if tw.err != nil {
return tw.err
}
// a map to hold pax header records, if any are needed
paxHeaders := make(map[string]string)
// TODO(shanemhansen): we might want to use PAX headers for
// subsecond time resolution, but for now let's just capture
// too long fields or non ascii characters
var f formatter
var header []byte
// We need to select which scratch buffer to use carefully,
// since this method is called recursively to write PAX headers.
// If allowPax is true, this is the non-recursive call, and we will use hdrBuff.
// If allowPax is false, we are being called by writePAXHeader, and hdrBuff is
// already being used by the non-recursive call, so we must use paxHdrBuff.
header = tw.hdrBuff[:]
if !allowPax {
header = tw.paxHdrBuff[:]
}
copy(header, zeroBlock)
s := slicer(header)
// Wrappers around formatter that automatically sets paxHeaders if the
// argument extends beyond the capacity of the input byte slice.
var formatString = func(b []byte, s string, paxKeyword string) {
needsPaxHeader := paxKeyword != paxNone && len(s) > len(b) || !isASCII(s)
if needsPaxHeader {
paxHeaders[paxKeyword] = s
return
}
f.formatString(b, s)
}
var formatNumeric = func(b []byte, x int64, paxKeyword string) {
// Try octal first.
s := strconv.FormatInt(x, 8)
if len(s) < len(b) {
f.formatOctal(b, x)
return
}
// If it is too long for octal, and PAX is preferred, use a PAX header.
if paxKeyword != paxNone && tw.preferPax {
f.formatOctal(b, 0)
s := strconv.FormatInt(x, 10)
paxHeaders[paxKeyword] = s
return
}
tw.usedBinary = true
f.formatNumeric(b, x)
}
var formatTime = func(b []byte, t time.Time, paxKeyword string) {
var unixTime int64
if !t.Before(minTime) && !t.After(maxTime) {
unixTime = t.Unix()
}
formatNumeric(b, unixTime, paxNone)
// Write a PAX header if the time didn't fit precisely.
if paxKeyword != "" && tw.preferPax && allowPax && (t.Nanosecond() != 0 || !t.Before(minTime) || !t.After(maxTime)) {
paxHeaders[paxKeyword] = formatPAXTime(t)
}
}
// keep a reference to the filename to allow to overwrite it later if we detect that we can use ustar longnames instead of pax
pathHeaderBytes := s.next(fileNameSize)
formatString(pathHeaderBytes, hdr.Name, paxPath)
f.formatOctal(s.next(8), hdr.Mode) // 100:108
formatNumeric(s.next(8), int64(hdr.Uid), paxUid) // 108:116
formatNumeric(s.next(8), int64(hdr.Gid), paxGid) // 116:124
formatNumeric(s.next(12), hdr.Size, paxSize) // 124:136
formatTime(s.next(12), hdr.ModTime, paxMtime) // 136:148
s.next(8) // chksum (148:156)
s.next(1)[0] = hdr.Typeflag // 156:157
formatString(s.next(100), hdr.Linkname, paxLinkpath)
copy(s.next(8), []byte("ustar\x0000")) // 257:265
formatString(s.next(32), hdr.Uname, paxUname) // 265:297
formatString(s.next(32), hdr.Gname, paxGname) // 297:329
formatNumeric(s.next(8), hdr.Devmajor, paxNone) // 329:337
formatNumeric(s.next(8), hdr.Devminor, paxNone) // 337:345
// keep a reference to the prefix to allow to overwrite it later if we detect that we can use ustar longnames instead of pax
prefixHeaderBytes := s.next(155)
formatString(prefixHeaderBytes, "", paxNone) // 345:500 prefix
// Use the GNU magic instead of POSIX magic if we used any GNU extensions.
if tw.usedBinary {
copy(header[257:265], []byte("ustar \x00"))
}
_, paxPathUsed := paxHeaders[paxPath]
// try to use a ustar header when only the name is too long
if !tw.preferPax && len(paxHeaders) == 1 && paxPathUsed {
prefix, suffix, ok := splitUSTARPath(hdr.Name)
if ok {
// Since we can encode in USTAR format, disable PAX header.
delete(paxHeaders, paxPath)
// Update the path fields
formatString(pathHeaderBytes, suffix, paxNone)
formatString(prefixHeaderBytes, prefix, paxNone)
}
}
// The chksum field is terminated by a NUL and a space.
// This is different from the other octal fields.
chksum, _ := checksum(header)
f.formatOctal(header[148:155], chksum) // Never fails
header[155] = ' '
// Check if there were any formatting errors.
if f.err != nil {
tw.err = f.err
return tw.err
}
if allowPax {
if !hdr.AccessTime.IsZero() {
paxHeaders[paxAtime] = formatPAXTime(hdr.AccessTime)
}
if !hdr.ChangeTime.IsZero() {
paxHeaders[paxCtime] = formatPAXTime(hdr.ChangeTime)
}
if !hdr.CreationTime.IsZero() {
paxHeaders[paxCreationTime] = formatPAXTime(hdr.CreationTime)
}
for k, v := range hdr.Xattrs {
paxHeaders[paxXattr+k] = v
}
for k, v := range hdr.Winheaders {
paxHeaders[paxWindows+k] = v
}
}
if len(paxHeaders) > 0 {
if !allowPax {
return errInvalidHeader
}
if err := tw.writePAXHeader(hdr, paxHeaders); err != nil {
return err
}
}
tw.nb = int64(hdr.Size)
tw.pad = (blockSize - (tw.nb % blockSize)) % blockSize
_, tw.err = tw.w.Write(header)
return tw.err
}
func formatPAXTime(t time.Time) string {
sec := t.Unix()
usec := t.Nanosecond()
s := strconv.FormatInt(sec, 10)
if usec != 0 {
s = fmt.Sprintf("%s.%09d", s, usec)
}
return s
}
// splitUSTARPath splits a path according to USTAR prefix and suffix rules.
// If the path is not splittable, then it will return ("", "", false).
func splitUSTARPath(name string) (prefix, suffix string, ok bool) {
length := len(name)
if length <= fileNameSize || !isASCII(name) {
return "", "", false
} else if length > fileNamePrefixSize+1 {
length = fileNamePrefixSize + 1
} else if name[length-1] == '/' {
length--
}
i := strings.LastIndex(name[:length], "/")
nlen := len(name) - i - 1 // nlen is length of suffix
plen := i // plen is length of prefix
if i <= 0 || nlen > fileNameSize || nlen == 0 || plen > fileNamePrefixSize {
return "", "", false
}
return name[:i], name[i+1:], true
}
// writePaxHeader writes an extended pax header to the
// archive.
func (tw *Writer) writePAXHeader(hdr *Header, paxHeaders map[string]string) error {
// Prepare extended header
ext := new(Header)
ext.Typeflag = TypeXHeader
// Setting ModTime is required for reader parsing to
// succeed, and seems harmless enough.
ext.ModTime = hdr.ModTime
// The spec asks that we namespace our pseudo files
// with the current pid. However, this results in differing outputs
// for identical inputs. As such, the constant 0 is now used instead.
// golang.org/issue/12358
dir, file := path.Split(hdr.Name)
fullName := path.Join(dir, "PaxHeaders.0", file)
ascii := toASCII(fullName)
if len(ascii) > 100 {
ascii = ascii[:100]
}
ext.Name = ascii
// Construct the body
var buf bytes.Buffer
// Keys are sorted before writing to body to allow deterministic output.
var keys []string
for k := range paxHeaders {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
fmt.Fprint(&buf, formatPAXRecord(k, paxHeaders[k]))
}
ext.Size = int64(len(buf.Bytes()))
if err := tw.writeHeader(ext, false); err != nil {
return err
}
if _, err := tw.Write(buf.Bytes()); err != nil {
return err
}
if err := tw.Flush(); err != nil {
return err
}
return nil
}
// formatPAXRecord formats a single PAX record, prefixing it with the
// appropriate length.
func formatPAXRecord(k, v string) string {
const padding = 3 // Extra padding for ' ', '=', and '\n'
size := len(k) + len(v) + padding
size += len(strconv.Itoa(size))
record := fmt.Sprintf("%d %s=%s\n", size, k, v)
// Final adjustment if adding size field increased the record size.
if len(record) != size {
size = len(record)
record = fmt.Sprintf("%d %s=%s\n", size, k, v)
}
return record
}
// Write writes to the current entry in the tar archive.
// Write returns the error ErrWriteTooLong if more than
// hdr.Size bytes are written after WriteHeader.
func (tw *Writer) Write(b []byte) (n int, err error) {
if tw.closed {
err = ErrWriteAfterClose
return
}
overwrite := false
if int64(len(b)) > tw.nb {
b = b[0:tw.nb]
overwrite = true
}
n, err = tw.w.Write(b)
tw.nb -= int64(n)
if err == nil && overwrite {
err = ErrWriteTooLong
return
}
tw.err = err
return
}
// Close closes the tar archive, flushing any unwritten
// data to the underlying writer.
func (tw *Writer) Close() error {
if tw.err != nil || tw.closed {
return tw.err
}
tw.Flush()
tw.closed = true
if tw.err != nil {
return tw.err
}
// trailer: two zero blocks
for i := 0; i < 2; i++ {
_, tw.err = tw.w.Write(zeroBlock)
if tw.err != nil {
break
}
}
return tw.err
}

View File

@@ -1,739 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tar
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"math"
"os"
"reflect"
"sort"
"strings"
"testing"
"testing/iotest"
"time"
)
type writerTestEntry struct {
header *Header
contents string
}
type writerTest struct {
file string // filename of expected output
entries []*writerTestEntry
}
var writerTests = []*writerTest{
// The writer test file was produced with this command:
// tar (GNU tar) 1.26
// ln -s small.txt link.txt
// tar -b 1 --format=ustar -c -f writer.tar small.txt small2.txt link.txt
{
file: "testdata/writer.tar",
entries: []*writerTestEntry{
{
header: &Header{
Name: "small.txt",
Mode: 0640,
Uid: 73025,
Gid: 5000,
Size: 5,
ModTime: time.Unix(1246508266, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
},
contents: "Kilts",
},
{
header: &Header{
Name: "small2.txt",
Mode: 0640,
Uid: 73025,
Gid: 5000,
Size: 11,
ModTime: time.Unix(1245217492, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
},
contents: "Google.com\n",
},
{
header: &Header{
Name: "link.txt",
Mode: 0777,
Uid: 1000,
Gid: 1000,
Size: 0,
ModTime: time.Unix(1314603082, 0),
Typeflag: '2',
Linkname: "small.txt",
Uname: "strings",
Gname: "strings",
},
// no contents
},
},
},
// The truncated test file was produced using these commands:
// dd if=/dev/zero bs=1048576 count=16384 > /tmp/16gig.txt
// tar -b 1 -c -f- /tmp/16gig.txt | dd bs=512 count=8 > writer-big.tar
{
file: "testdata/writer-big.tar",
entries: []*writerTestEntry{
{
header: &Header{
Name: "tmp/16gig.txt",
Mode: 0640,
Uid: 73025,
Gid: 5000,
Size: 16 << 30,
ModTime: time.Unix(1254699560, 0),
Typeflag: '0',
Uname: "dsymonds",
Gname: "eng",
},
// fake contents
contents: strings.Repeat("\x00", 4<<10),
},
},
},
// The truncated test file was produced using these commands:
// dd if=/dev/zero bs=1048576 count=16384 > (longname/)*15 /16gig.txt
// tar -b 1 -c -f- (longname/)*15 /16gig.txt | dd bs=512 count=8 > writer-big-long.tar
{
file: "testdata/writer-big-long.tar",
entries: []*writerTestEntry{
{
header: &Header{
Name: strings.Repeat("longname/", 15) + "16gig.txt",
Mode: 0644,
Uid: 1000,
Gid: 1000,
Size: 16 << 30,
ModTime: time.Unix(1399583047, 0),
Typeflag: '0',
Uname: "guillaume",
Gname: "guillaume",
},
// fake contents
contents: strings.Repeat("\x00", 4<<10),
},
},
},
// This file was produced using gnu tar 1.17
// gnutar -b 4 --format=ustar (longname/)*15 + file.txt
{
file: "testdata/ustar.tar",
entries: []*writerTestEntry{
{
header: &Header{
Name: strings.Repeat("longname/", 15) + "file.txt",
Mode: 0644,
Uid: 0765,
Gid: 024,
Size: 06,
ModTime: time.Unix(1360135598, 0),
Typeflag: '0',
Uname: "shane",
Gname: "staff",
},
contents: "hello\n",
},
},
},
// This file was produced using gnu tar 1.26
// echo "Slartibartfast" > file.txt
// ln file.txt hard.txt
// tar -b 1 --format=ustar -c -f hardlink.tar file.txt hard.txt
{
file: "testdata/hardlink.tar",
entries: []*writerTestEntry{
{
header: &Header{
Name: "file.txt",
Mode: 0644,
Uid: 1000,
Gid: 100,
Size: 15,
ModTime: time.Unix(1425484303, 0),
Typeflag: '0',
Uname: "vbatts",
Gname: "users",
},
contents: "Slartibartfast\n",
},
{
header: &Header{
Name: "hard.txt",
Mode: 0644,
Uid: 1000,
Gid: 100,
Size: 0,
ModTime: time.Unix(1425484303, 0),
Typeflag: '1',
Linkname: "file.txt",
Uname: "vbatts",
Gname: "users",
},
// no contents
},
},
},
}
// Render byte array in a two-character hexadecimal string, spaced for easy visual inspection.
func bytestr(offset int, b []byte) string {
const rowLen = 32
s := fmt.Sprintf("%04x ", offset)
for _, ch := range b {
switch {
case '0' <= ch && ch <= '9', 'A' <= ch && ch <= 'Z', 'a' <= ch && ch <= 'z':
s += fmt.Sprintf(" %c", ch)
default:
s += fmt.Sprintf(" %02x", ch)
}
}
return s
}
// Render a pseudo-diff between two blocks of bytes.
func bytediff(a []byte, b []byte) string {
const rowLen = 32
s := fmt.Sprintf("(%d bytes vs. %d bytes)\n", len(a), len(b))
for offset := 0; len(a)+len(b) > 0; offset += rowLen {
na, nb := rowLen, rowLen
if na > len(a) {
na = len(a)
}
if nb > len(b) {
nb = len(b)
}
sa := bytestr(offset, a[0:na])
sb := bytestr(offset, b[0:nb])
if sa != sb {
s += fmt.Sprintf("-%v\n+%v\n", sa, sb)
}
a = a[na:]
b = b[nb:]
}
return s
}
func TestWriter(t *testing.T) {
testLoop:
for i, test := range writerTests {
expected, err := ioutil.ReadFile(test.file)
if err != nil {
t.Errorf("test %d: Unexpected error: %v", i, err)
continue
}
buf := new(bytes.Buffer)
tw := NewWriter(iotest.TruncateWriter(buf, 4<<10)) // only catch the first 4 KB
big := false
for j, entry := range test.entries {
big = big || entry.header.Size > 1<<10
if err := tw.WriteHeader(entry.header); err != nil {
t.Errorf("test %d, entry %d: Failed writing header: %v", i, j, err)
continue testLoop
}
if _, err := io.WriteString(tw, entry.contents); err != nil {
t.Errorf("test %d, entry %d: Failed writing contents: %v", i, j, err)
continue testLoop
}
}
// Only interested in Close failures for the small tests.
if err := tw.Close(); err != nil && !big {
t.Errorf("test %d: Failed closing archive: %v", i, err)
continue testLoop
}
actual := buf.Bytes()
if !bytes.Equal(expected, actual) {
t.Errorf("test %d: Incorrect result: (-=expected, +=actual)\n%v",
i, bytediff(expected, actual))
}
if testing.Short() { // The second test is expensive.
break
}
}
}
func TestPax(t *testing.T) {
// Create an archive with a large name
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
if err != nil {
t.Fatalf("os.Stat: %v", err)
}
// Force a PAX long name to be written
longName := strings.Repeat("ab", 100)
contents := strings.Repeat(" ", int(hdr.Size))
hdr.Name = longName
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if _, err = writer.Write([]byte(contents)); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Simple test to make sure PAX extensions are in effect
if !bytes.Contains(buf.Bytes(), []byte("PaxHeaders.0")) {
t.Fatal("Expected at least one PAX header to be written.")
}
// Test that we can get a long name back out of the archive.
reader := NewReader(&buf)
hdr, err = reader.Next()
if err != nil {
t.Fatal(err)
}
if hdr.Name != longName {
t.Fatal("Couldn't recover long file name")
}
}
func TestPaxSymlink(t *testing.T) {
// Create an archive with a large linkname
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
hdr.Typeflag = TypeSymlink
if err != nil {
t.Fatalf("os.Stat:1 %v", err)
}
// Force a PAX long linkname to be written
longLinkname := strings.Repeat("1234567890/1234567890", 10)
hdr.Linkname = longLinkname
hdr.Size = 0
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Simple test to make sure PAX extensions are in effect
if !bytes.Contains(buf.Bytes(), []byte("PaxHeaders.0")) {
t.Fatal("Expected at least one PAX header to be written.")
}
// Test that we can get a long name back out of the archive.
reader := NewReader(&buf)
hdr, err = reader.Next()
if err != nil {
t.Fatal(err)
}
if hdr.Linkname != longLinkname {
t.Fatal("Couldn't recover long link name")
}
}
func TestPaxNonAscii(t *testing.T) {
// Create an archive with non ascii. These should trigger a pax header
// because pax headers have a defined utf-8 encoding.
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
if err != nil {
t.Fatalf("os.Stat:1 %v", err)
}
// some sample data
chineseFilename := "文件名"
chineseGroupname := "組"
chineseUsername := "用戶名"
hdr.Name = chineseFilename
hdr.Gname = chineseGroupname
hdr.Uname = chineseUsername
contents := strings.Repeat(" ", int(hdr.Size))
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if _, err = writer.Write([]byte(contents)); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Simple test to make sure PAX extensions are in effect
if !bytes.Contains(buf.Bytes(), []byte("PaxHeaders.0")) {
t.Fatal("Expected at least one PAX header to be written.")
}
// Test that we can get a long name back out of the archive.
reader := NewReader(&buf)
hdr, err = reader.Next()
if err != nil {
t.Fatal(err)
}
if hdr.Name != chineseFilename {
t.Fatal("Couldn't recover unicode name")
}
if hdr.Gname != chineseGroupname {
t.Fatal("Couldn't recover unicode group")
}
if hdr.Uname != chineseUsername {
t.Fatal("Couldn't recover unicode user")
}
}
func TestPaxXattrs(t *testing.T) {
xattrs := map[string]string{
"user.key": "value",
}
// Create an archive with an xattr
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
if err != nil {
t.Fatalf("os.Stat: %v", err)
}
contents := "Kilts"
hdr.Xattrs = xattrs
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if _, err = writer.Write([]byte(contents)); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Test that we can get the xattrs back out of the archive.
reader := NewReader(&buf)
hdr, err = reader.Next()
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(hdr.Xattrs, xattrs) {
t.Fatalf("xattrs did not survive round trip: got %+v, want %+v",
hdr.Xattrs, xattrs)
}
}
func TestPaxHeadersSorted(t *testing.T) {
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
if err != nil {
t.Fatalf("os.Stat: %v", err)
}
contents := strings.Repeat(" ", int(hdr.Size))
hdr.Xattrs = map[string]string{
"foo": "foo",
"bar": "bar",
"baz": "baz",
"qux": "qux",
}
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if _, err = writer.Write([]byte(contents)); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Simple test to make sure PAX extensions are in effect
if !bytes.Contains(buf.Bytes(), []byte("PaxHeaders.0")) {
t.Fatal("Expected at least one PAX header to be written.")
}
// xattr bar should always appear before others
indices := []int{
bytes.Index(buf.Bytes(), []byte("bar=bar")),
bytes.Index(buf.Bytes(), []byte("baz=baz")),
bytes.Index(buf.Bytes(), []byte("foo=foo")),
bytes.Index(buf.Bytes(), []byte("qux=qux")),
}
if !sort.IntsAreSorted(indices) {
t.Fatal("PAX headers are not sorted")
}
}
func TestUSTARLongName(t *testing.T) {
// Create an archive with a path that failed to split with USTAR extension in previous versions.
fileinfo, err := os.Stat("testdata/small.txt")
if err != nil {
t.Fatal(err)
}
hdr, err := FileInfoHeader(fileinfo, "")
hdr.Typeflag = TypeDir
if err != nil {
t.Fatalf("os.Stat:1 %v", err)
}
// Force a PAX long name to be written. The name was taken from a practical example
// that fails and replaced ever char through numbers to anonymize the sample.
longName := "/0000_0000000/00000-000000000/0000_0000000/00000-0000000000000/0000_0000000/00000-0000000-00000000/0000_0000000/00000000/0000_0000000/000/0000_0000000/00000000v00/0000_0000000/000000/0000_0000000/0000000/0000_0000000/00000y-00/0000/0000/00000000/0x000000/"
hdr.Name = longName
hdr.Size = 0
var buf bytes.Buffer
writer := NewWriter(&buf)
if err := writer.WriteHeader(hdr); err != nil {
t.Fatal(err)
}
if err := writer.Close(); err != nil {
t.Fatal(err)
}
// Test that we can get a long name back out of the archive.
reader := NewReader(&buf)
hdr, err = reader.Next()
if err != nil {
t.Fatal(err)
}
if hdr.Name != longName {
t.Fatal("Couldn't recover long name")
}
}
func TestValidTypeflagWithPAXHeader(t *testing.T) {
var buffer bytes.Buffer
tw := NewWriter(&buffer)
fileName := strings.Repeat("ab", 100)
hdr := &Header{
Name: fileName,
Size: 4,
Typeflag: 0,
}
if err := tw.WriteHeader(hdr); err != nil {
t.Fatalf("Failed to write header: %s", err)
}
if _, err := tw.Write([]byte("fooo")); err != nil {
t.Fatalf("Failed to write the file's data: %s", err)
}
tw.Close()
tr := NewReader(&buffer)
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("Failed to read header: %s", err)
}
if header.Typeflag != 0 {
t.Fatalf("Typeflag should've been 0, found %d", header.Typeflag)
}
}
}
func TestWriteAfterClose(t *testing.T) {
var buffer bytes.Buffer
tw := NewWriter(&buffer)
hdr := &Header{
Name: "small.txt",
Size: 5,
}
if err := tw.WriteHeader(hdr); err != nil {
t.Fatalf("Failed to write header: %s", err)
}
tw.Close()
if _, err := tw.Write([]byte("Kilts")); err != ErrWriteAfterClose {
t.Fatalf("Write: got %v; want ErrWriteAfterClose", err)
}
}
func TestSplitUSTARPath(t *testing.T) {
var sr = strings.Repeat
var vectors = []struct {
input string // Input path
prefix string // Expected output prefix
suffix string // Expected output suffix
ok bool // Split success?
}{
{"", "", "", false},
{"abc", "", "", false},
{"用戶名", "", "", false},
{sr("a", fileNameSize), "", "", false},
{sr("a", fileNameSize) + "/", "", "", false},
{sr("a", fileNameSize) + "/a", sr("a", fileNameSize), "a", true},
{sr("a", fileNamePrefixSize) + "/", "", "", false},
{sr("a", fileNamePrefixSize) + "/a", sr("a", fileNamePrefixSize), "a", true},
{sr("a", fileNameSize+1), "", "", false},
{sr("/", fileNameSize+1), sr("/", fileNameSize-1), "/", true},
{sr("a", fileNamePrefixSize) + "/" + sr("b", fileNameSize),
sr("a", fileNamePrefixSize), sr("b", fileNameSize), true},
{sr("a", fileNamePrefixSize) + "//" + sr("b", fileNameSize), "", "", false},
{sr("a/", fileNameSize), sr("a/", 77) + "a", sr("a/", 22), true},
}
for _, v := range vectors {
prefix, suffix, ok := splitUSTARPath(v.input)
if prefix != v.prefix || suffix != v.suffix || ok != v.ok {
t.Errorf("splitUSTARPath(%q):\ngot (%q, %q, %v)\nwant (%q, %q, %v)",
v.input, prefix, suffix, ok, v.prefix, v.suffix, v.ok)
}
}
}
func TestFormatPAXRecord(t *testing.T) {
var medName = strings.Repeat("CD", 50)
var longName = strings.Repeat("AB", 100)
var vectors = []struct {
inputKey string
inputVal string
output string
}{
{"k", "v", "6 k=v\n"},
{"path", "/etc/hosts", "19 path=/etc/hosts\n"},
{"path", longName, "210 path=" + longName + "\n"},
{"path", medName, "110 path=" + medName + "\n"},
{"foo", "ba", "9 foo=ba\n"},
{"foo", "bar", "11 foo=bar\n"},
{"foo", "b=\nar=\n==\x00", "18 foo=b=\nar=\n==\x00\n"},
{"foo", "hello9 foo=ba\nworld", "27 foo=hello9 foo=ba\nworld\n"},
{"☺☻☹", "日a本b語ç", "27 ☺☻☹=日a本b語ç\n"},
{"\x00hello", "\x00world", "17 \x00hello=\x00world\n"},
}
for _, v := range vectors {
output := formatPAXRecord(v.inputKey, v.inputVal)
if output != v.output {
t.Errorf("formatPAXRecord(%q, %q): got %q, want %q",
v.inputKey, v.inputVal, output, v.output)
}
}
}
func TestFitsInBase256(t *testing.T) {
var vectors = []struct {
input int64
width int
ok bool
}{
{+1, 8, true},
{0, 8, true},
{-1, 8, true},
{1 << 56, 8, false},
{(1 << 56) - 1, 8, true},
{-1 << 56, 8, true},
{(-1 << 56) - 1, 8, false},
{121654, 8, true},
{-9849849, 8, true},
{math.MaxInt64, 9, true},
{0, 9, true},
{math.MinInt64, 9, true},
{math.MaxInt64, 12, true},
{0, 12, true},
{math.MinInt64, 12, true},
}
for _, v := range vectors {
ok := fitsInBase256(v.width, v.input)
if ok != v.ok {
t.Errorf("checkNumeric(%d, %d): got %v, want %v", v.input, v.width, ok, v.ok)
}
}
}
func TestFormatNumeric(t *testing.T) {
var vectors = []struct {
input int64
output string
ok bool
}{
// Test base-256 (binary) encoded values.
{-1, "\xff", true},
{-1, "\xff\xff", true},
{-1, "\xff\xff\xff", true},
{(1 << 0), "0", false},
{(1 << 8) - 1, "\x80\xff", true},
{(1 << 8), "0\x00", false},
{(1 << 16) - 1, "\x80\xff\xff", true},
{(1 << 16), "00\x00", false},
{-1 * (1 << 0), "\xff", true},
{-1*(1<<0) - 1, "0", false},
{-1 * (1 << 8), "\xff\x00", true},
{-1*(1<<8) - 1, "0\x00", false},
{-1 * (1 << 16), "\xff\x00\x00", true},
{-1*(1<<16) - 1, "00\x00", false},
{537795476381659745, "0000000\x00", false},
{537795476381659745, "\x80\x00\x00\x00\x07\x76\xa2\x22\xeb\x8a\x72\x61", true},
{-615126028225187231, "0000000\x00", false},
{-615126028225187231, "\xff\xff\xff\xff\xf7\x76\xa2\x22\xeb\x8a\x72\x61", true},
{math.MaxInt64, "0000000\x00", false},
{math.MaxInt64, "\x80\x00\x00\x00\x7f\xff\xff\xff\xff\xff\xff\xff", true},
{math.MinInt64, "0000000\x00", false},
{math.MinInt64, "\xff\xff\xff\xff\x80\x00\x00\x00\x00\x00\x00\x00", true},
{math.MaxInt64, "\x80\x7f\xff\xff\xff\xff\xff\xff\xff", true},
{math.MinInt64, "\xff\x80\x00\x00\x00\x00\x00\x00\x00", true},
}
for _, v := range vectors {
var f formatter
output := make([]byte, len(v.output))
f.formatNumeric(output, v.input)
ok := (f.err == nil)
if ok != v.ok {
if v.ok {
t.Errorf("formatNumeric(%d): got formatting failure, want success", v.input)
} else {
t.Errorf("formatNumeric(%d): got formatting success, want failure", v.input)
}
}
if string(output) != v.output {
t.Errorf("formatNumeric(%d): got %q, want %q", v.input, output, v.output)
}
}
}
func TestFormatPAXTime(t *testing.T) {
t1 := time.Date(2000, 1, 1, 11, 0, 0, 0, time.UTC)
t2 := time.Date(2000, 1, 1, 11, 0, 0, 100, time.UTC)
t3 := time.Date(1960, 1, 1, 11, 0, 0, 0, time.UTC)
t4 := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)
verify := func(time time.Time, s string) {
p := formatPAXTime(time)
if p != s {
t.Errorf("for %v, expected %s, got %s", time, s, p)
}
}
verify(t1, "946724400")
verify(t2, "946724400.000000100")
verify(t3, "-315579600")
verify(t4, "0")
}

View File

@@ -1,280 +0,0 @@
// +build windows
package winio
import (
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"runtime"
"syscall"
"unicode/utf16"
)
//sys backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead
//sys backupWrite(h syscall.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite
const (
BackupData = uint32(iota + 1)
BackupEaData
BackupSecurity
BackupAlternateData
BackupLink
BackupPropertyData
BackupObjectId
BackupReparseData
BackupSparseBlock
BackupTxfsData
)
const (
StreamSparseAttributes = uint32(8)
)
const (
WRITE_DAC = 0x40000
WRITE_OWNER = 0x80000
ACCESS_SYSTEM_SECURITY = 0x1000000
)
// BackupHeader represents a backup stream of a file.
type BackupHeader struct {
Id uint32 // The backup stream ID
Attributes uint32 // Stream attributes
Size int64 // The size of the stream in bytes
Name string // The name of the stream (for BackupAlternateData only).
Offset int64 // The offset of the stream in the file (for BackupSparseBlock only).
}
type win32StreamId struct {
StreamId uint32
Attributes uint32
Size uint64
NameSize uint32
}
// BackupStreamReader reads from a stream produced by the BackupRead Win32 API and produces a series
// of BackupHeader values.
type BackupStreamReader struct {
r io.Reader
bytesLeft int64
}
// NewBackupStreamReader produces a BackupStreamReader from any io.Reader.
func NewBackupStreamReader(r io.Reader) *BackupStreamReader {
return &BackupStreamReader{r, 0}
}
// Next returns the next backup stream and prepares for calls to Read(). It skips the remainder of the current stream if
// it was not completely read.
func (r *BackupStreamReader) Next() (*BackupHeader, error) {
if r.bytesLeft > 0 {
if s, ok := r.r.(io.Seeker); ok {
// Make sure Seek on io.SeekCurrent sometimes succeeds
// before trying the actual seek.
if _, err := s.Seek(0, io.SeekCurrent); err == nil {
if _, err = s.Seek(r.bytesLeft, io.SeekCurrent); err != nil {
return nil, err
}
r.bytesLeft = 0
}
}
if _, err := io.Copy(ioutil.Discard, r); err != nil {
return nil, err
}
}
var wsi win32StreamId
if err := binary.Read(r.r, binary.LittleEndian, &wsi); err != nil {
return nil, err
}
hdr := &BackupHeader{
Id: wsi.StreamId,
Attributes: wsi.Attributes,
Size: int64(wsi.Size),
}
if wsi.NameSize != 0 {
name := make([]uint16, int(wsi.NameSize/2))
if err := binary.Read(r.r, binary.LittleEndian, name); err != nil {
return nil, err
}
hdr.Name = syscall.UTF16ToString(name)
}
if wsi.StreamId == BackupSparseBlock {
if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil {
return nil, err
}
hdr.Size -= 8
}
r.bytesLeft = hdr.Size
return hdr, nil
}
// Read reads from the current backup stream.
func (r *BackupStreamReader) Read(b []byte) (int, error) {
if r.bytesLeft == 0 {
return 0, io.EOF
}
if int64(len(b)) > r.bytesLeft {
b = b[:r.bytesLeft]
}
n, err := r.r.Read(b)
r.bytesLeft -= int64(n)
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if r.bytesLeft == 0 && err == nil {
err = io.EOF
}
return n, err
}
// BackupStreamWriter writes a stream compatible with the BackupWrite Win32 API.
type BackupStreamWriter struct {
w io.Writer
bytesLeft int64
}
// NewBackupStreamWriter produces a BackupStreamWriter on top of an io.Writer.
func NewBackupStreamWriter(w io.Writer) *BackupStreamWriter {
return &BackupStreamWriter{w, 0}
}
// WriteHeader writes the next backup stream header and prepares for calls to Write().
func (w *BackupStreamWriter) WriteHeader(hdr *BackupHeader) error {
if w.bytesLeft != 0 {
return fmt.Errorf("missing %d bytes", w.bytesLeft)
}
name := utf16.Encode([]rune(hdr.Name))
wsi := win32StreamId{
StreamId: hdr.Id,
Attributes: hdr.Attributes,
Size: uint64(hdr.Size),
NameSize: uint32(len(name) * 2),
}
if hdr.Id == BackupSparseBlock {
// Include space for the int64 block offset
wsi.Size += 8
}
if err := binary.Write(w.w, binary.LittleEndian, &wsi); err != nil {
return err
}
if len(name) != 0 {
if err := binary.Write(w.w, binary.LittleEndian, name); err != nil {
return err
}
}
if hdr.Id == BackupSparseBlock {
if err := binary.Write(w.w, binary.LittleEndian, hdr.Offset); err != nil {
return err
}
}
w.bytesLeft = hdr.Size
return nil
}
// Write writes to the current backup stream.
func (w *BackupStreamWriter) Write(b []byte) (int, error) {
if w.bytesLeft < int64(len(b)) {
return 0, fmt.Errorf("too many bytes by %d", int64(len(b))-w.bytesLeft)
}
n, err := w.w.Write(b)
w.bytesLeft -= int64(n)
return n, err
}
// BackupFileReader provides an io.ReadCloser interface on top of the BackupRead Win32 API.
type BackupFileReader struct {
f *os.File
includeSecurity bool
ctx uintptr
}
// NewBackupFileReader returns a new BackupFileReader from a file handle. If includeSecurity is true,
// Read will attempt to read the security descriptor of the file.
func NewBackupFileReader(f *os.File, includeSecurity bool) *BackupFileReader {
r := &BackupFileReader{f, includeSecurity, 0}
return r
}
// Read reads a backup stream from the file by calling the Win32 API BackupRead().
func (r *BackupFileReader) Read(b []byte) (int, error) {
var bytesRead uint32
err := backupRead(syscall.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx)
if err != nil {
return 0, &os.PathError{"BackupRead", r.f.Name(), err}
}
runtime.KeepAlive(r.f)
if bytesRead == 0 {
return 0, io.EOF
}
return int(bytesRead), nil
}
// Close frees Win32 resources associated with the BackupFileReader. It does not close
// the underlying file.
func (r *BackupFileReader) Close() error {
if r.ctx != 0 {
backupRead(syscall.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx)
runtime.KeepAlive(r.f)
r.ctx = 0
}
return nil
}
// BackupFileWriter provides an io.WriteCloser interface on top of the BackupWrite Win32 API.
type BackupFileWriter struct {
f *os.File
includeSecurity bool
ctx uintptr
}
// NewBackupFileWriter returns a new BackupFileWriter from a file handle. If includeSecurity is true,
// Write() will attempt to restore the security descriptor from the stream.
func NewBackupFileWriter(f *os.File, includeSecurity bool) *BackupFileWriter {
w := &BackupFileWriter{f, includeSecurity, 0}
return w
}
// Write restores a portion of the file using the provided backup stream.
func (w *BackupFileWriter) Write(b []byte) (int, error) {
var bytesWritten uint32
err := backupWrite(syscall.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx)
if err != nil {
return 0, &os.PathError{"BackupWrite", w.f.Name(), err}
}
runtime.KeepAlive(w.f)
if int(bytesWritten) != len(b) {
return int(bytesWritten), errors.New("not all bytes could be written")
}
return len(b), nil
}
// Close frees Win32 resources associated with the BackupFileWriter. It does not
// close the underlying file.
func (w *BackupFileWriter) Close() error {
if w.ctx != 0 {
backupWrite(syscall.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx)
runtime.KeepAlive(w.f)
w.ctx = 0
}
return nil
}
// OpenForBackup opens a file or directory, potentially skipping access checks if the backup
// or restore privileges have been acquired.
//
// If the file opened was a directory, it cannot be used with Readdir().
func OpenForBackup(path string, access uint32, share uint32, createmode uint32) (*os.File, error) {
winPath, err := syscall.UTF16FromString(path)
if err != nil {
return nil, err
}
h, err := syscall.CreateFile(&winPath[0], access, share, nil, createmode, syscall.FILE_FLAG_BACKUP_SEMANTICS|syscall.FILE_FLAG_OPEN_REPARSE_POINT, 0)
if err != nil {
err = &os.PathError{Op: "open", Path: path, Err: err}
return nil, err
}
return os.NewFile(uintptr(h), path), nil
}

View File

@@ -1,255 +0,0 @@
package winio
import (
"io"
"io/ioutil"
"os"
"syscall"
"testing"
)
var testFileName string
func TestMain(m *testing.M) {
f, err := ioutil.TempFile("", "tmp")
if err != nil {
panic(err)
}
testFileName = f.Name()
f.Close()
defer os.Remove(testFileName)
os.Exit(m.Run())
}
func makeTestFile(makeADS bool) error {
os.Remove(testFileName)
f, err := os.Create(testFileName)
if err != nil {
return err
}
defer f.Close()
_, err = f.Write([]byte("testing 1 2 3\n"))
if err != nil {
return err
}
if makeADS {
a, err := os.Create(testFileName + ":ads.txt")
if err != nil {
return err
}
defer a.Close()
_, err = a.Write([]byte("alternate data stream\n"))
if err != nil {
return err
}
}
return nil
}
func TestBackupRead(t *testing.T) {
err := makeTestFile(true)
if err != nil {
t.Fatal(err)
}
f, err := os.Open(testFileName)
if err != nil {
t.Fatal(err)
}
defer f.Close()
r := NewBackupFileReader(f, false)
defer r.Close()
b, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if len(b) == 0 {
t.Fatal("no data")
}
}
func TestBackupStreamRead(t *testing.T) {
err := makeTestFile(true)
if err != nil {
t.Fatal(err)
}
f, err := os.Open(testFileName)
if err != nil {
t.Fatal(err)
}
defer f.Close()
r := NewBackupFileReader(f, false)
defer r.Close()
br := NewBackupStreamReader(r)
gotData := false
gotAltData := false
for {
hdr, err := br.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
switch hdr.Id {
case BackupData:
if gotData {
t.Fatal("duplicate data")
}
if hdr.Name != "" {
t.Fatalf("unexpected name %s", hdr.Name)
}
b, err := ioutil.ReadAll(br)
if err != nil {
t.Fatal(err)
}
if string(b) != "testing 1 2 3\n" {
t.Fatalf("incorrect data %v", b)
}
gotData = true
case BackupAlternateData:
if gotAltData {
t.Fatal("duplicate alt data")
}
if hdr.Name != ":ads.txt:$DATA" {
t.Fatalf("incorrect name %s", hdr.Name)
}
b, err := ioutil.ReadAll(br)
if err != nil {
t.Fatal(err)
}
if string(b) != "alternate data stream\n" {
t.Fatalf("incorrect data %v", b)
}
gotAltData = true
default:
t.Fatalf("unknown stream ID %d", hdr.Id)
}
}
if !gotData || !gotAltData {
t.Fatal("missing stream")
}
}
func TestBackupStreamWrite(t *testing.T) {
f, err := os.Create(testFileName)
if err != nil {
t.Fatal(err)
}
defer f.Close()
w := NewBackupFileWriter(f, false)
defer w.Close()
data := "testing 1 2 3\n"
altData := "alternate stream\n"
br := NewBackupStreamWriter(w)
err = br.WriteHeader(&BackupHeader{Id: BackupData, Size: int64(len(data))})
if err != nil {
t.Fatal(err)
}
n, err := br.Write([]byte(data))
if err != nil {
t.Fatal(err)
}
if n != len(data) {
t.Fatal("short write")
}
err = br.WriteHeader(&BackupHeader{Id: BackupAlternateData, Size: int64(len(altData)), Name: ":ads.txt:$DATA"})
if err != nil {
t.Fatal(err)
}
n, err = br.Write([]byte(altData))
if err != nil {
t.Fatal(err)
}
if n != len(altData) {
t.Fatal("short write")
}
f.Close()
b, err := ioutil.ReadFile(testFileName)
if err != nil {
t.Fatal(err)
}
if string(b) != data {
t.Fatalf("wrong data %v", b)
}
b, err = ioutil.ReadFile(testFileName + ":ads.txt")
if err != nil {
t.Fatal(err)
}
if string(b) != altData {
t.Fatalf("wrong data %v", b)
}
}
func makeSparseFile() error {
os.Remove(testFileName)
f, err := os.Create(testFileName)
if err != nil {
return err
}
defer f.Close()
const (
FSCTL_SET_SPARSE = 0x000900c4
FSCTL_SET_ZERO_DATA = 0x000980c8
)
err = syscall.DeviceIoControl(syscall.Handle(f.Fd()), FSCTL_SET_SPARSE, nil, 0, nil, 0, nil, nil)
if err != nil {
return err
}
_, err = f.Write([]byte("testing 1 2 3\n"))
if err != nil {
return err
}
_, err = f.Seek(1000000, 0)
if err != nil {
return err
}
_, err = f.Write([]byte("more data later\n"))
if err != nil {
return err
}
return nil
}
func TestBackupSparseFile(t *testing.T) {
err := makeSparseFile()
if err != nil {
t.Fatal(err)
}
f, err := os.Open(testFileName)
if err != nil {
t.Fatal(err)
}
defer f.Close()
r := NewBackupFileReader(f, false)
defer r.Close()
br := NewBackupStreamReader(r)
for {
hdr, err := br.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
t.Log(hdr)
}
}

View File

@@ -1,4 +0,0 @@
// +build !windows
// This file only exists to allow go get on non-Windows platforms.
package backuptar

View File

@@ -1,439 +0,0 @@
// +build windows
package backuptar
import (
"encoding/base64"
"errors"
"fmt"
"io"
"io/ioutil"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
"github.com/Microsoft/go-winio"
"github.com/Microsoft/go-winio/archive/tar" // until archive/tar supports pax extensions in its interface
)
const (
c_ISUID = 04000 // Set uid
c_ISGID = 02000 // Set gid
c_ISVTX = 01000 // Save text (sticky bit)
c_ISDIR = 040000 // Directory
c_ISFIFO = 010000 // FIFO
c_ISREG = 0100000 // Regular file
c_ISLNK = 0120000 // Symbolic link
c_ISBLK = 060000 // Block special file
c_ISCHR = 020000 // Character special file
c_ISSOCK = 0140000 // Socket
)
const (
hdrFileAttributes = "fileattr"
hdrSecurityDescriptor = "sd"
hdrRawSecurityDescriptor = "rawsd"
hdrMountPoint = "mountpoint"
hdrEaPrefix = "xattr."
)
func writeZeroes(w io.Writer, count int64) error {
buf := make([]byte, 8192)
c := len(buf)
for i := int64(0); i < count; i += int64(c) {
if int64(c) > count-i {
c = int(count - i)
}
_, err := w.Write(buf[:c])
if err != nil {
return err
}
}
return nil
}
func copySparse(t *tar.Writer, br *winio.BackupStreamReader) error {
curOffset := int64(0)
for {
bhdr, err := br.Next()
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if err != nil {
return err
}
if bhdr.Id != winio.BackupSparseBlock {
return fmt.Errorf("unexpected stream %d", bhdr.Id)
}
// archive/tar does not support writing sparse files
// so just write zeroes to catch up to the current offset.
err = writeZeroes(t, bhdr.Offset-curOffset)
if bhdr.Size == 0 {
break
}
n, err := io.Copy(t, br)
if err != nil {
return err
}
curOffset = bhdr.Offset + n
}
return nil
}
// BasicInfoHeader creates a tar header from basic file information.
func BasicInfoHeader(name string, size int64, fileInfo *winio.FileBasicInfo) *tar.Header {
hdr := &tar.Header{
Name: filepath.ToSlash(name),
Size: size,
Typeflag: tar.TypeReg,
ModTime: time.Unix(0, fileInfo.LastWriteTime.Nanoseconds()),
ChangeTime: time.Unix(0, fileInfo.ChangeTime.Nanoseconds()),
AccessTime: time.Unix(0, fileInfo.LastAccessTime.Nanoseconds()),
CreationTime: time.Unix(0, fileInfo.CreationTime.Nanoseconds()),
Winheaders: make(map[string]string),
}
hdr.Winheaders[hdrFileAttributes] = fmt.Sprintf("%d", fileInfo.FileAttributes)
if (fileInfo.FileAttributes & syscall.FILE_ATTRIBUTE_DIRECTORY) != 0 {
hdr.Mode |= c_ISDIR
hdr.Size = 0
hdr.Typeflag = tar.TypeDir
}
return hdr
}
// WriteTarFileFromBackupStream writes a file to a tar writer using data from a Win32 backup stream.
//
// This encodes Win32 metadata as tar pax vendor extensions starting with MSWINDOWS.
//
// The additional Win32 metadata is:
//
// MSWINDOWS.fileattr: The Win32 file attributes, as a decimal value
//
// MSWINDOWS.rawsd: The Win32 security descriptor, in raw binary format
//
// MSWINDOWS.mountpoint: If present, this is a mount point and not a symlink, even though the type is '2' (symlink)
func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size int64, fileInfo *winio.FileBasicInfo) error {
name = filepath.ToSlash(name)
hdr := BasicInfoHeader(name, size, fileInfo)
// If r can be seeked, then this function is two-pass: pass 1 collects the
// tar header data, and pass 2 copies the data stream. If r cannot be
// seeked, then some header data (in particular EAs) will be silently lost.
var (
restartPos int64
err error
)
sr, readTwice := r.(io.Seeker)
if readTwice {
if restartPos, err = sr.Seek(0, io.SeekCurrent); err != nil {
readTwice = false
}
}
br := winio.NewBackupStreamReader(r)
var dataHdr *winio.BackupHeader
for dataHdr == nil {
bhdr, err := br.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
switch bhdr.Id {
case winio.BackupData:
hdr.Mode |= c_ISREG
if !readTwice {
dataHdr = bhdr
}
case winio.BackupSecurity:
sd, err := ioutil.ReadAll(br)
if err != nil {
return err
}
hdr.Winheaders[hdrRawSecurityDescriptor] = base64.StdEncoding.EncodeToString(sd)
case winio.BackupReparseData:
hdr.Mode |= c_ISLNK
hdr.Typeflag = tar.TypeSymlink
reparseBuffer, err := ioutil.ReadAll(br)
rp, err := winio.DecodeReparsePoint(reparseBuffer)
if err != nil {
return err
}
if rp.IsMountPoint {
hdr.Winheaders[hdrMountPoint] = "1"
}
hdr.Linkname = rp.Target
case winio.BackupEaData:
eab, err := ioutil.ReadAll(br)
if err != nil {
return err
}
eas, err := winio.DecodeExtendedAttributes(eab)
if err != nil {
return err
}
for _, ea := range eas {
// Use base64 encoding for the binary value. Note that there
// is no way to encode the EA's flags, since their use doesn't
// make any sense for persisted EAs.
hdr.Winheaders[hdrEaPrefix+ea.Name] = base64.StdEncoding.EncodeToString(ea.Value)
}
case winio.BackupAlternateData, winio.BackupLink, winio.BackupPropertyData, winio.BackupObjectId, winio.BackupTxfsData:
// ignore these streams
default:
return fmt.Errorf("%s: unknown stream ID %d", name, bhdr.Id)
}
}
err = t.WriteHeader(hdr)
if err != nil {
return err
}
if readTwice {
// Get back to the data stream.
if _, err = sr.Seek(restartPos, io.SeekStart); err != nil {
return err
}
for dataHdr == nil {
bhdr, err := br.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
if bhdr.Id == winio.BackupData {
dataHdr = bhdr
}
}
}
if dataHdr != nil {
// A data stream was found. Copy the data.
if (dataHdr.Attributes & winio.StreamSparseAttributes) == 0 {
if size != dataHdr.Size {
return fmt.Errorf("%s: mismatch between file size %d and header size %d", name, size, dataHdr.Size)
}
_, err = io.Copy(t, br)
if err != nil {
return err
}
} else {
err = copySparse(t, br)
if err != nil {
return err
}
}
}
// Look for streams after the data stream. The only ones we handle are alternate data streams.
// Other streams may have metadata that could be serialized, but the tar header has already
// been written. In practice, this means that we don't get EA or TXF metadata.
for {
bhdr, err := br.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
switch bhdr.Id {
case winio.BackupAlternateData:
altName := bhdr.Name
if strings.HasSuffix(altName, ":$DATA") {
altName = altName[:len(altName)-len(":$DATA")]
}
if (bhdr.Attributes & winio.StreamSparseAttributes) == 0 {
hdr = &tar.Header{
Name: name + altName,
Mode: hdr.Mode,
Typeflag: tar.TypeReg,
Size: bhdr.Size,
ModTime: hdr.ModTime,
AccessTime: hdr.AccessTime,
ChangeTime: hdr.ChangeTime,
}
err = t.WriteHeader(hdr)
if err != nil {
return err
}
_, err = io.Copy(t, br)
if err != nil {
return err
}
} else {
// Unsupported for now, since the size of the alternate stream is not present
// in the backup stream until after the data has been read.
return errors.New("tar of sparse alternate data streams is unsupported")
}
case winio.BackupEaData, winio.BackupLink, winio.BackupPropertyData, winio.BackupObjectId, winio.BackupTxfsData:
// ignore these streams
default:
return fmt.Errorf("%s: unknown stream ID %d after data", name, bhdr.Id)
}
}
return nil
}
// FileInfoFromHeader retrieves basic Win32 file information from a tar header, using the additional metadata written by
// WriteTarFileFromBackupStream.
func FileInfoFromHeader(hdr *tar.Header) (name string, size int64, fileInfo *winio.FileBasicInfo, err error) {
name = hdr.Name
if hdr.Typeflag == tar.TypeReg || hdr.Typeflag == tar.TypeRegA {
size = hdr.Size
}
fileInfo = &winio.FileBasicInfo{
LastAccessTime: syscall.NsecToFiletime(hdr.AccessTime.UnixNano()),
LastWriteTime: syscall.NsecToFiletime(hdr.ModTime.UnixNano()),
ChangeTime: syscall.NsecToFiletime(hdr.ChangeTime.UnixNano()),
CreationTime: syscall.NsecToFiletime(hdr.CreationTime.UnixNano()),
}
if attrStr, ok := hdr.Winheaders[hdrFileAttributes]; ok {
attr, err := strconv.ParseUint(attrStr, 10, 32)
if err != nil {
return "", 0, nil, err
}
fileInfo.FileAttributes = uintptr(attr)
} else {
if hdr.Typeflag == tar.TypeDir {
fileInfo.FileAttributes |= syscall.FILE_ATTRIBUTE_DIRECTORY
}
}
return
}
// WriteBackupStreamFromTarFile writes a Win32 backup stream from the current tar file. Since this function may process multiple
// tar file entries in order to collect all the alternate data streams for the file, it returns the next
// tar file that was not processed, or io.EOF is there are no more.
func WriteBackupStreamFromTarFile(w io.Writer, t *tar.Reader, hdr *tar.Header) (*tar.Header, error) {
bw := winio.NewBackupStreamWriter(w)
var sd []byte
var err error
// Maintaining old SDDL-based behavior for backward compatibility. All new tar headers written
// by this library will have raw binary for the security descriptor.
if sddl, ok := hdr.Winheaders[hdrSecurityDescriptor]; ok {
sd, err = winio.SddlToSecurityDescriptor(sddl)
if err != nil {
return nil, err
}
}
if sdraw, ok := hdr.Winheaders[hdrRawSecurityDescriptor]; ok {
sd, err = base64.StdEncoding.DecodeString(sdraw)
if err != nil {
return nil, err
}
}
if len(sd) != 0 {
bhdr := winio.BackupHeader{
Id: winio.BackupSecurity,
Size: int64(len(sd)),
}
err := bw.WriteHeader(&bhdr)
if err != nil {
return nil, err
}
_, err = bw.Write(sd)
if err != nil {
return nil, err
}
}
var eas []winio.ExtendedAttribute
for k, v := range hdr.Winheaders {
if !strings.HasPrefix(k, hdrEaPrefix) {
continue
}
data, err := base64.StdEncoding.DecodeString(v)
if err != nil {
return nil, err
}
eas = append(eas, winio.ExtendedAttribute{
Name: k[len(hdrEaPrefix):],
Value: data,
})
}
if len(eas) != 0 {
eadata, err := winio.EncodeExtendedAttributes(eas)
if err != nil {
return nil, err
}
bhdr := winio.BackupHeader{
Id: winio.BackupEaData,
Size: int64(len(eadata)),
}
err = bw.WriteHeader(&bhdr)
if err != nil {
return nil, err
}
_, err = bw.Write(eadata)
if err != nil {
return nil, err
}
}
if hdr.Typeflag == tar.TypeSymlink {
_, isMountPoint := hdr.Winheaders[hdrMountPoint]
rp := winio.ReparsePoint{
Target: filepath.FromSlash(hdr.Linkname),
IsMountPoint: isMountPoint,
}
reparse := winio.EncodeReparsePoint(&rp)
bhdr := winio.BackupHeader{
Id: winio.BackupReparseData,
Size: int64(len(reparse)),
}
err := bw.WriteHeader(&bhdr)
if err != nil {
return nil, err
}
_, err = bw.Write(reparse)
if err != nil {
return nil, err
}
}
if hdr.Typeflag == tar.TypeReg || hdr.Typeflag == tar.TypeRegA {
bhdr := winio.BackupHeader{
Id: winio.BackupData,
Size: hdr.Size,
}
err := bw.WriteHeader(&bhdr)
if err != nil {
return nil, err
}
_, err = io.Copy(bw, t)
if err != nil {
return nil, err
}
}
// Copy all the alternate data streams and return the next non-ADS header.
for {
ahdr, err := t.Next()
if err != nil {
return nil, err
}
if ahdr.Typeflag != tar.TypeReg || !strings.HasPrefix(ahdr.Name, hdr.Name+":") {
return ahdr, nil
}
bhdr := winio.BackupHeader{
Id: winio.BackupAlternateData,
Size: ahdr.Size,
Name: ahdr.Name[len(hdr.Name):] + ":$DATA",
}
err = bw.WriteHeader(&bhdr)
if err != nil {
return nil, err
}
_, err = io.Copy(bw, t)
if err != nil {
return nil, err
}
}
}

View File

@@ -1,84 +0,0 @@
package backuptar
import (
"bytes"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/Microsoft/go-winio"
"github.com/Microsoft/go-winio/archive/tar"
)
func ensurePresent(t *testing.T, m map[string]string, keys ...string) {
for _, k := range keys {
if _, ok := m[k]; !ok {
t.Error(k, "not present in tar header")
}
}
}
func TestRoundTrip(t *testing.T) {
f, err := ioutil.TempFile("", "tst")
if err != nil {
t.Fatal(err)
}
defer f.Close()
defer os.Remove(f.Name())
if _, err = f.Write([]byte("testing 1 2 3\n")); err != nil {
t.Fatal(err)
}
if _, err = f.Seek(0, 0); err != nil {
t.Fatal(err)
}
fi, err := f.Stat()
if err != nil {
t.Fatal(err)
}
bi, err := winio.GetFileBasicInfo(f)
if err != nil {
t.Fatal(err)
}
br := winio.NewBackupFileReader(f, true)
defer br.Close()
var buf bytes.Buffer
tw := tar.NewWriter(&buf)
err = WriteTarFileFromBackupStream(tw, br, f.Name(), fi.Size(), bi)
if err != nil {
t.Fatal(err)
}
tr := tar.NewReader(&buf)
hdr, err := tr.Next()
if err != nil {
t.Fatal(err)
}
name, size, bi2, err := FileInfoFromHeader(hdr)
if err != nil {
t.Fatal(err)
}
if name != filepath.ToSlash(f.Name()) {
t.Errorf("got name %s, expected %s", name, filepath.ToSlash(f.Name()))
}
if size != fi.Size() {
t.Errorf("got size %d, expected %d", size, fi.Size())
}
if !reflect.DeepEqual(*bi, *bi2) {
t.Errorf("got %#v, expected %#v", *bi, *bi2)
}
ensurePresent(t, hdr.Winheaders, "fileattr", "rawsd")
}

View File

@@ -1,137 +0,0 @@
package winio
import (
"bytes"
"encoding/binary"
"errors"
)
type fileFullEaInformation struct {
NextEntryOffset uint32
Flags uint8
NameLength uint8
ValueLength uint16
}
var (
fileFullEaInformationSize = binary.Size(&fileFullEaInformation{})
errInvalidEaBuffer = errors.New("invalid extended attribute buffer")
errEaNameTooLarge = errors.New("extended attribute name too large")
errEaValueTooLarge = errors.New("extended attribute value too large")
)
// ExtendedAttribute represents a single Windows EA.
type ExtendedAttribute struct {
Name string
Value []byte
Flags uint8
}
func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
var info fileFullEaInformation
err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info)
if err != nil {
err = errInvalidEaBuffer
return
}
nameOffset := fileFullEaInformationSize
nameLen := int(info.NameLength)
valueOffset := nameOffset + int(info.NameLength) + 1
valueLen := int(info.ValueLength)
nextOffset := int(info.NextEntryOffset)
if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) {
err = errInvalidEaBuffer
return
}
ea.Name = string(b[nameOffset : nameOffset+nameLen])
ea.Value = b[valueOffset : valueOffset+valueLen]
ea.Flags = info.Flags
if info.NextEntryOffset != 0 {
nb = b[info.NextEntryOffset:]
}
return
}
// DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION
// buffer retrieved from BackupRead, ZwQueryEaFile, etc.
func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) {
for len(b) != 0 {
ea, nb, err := parseEa(b)
if err != nil {
return nil, err
}
eas = append(eas, ea)
b = nb
}
return
}
func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error {
if int(uint8(len(ea.Name))) != len(ea.Name) {
return errEaNameTooLarge
}
if int(uint16(len(ea.Value))) != len(ea.Value) {
return errEaValueTooLarge
}
entrySize := uint32(fileFullEaInformationSize + len(ea.Name) + 1 + len(ea.Value))
withPadding := (entrySize + 3) &^ 3
nextOffset := uint32(0)
if !last {
nextOffset = withPadding
}
info := fileFullEaInformation{
NextEntryOffset: nextOffset,
Flags: ea.Flags,
NameLength: uint8(len(ea.Name)),
ValueLength: uint16(len(ea.Value)),
}
err := binary.Write(buf, binary.LittleEndian, &info)
if err != nil {
return err
}
_, err = buf.Write([]byte(ea.Name))
if err != nil {
return err
}
err = buf.WriteByte(0)
if err != nil {
return err
}
_, err = buf.Write(ea.Value)
if err != nil {
return err
}
_, err = buf.Write([]byte{0, 0, 0}[0 : withPadding-entrySize])
if err != nil {
return err
}
return nil
}
// EncodeExtendedAttributes encodes a list of EAs into a FILE_FULL_EA_INFORMATION
// buffer for use with BackupWrite, ZwSetEaFile, etc.
func EncodeExtendedAttributes(eas []ExtendedAttribute) ([]byte, error) {
var buf bytes.Buffer
for i := range eas {
last := false
if i == len(eas)-1 {
last = true
}
err := writeEa(&buf, &eas[i], last)
if err != nil {
return nil, err
}
}
return buf.Bytes(), nil
}

View File

@@ -1,89 +0,0 @@
package winio
import (
"io/ioutil"
"os"
"reflect"
"syscall"
"testing"
"unsafe"
)
var (
testEas = []ExtendedAttribute{
{Name: "foo", Value: []byte("bar")},
{Name: "fizz", Value: []byte("buzz")},
}
testEasEncoded = []byte{16, 0, 0, 0, 0, 3, 3, 0, 102, 111, 111, 0, 98, 97, 114, 0, 0, 0, 0, 0, 0, 4, 4, 0, 102, 105, 122, 122, 0, 98, 117, 122, 122, 0, 0, 0}
testEasNotPadded = testEasEncoded[0 : len(testEasEncoded)-3]
testEasTruncated = testEasEncoded[0:20]
)
func Test_RoundTripEas(t *testing.T) {
b, err := EncodeExtendedAttributes(testEas)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(testEasEncoded, b) {
t.Fatalf("encoded mismatch %v %v", testEasEncoded, b)
}
eas, err := DecodeExtendedAttributes(b)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(testEas, eas) {
t.Fatalf("mismatch %+v %+v", testEas, eas)
}
}
func Test_EasDontNeedPaddingAtEnd(t *testing.T) {
eas, err := DecodeExtendedAttributes(testEasNotPadded)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(testEas, eas) {
t.Fatalf("mismatch %+v %+v", testEas, eas)
}
}
func Test_TruncatedEasFailCorrectly(t *testing.T) {
_, err := DecodeExtendedAttributes(testEasTruncated)
if err == nil {
t.Fatal("expected error")
}
}
func Test_NilEasEncodeAndDecodeAsNil(t *testing.T) {
b, err := EncodeExtendedAttributes(nil)
if err != nil {
t.Fatal(err)
}
if len(b) != 0 {
t.Fatal("expected empty")
}
eas, err := DecodeExtendedAttributes(nil)
if err != nil {
t.Fatal(err)
}
if len(eas) != 0 {
t.Fatal("expected empty")
}
}
// Test_SetFileEa makes sure that the test buffer is actually parsable by NtSetEaFile.
func Test_SetFileEa(t *testing.T) {
f, err := ioutil.TempFile("", "winio")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
ntdll := syscall.MustLoadDLL("ntdll.dll")
ntSetEaFile := ntdll.MustFindProc("NtSetEaFile")
var iosb [2]uintptr
r, _, _ := ntSetEaFile.Call(f.Fd(), uintptr(unsafe.Pointer(&iosb[0])), uintptr(unsafe.Pointer(&testEasEncoded[0])), uintptr(len(testEasEncoded)))
if r != 0 {
t.Fatalf("NtSetEaFile failed with %08x", r)
}
}

View File

@@ -1,307 +0,0 @@
// +build windows
package winio
import (
"errors"
"io"
"runtime"
"sync"
"sync/atomic"
"syscall"
"time"
)
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
type atomicBool int32
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
func (b *atomicBool) swap(new bool) bool {
var newInt int32
if new {
newInt = 1
}
return atomic.SwapInt32((*int32)(b), newInt) == 1
}
const (
cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1
cFILE_SKIP_SET_EVENT_ON_HANDLE = 2
)
var (
ErrFileClosed = errors.New("file has already been closed")
ErrTimeout = &timeoutError{}
)
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{}
var ioInitOnce sync.Once
var ioCompletionPort syscall.Handle
// ioResult contains the result of an asynchronous IO operation
type ioResult struct {
bytes uint32
err error
}
// ioOperation represents an outstanding asynchronous Win32 IO
type ioOperation struct {
o syscall.Overlapped
ch chan ioResult
}
func initIo() {
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
if err != nil {
panic(err)
}
ioCompletionPort = h
go ioCompletionProcessor(h)
}
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected.
type win32File struct {
handle syscall.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
closing atomicBool
readDeadline deadlineHandler
writeDeadline deadlineHandler
}
type deadlineHandler struct {
setLock sync.Mutex
channel timeoutChan
channelLock sync.RWMutex
timer *time.Timer
timedout atomicBool
}
// makeWin32File makes a new win32File from an existing file handle
func makeWin32File(h syscall.Handle) (*win32File, error) {
f := &win32File{handle: h}
ioInitOnce.Do(initIo)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
if err != nil {
return nil, err
}
err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE)
if err != nil {
return nil, err
}
f.readDeadline.channel = make(timeoutChan)
f.writeDeadline.channel = make(timeoutChan)
return f, nil
}
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
return makeWin32File(h)
}
// closeHandle closes the resources associated with a Win32 handle
func (f *win32File) closeHandle() {
f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once.
if !f.closing.swap(true) {
f.wgLock.Unlock()
// cancel all IO and wait for it to complete
cancelIoEx(f.handle, nil)
f.wg.Wait()
// at this point, no new IO can start
syscall.Close(f.handle)
f.handle = 0
} else {
f.wgLock.Unlock()
}
}
// Close closes a win32File.
func (f *win32File) Close() error {
f.closeHandle()
return nil
}
// prepareIo prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *win32File) prepareIo() (*ioOperation, error) {
f.wgLock.RLock()
if f.closing.isSet() {
f.wgLock.RUnlock()
return nil, ErrFileClosed
}
f.wg.Add(1)
f.wgLock.RUnlock()
c := &ioOperation{}
c.ch = make(chan ioResult)
return c, nil
}
// ioCompletionProcessor processes completed async IOs forever
func ioCompletionProcessor(h syscall.Handle) {
for {
var bytes uint32
var key uintptr
var op *ioOperation
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE)
if op == nil {
panic(err)
}
op.ch <- ioResult{bytes, err}
}
}
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed.
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING {
return int(bytes), err
}
if f.closing.isSet() {
cancelIoEx(f.handle, &c.o)
}
var timeout timeoutChan
if d != nil {
d.channelLock.Lock()
timeout = d.channel
d.channelLock.Unlock()
}
var r ioResult
select {
case r = <-c.ch:
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED {
if f.closing.isSet() {
err = ErrFileClosed
}
}
case <-timeout:
cancelIoEx(f.handle, &c.o)
r = <-c.ch
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED {
err = ErrTimeout
}
}
// runtime.KeepAlive is needed, as c is passed via native
// code to ioCompletionProcessor, c must remain alive
// until the channel read is complete.
runtime.KeepAlive(c)
return int(r.bytes), err
}
// Read reads from a file handle.
func (f *win32File) Read(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.readDeadline.timedout.isSet() {
return 0, ErrTimeout
}
var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b)
// Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF
} else if err == syscall.ERROR_BROKEN_PIPE {
return 0, io.EOF
} else {
return n, err
}
}
// Write writes to a file handle.
func (f *win32File) Write(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.writeDeadline.timedout.isSet() {
return 0, ErrTimeout
}
var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b)
return n, err
}
func (f *win32File) SetReadDeadline(deadline time.Time) error {
return f.readDeadline.set(deadline)
}
func (f *win32File) SetWriteDeadline(deadline time.Time) error {
return f.writeDeadline.set(deadline)
}
func (f *win32File) Flush() error {
return syscall.FlushFileBuffers(f.handle)
}
func (d *deadlineHandler) set(deadline time.Time) error {
d.setLock.Lock()
defer d.setLock.Unlock()
if d.timer != nil {
if !d.timer.Stop() {
<-d.channel
}
d.timer = nil
}
d.timedout.setFalse()
select {
case <-d.channel:
d.channelLock.Lock()
d.channel = make(chan struct{})
d.channelLock.Unlock()
default:
}
if deadline.IsZero() {
return nil
}
timeoutIO := func() {
d.timedout.setTrue()
close(d.channel)
}
now := time.Now()
duration := deadline.Sub(now)
if deadline.After(now) {
// Deadline is in the future, set a timer to wait
d.timer = time.AfterFunc(duration, timeoutIO)
} else {
// Deadline is in the past. Cancel all pending IO now.
timeoutIO()
}
return nil
}

View File

@@ -1,60 +0,0 @@
// +build windows
package winio
import (
"os"
"runtime"
"syscall"
"unsafe"
)
//sys getFileInformationByHandleEx(h syscall.Handle, class uint32, buffer *byte, size uint32) (err error) = GetFileInformationByHandleEx
//sys setFileInformationByHandle(h syscall.Handle, class uint32, buffer *byte, size uint32) (err error) = SetFileInformationByHandle
const (
fileBasicInfo = 0
fileIDInfo = 0x12
)
// FileBasicInfo contains file access time and file attributes information.
type FileBasicInfo struct {
CreationTime, LastAccessTime, LastWriteTime, ChangeTime syscall.Filetime
FileAttributes uintptr // includes padding
}
// GetFileBasicInfo retrieves times and attributes for a file.
func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
bi := &FileBasicInfo{}
if err := getFileInformationByHandleEx(syscall.Handle(f.Fd()), fileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil {
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
}
runtime.KeepAlive(f)
return bi, nil
}
// SetFileBasicInfo sets times and attributes for a file.
func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error {
if err := setFileInformationByHandle(syscall.Handle(f.Fd()), fileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil {
return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err}
}
runtime.KeepAlive(f)
return nil
}
// FileIDInfo contains the volume serial number and file ID for a file. This pair should be
// unique on a system.
type FileIDInfo struct {
VolumeSerialNumber uint64
FileID [16]byte
}
// GetFileID retrieves the unique (volume, file ID) pair for a file.
func GetFileID(f *os.File) (*FileIDInfo, error) {
fileID := &FileIDInfo{}
if err := getFileInformationByHandleEx(syscall.Handle(f.Fd()), fileIDInfo, (*byte)(unsafe.Pointer(fileID)), uint32(unsafe.Sizeof(*fileID))); err != nil {
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
}
runtime.KeepAlive(f)
return fileID, nil
}

Some files were not shown because too many files have changed in this diff Show More