diff --git a/api/agent/agent.go b/api/agent/agent.go index 48a372cf7..e88c0e4e2 100644 --- a/api/agent/agent.go +++ b/api/agent/agent.go @@ -25,6 +25,7 @@ import ( "github.com/sirupsen/logrus" "go.opencensus.io/stats" "go.opencensus.io/trace" + "os" ) // TODO we should prob store async calls in db immediately since we're returning id (will 404 until post-execution) @@ -1011,6 +1012,7 @@ func (a *agent) runHot(ctx context.Context, call *call, tok ResourceToken, state call.slots.queueSlot(&hotSlot{done: make(chan struct{}), fatalErr: err}) return } + case <-ctx.Done(): call.slots.queueSlot(&hotSlot{done: make(chan struct{}), fatalErr: ctx.Err()}) return @@ -1055,9 +1057,40 @@ func (a *agent) runHot(ctx context.Context, call *call, tok ResourceToken, state logger.WithError(res.Error()).Info("hot function terminated") } +//checkSocketDestination verifies that the socket file created by the FDK is valid and permitted - notably verifying that any symlinks are relative to the socket dir +func checkSocketDestination(filename string) error { + finfo, err := os.Lstat(filename) + if err != nil { + return fmt.Errorf("error statting unix socket link file %s", err) + } + + if (finfo.Mode() & os.ModeSymlink) > 0 { + linkDest, err := os.Readlink(filename) + if err != nil { + return fmt.Errorf("error reading unix socket symlink destination %s", err) + } + if filepath.Dir(linkDest) != "." { + return fmt.Errorf("invalid unix socket symlink, symlinks must be relative within the unix socket directory") + } + } + + // stat the absolute path and check it is a socket + absInfo, err := os.Stat(filename) + if err != nil { + return fmt.Errorf("unable to stat unix socket file %s", err) + } + if absInfo.Mode()&os.ModeSocket == 0 { + return fmt.Errorf("listener file is not a socket") + } + + return nil +} func inotifyUDS(ctx context.Context, iofsDir string, awaitUDS chan<- error) { // XXX(reed): I forgot how to plumb channels temporarily forgive me for this sin (inotify will timeout, this is just bad programming) err := inotifyAwait(ctx, iofsDir) + if err == nil { + err = checkSocketDestination(filepath.Join(iofsDir, udsFilename)) + } select { case awaitUDS <- err: case <-ctx.Done(): @@ -1089,6 +1122,7 @@ func inotifyAwait(ctx context.Context, iofsDir string) error { case event := <-fsWatcher.Events: common.Logger(ctx).WithField("event", event).Debug("fsnotify event") if event.Op&fsnotify.Create == fsnotify.Create && event.Name == filepath.Join(iofsDir, udsFilename) { + // wait until the socket file is created by the container return nil } diff --git a/api/agent/agent_test.go b/api/agent/agent_test.go index ecd24b4b1..2f5af9ce8 100644 --- a/api/agent/agent_test.go +++ b/api/agent/agent_test.go @@ -22,6 +22,10 @@ import ( "github.com/fnproject/fn/api/models" "github.com/fnproject/fn/api/mqs" "github.com/sirupsen/logrus" + "io/ioutil" + "net" + "os" + "path/filepath" ) func init() { @@ -1162,3 +1166,65 @@ func TestDockerAuthExtn(t *testing.T) { t.Fatalf("unexpected registry token %s", da.RegistryToken) } } + +func TestCheckSocketDestination(t *testing.T) { + tmpDir, err := ioutil.TempDir(os.TempDir(), "testSocketPerms") + if err != nil { + t.Fatal("failed to create temp tmpDir", err) + } + defer os.RemoveAll(tmpDir) + + goodSock := filepath.Join(tmpDir, "fn.sock") + s, err := net.Listen("unix", goodSock) + if err != nil { + t.Fatal("failed to create socket", err) + } + defer s.Close() + + err = os.Chmod(goodSock, 0666) + if err != nil { + t.Fatal("failed to change perms", err) + } + notASocket := filepath.Join(tmpDir, "notasock.sock") + + err = ioutil.WriteFile(notASocket, []byte{0}, 0666) + if err != nil { + t.Fatalf("Failed to create empty sock") + } + + goodSymlink := filepath.Join(tmpDir, "goodlink.sock") + err = os.Symlink("fn.sock", goodSymlink) + if err != nil { + t.Fatalf("Failed to create symlink") + } + + badLinkNonExistant := filepath.Join(tmpDir, "badlinknonExist.sock") + err = os.Symlink("noxexistatnt.sock", badLinkNonExistant) + if err != nil { + t.Fatalf("Failed to create symlink") + } + + badLinkOutOfPath := filepath.Join(tmpDir, "badlinkoutofpath.sock") + err = os.Symlink(filepath.Join("..", filepath.Base(tmpDir), "fn.sock"), badLinkOutOfPath) + if err != nil { + t.Fatalf("Failed to create symlink") + } + + for _, good := range []string{goodSock, goodSymlink} { + t.Run(filepath.Base(good), func(t *testing.T) { + err := checkSocketDestination(good) + if err != nil { + t.Errorf("Expected no error got, %s", err) + } + }) + } + for _, bad := range []string{notASocket, badLinkNonExistant, badLinkOutOfPath, filepath.Join(tmpDir, "notAFile"), tmpDir} { + t.Run(filepath.Base(bad), func(t *testing.T) { + err := checkSocketDestination(bad) + if err == nil { + t.Errorf("Expected an error but got none") + } + }) + } + +}