diff --git a/internal/unarchiver/testdata/testfile-escapes.tar.gz b/internal/unarchiver/testdata/testfile-escapes.tar.gz new file mode 100644 index 0000000000..c69a0ef234 Binary files /dev/null and b/internal/unarchiver/testdata/testfile-escapes.tar.gz differ diff --git a/internal/unarchiver/testdata/testfile.tar.gz b/internal/unarchiver/testdata/testfile.tar.gz index c69a0ef234..03ee3fe41b 100644 Binary files a/internal/unarchiver/testdata/testfile.tar.gz and b/internal/unarchiver/testdata/testfile.tar.gz differ diff --git a/internal/unarchiver/unarchiver.go b/internal/unarchiver/unarchiver.go index 4b5aed2c5e..cce34446d1 100644 --- a/internal/unarchiver/unarchiver.go +++ b/internal/unarchiver/unarchiver.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "runtime" + "strings" "github.com/mholt/archives" @@ -16,19 +17,42 @@ import ( type Unarchiver struct { archives.Extraction + + untrusted bool +} + +// Option configures an Unarchiver. +type Option func(*Unarchiver) + +// WithUntrustedSource marks the archive as coming from an untrusted source, so +// every extracted path, symlink target, and hardlink target is confined under +// the destination root and anything that would escape aborts extraction. Use it +// for untrusted archives such as private ingredient wheels. +// +// It is off by default: trusted Platform artifacts may legitimately contain +// absolute symlinks (for example into /usr/share), which would otherwise be +// rejected. +func WithUntrustedSource() Option { + return func(ua *Unarchiver) { ua.untrusted = true } } -func NewTarGz() Unarchiver { - return Unarchiver{archives.CompressedArchive{ +func NewTarGz(opts ...Option) Unarchiver { + return newUnarchiver(archives.CompressedArchive{ Compression: archives.Gz{}, Extraction: archives.Tar{}, - }} + }, opts) } -func NewZip() Unarchiver { - return Unarchiver{ - archives.Zip{}, +func NewZip(opts ...Option) Unarchiver { + return newUnarchiver(archives.Zip{}, opts) +} + +func newUnarchiver(extraction archives.Extraction, opts []Option) Unarchiver { + ua := Unarchiver{Extraction: extraction} + for _, opt := range opts { + opt(&ua) } + return ua } // PrepareUnpacking prepares the destination directory and the archive for unpacking @@ -49,31 +73,62 @@ func (ua *Unarchiver) PrepareUnpacking(source, destination string) (archiveFile return archiveFile, nil } -// Unarchive unarchives an archive file and unpacks it in `destination` +// Unarchive unarchives an archive file and unpacks it in `destination`. For an +// archive from an untrusted source (see WithUntrustedSource), every entry path, +// symlink target, and hardlink target is confined under destination and anything +// that would escape aborts extraction; otherwise paths are trusted as-is. func (ua *Unarchiver) Unarchive(archiveStream io.Reader, destination string) error { + root := filepath.Clean(destination) ctx := context.Background() err := ua.Extract(ctx, archiveStream, func(_ context.Context, file archives.FileInfo) error { - path := filepath.Join(destination, file.NameInArchive) + path := filepath.Join(root, file.NameInArchive) + if ua.untrusted && !isContainedPath(root, path) { + return errs.New("entry %q escapes the extraction root", file.NameInArchive) + } if file.IsDir() { - return mkdir(path) + if err := mkdir(path); err != nil { + return errs.Wrap(err, "could not create directory") + } + return nil } if file.LinkTarget != "" { if file.Mode()&os.ModeSymlink != 0 { - return writeNewSymbolicLink(path, file.LinkTarget) + if ua.untrusted { + if filepath.IsAbs(file.LinkTarget) { + return errs.New("symlink target %q is absolute", file.LinkTarget) + } + resolved := filepath.Join(filepath.Dir(path), file.LinkTarget) + if !isContainedPath(root, resolved) { + return errs.New("symlink target %q escapes the extraction root", file.LinkTarget) + } + } + if err := writeNewSymbolicLink(path, file.LinkTarget); err != nil { + return errs.Wrap(err, "could not write symlink") + } + return nil + } + target := filepath.Join(root, file.LinkTarget) + if ua.untrusted && !isContainedPath(root, target) { + return errs.New("hardlink target %q escapes the extraction root", file.LinkTarget) + } + if err := writeNewHardLink(path, target); err != nil { + return errs.Wrap(err, "could not write hardlink") } - target := filepath.Join(destination, file.LinkTarget) - return writeNewHardLink(path, target) + return nil } f, err := file.Open() if err != nil { - return err + return errs.Wrap(err, "could not open archived file") } defer f.Close() - return writeNewFile(path, f, file.Mode()) + if err := writeNewFile(path, f, file.Mode()); err != nil { + return errs.Wrap(err, "could not write file") + } + return nil }) if err != nil { return errs.Wrap(err, "Unable to extract files") @@ -82,6 +137,12 @@ func (ua *Unarchiver) Unarchive(archiveStream io.Reader, destination string) err return nil } +// isContainedPath reports whether path is at or under root. Both are expected to +// be cleaned (filepath.Join cleans its result). +func isContainedPath(root, path string) bool { + return path == root || strings.HasPrefix(path, root+string(os.PathSeparator)) +} + // the following files are just copied from the ActiveState/archiver repository // so we can use them in our extensions diff --git a/internal/unarchiver/unarchiver_test.go b/internal/unarchiver/unarchiver_test.go index 670a66189c..99baacdc89 100644 --- a/internal/unarchiver/unarchiver_test.go +++ b/internal/unarchiver/unarchiver_test.go @@ -21,26 +21,47 @@ func (suite *UnarchiverTestSuite) TestUnarchiveWithProgress() { func (suite *UnarchiverTestSuite) TestUnarchive() { cases := []struct { - name string - ua unarchiver.Unarchiver - testfile string - prep func(destination string) + name string + ua unarchiver.Unarchiver + testfile string + wantErr bool + wantFiles int }{ { - "successful unpacking targz", + // testfile.tar.gz is fully contained. + "successful tar.gz unpacking", unarchiver.NewTarGz(), - "testfile.tar.gz", func(destination string) { - err := os.WriteFile(destination, []byte{}, 0666) - suite.Require().NoError(err) - }, + "testfile.tar.gz", + false, + 11, }, { - "successful unpacking zip", + // testfile-escapes.tar.gz has a root-level symlink (symlink-to-file3 -> + // ../b/c/file3) whose target resolves outside the destination, so it is + // rejected when treated as untrusted. + "escaping tar.gz rejected when untrusted", + unarchiver.NewTarGz(unarchiver.WithUntrustedSource()), + "testfile-escapes.tar.gz", + true, + 0, + }, + { + // When trusted (the default), the same archive extracts as before + // (Platform artifacts may legitimately link outside the destination). + "escaping tar.gz extracts when trusted", + unarchiver.NewTarGz(), + "testfile-escapes.tar.gz", + false, + 12, + }, + { + // The zip fixture stores its symlinks as ordinary files, so every entry is + // contained and extraction succeeds. + "successful zip unpacking", unarchiver.NewZip(), - "testfile.zip", func(destination string) { - err := os.WriteFile(destination, []byte{}, 0666) - suite.Require().NoError(err) - }, + "testfile.zip", + false, + 12, }, } @@ -54,10 +75,14 @@ func (suite *UnarchiverTestSuite) TestUnarchive() { destination := filepath.Join(tempDir, "destination") f, err := tc.ua.PrepareUnpacking(testfile, destination) - suite.Assert().NoError(err) - suite.Assert().NotNil(f) + suite.Require().NoError(err) + suite.Require().NotNil(f) err = tc.ua.Unarchive(f, destination) + if tc.wantErr { + suite.Assert().Error(err) + return + } suite.Assert().NoError(err) installedFiles, err := listFilesRecursively(destination) @@ -65,7 +90,7 @@ func (suite *UnarchiverTestSuite) TestUnarchive() { sort.Strings(installedFiles) - suite.Assert().Len(installedFiles, 12) + suite.Assert().Len(installedFiles, tc.wantFiles) }) } } diff --git a/internal/unarchiver/untrusted_test.go b/internal/unarchiver/untrusted_test.go new file mode 100644 index 0000000000..3bf0120548 --- /dev/null +++ b/internal/unarchiver/untrusted_test.go @@ -0,0 +1,150 @@ +package unarchiver + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "os" + "path/filepath" + "runtime" + "testing" +) + +type tarEntry struct { + name string + typeflag byte + linkname string + body string +} + +// makeTarGz builds an in-memory tar.gz from the given entries, including +// deliberately malicious ones (the tar writer does not sanitize names). +func makeTarGz(t *testing.T, entries []tarEntry) []byte { + t.Helper() + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + tw := tar.NewWriter(gz) + for _, e := range entries { + hdr := &tar.Header{Name: e.name, Typeflag: e.typeflag, Linkname: e.linkname, Mode: 0644} + if e.typeflag == tar.TypeReg { + hdr.Size = int64(len(e.body)) + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("write header %q: %v", e.name, err) + } + if e.typeflag == tar.TypeReg { + if _, err := tw.Write([]byte(e.body)); err != nil { + t.Fatalf("write body %q: %v", e.name, err) + } + } + } + if err := tw.Close(); err != nil { + t.Fatal(err) + } + if err := gz.Close(); err != nil { + t.Fatal(err) + } + return buf.Bytes() +} + +func TestUntrustedSourceRejectsEscapes(t *testing.T) { + tests := []struct { + name string + entries []tarEntry + }{ + {"path traversal", []tarEntry{{name: "../escape.txt", typeflag: tar.TypeReg, body: "x"}}}, + {"absolute symlink", []tarEntry{{name: "link", typeflag: tar.TypeSymlink, linkname: "/etc/passwd"}}}, + {"symlink escapes root", []tarEntry{{name: "sub/link", typeflag: tar.TypeSymlink, linkname: "../../outside"}}}, + {"hardlink escapes root", []tarEntry{{name: "link", typeflag: tar.TypeLink, linkname: "../outside"}}}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + root := t.TempDir() + dest := filepath.Join(root, "dest") + ua := NewTarGz(WithUntrustedSource()) + err := ua.Unarchive(bytes.NewReader(makeTarGz(t, tc.entries)), dest) + if err == nil { + t.Fatal("expected rejection, got nil") + } + // Nothing should have been written outside dest. + if entries, _ := os.ReadDir(root); len(entries) > 1 { + t.Errorf("unexpected files written outside dest: %v", entries) + } + }) + } +} + +func TestTrustedSourceAllowsEscape(t *testing.T) { + root := t.TempDir() + dest := filepath.Join(root, "dest") + + // Without WithUntrustedSource, an escaping entry extracts as before (the + // entry lands at root/escape.txt, still inside the test's sandbox). + ua := NewTarGz() + err := ua.Unarchive(bytes.NewReader(makeTarGz(t, []tarEntry{ + {name: "../escape.txt", typeflag: tar.TypeReg, body: "trusted"}, + })), dest) + if err != nil { + t.Fatalf("default (trusted) extraction should not error: %v", err) + } + if _, err := os.Stat(filepath.Join(root, "escape.txt")); err != nil { + t.Errorf("escaping entry should have been extracted: %v", err) + } +} + +func TestUntrustedSourceHappyPath(t *testing.T) { + root := t.TempDir() + dest := filepath.Join(root, "dest") + ua := NewTarGz(WithUntrustedSource()) + err := ua.Unarchive(bytes.NewReader(makeTarGz(t, []tarEntry{ + {name: "dir/", typeflag: tar.TypeDir}, + {name: "dir/file.txt", typeflag: tar.TypeReg, body: "hello"}, + })), dest) + if err != nil { + t.Fatalf("happy path failed: %v", err) + } + got, err := os.ReadFile(filepath.Join(dest, "dir", "file.txt")) + if err != nil || string(got) != "hello" { + t.Fatalf("file not extracted correctly: got %q err %v", got, err) + } +} + +func TestUntrustedSourceAllowsContainedLinks(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink creation requires privileges on Windows") + } + dest := filepath.Join(t.TempDir(), "dest") + ua := NewTarGz(WithUntrustedSource()) + err := ua.Unarchive(bytes.NewReader(makeTarGz(t, []tarEntry{ + {name: "file.txt", typeflag: tar.TypeReg, body: "data"}, + {name: "sym", typeflag: tar.TypeSymlink, linkname: "file.txt"}, // contained sibling + {name: "hard", typeflag: tar.TypeLink, linkname: "file.txt"}, // contained target + })), dest) + if err != nil { + t.Fatalf("contained links should extract: %v", err) + } + for _, name := range []string{"file.txt", "sym", "hard"} { + if _, err := os.Lstat(filepath.Join(dest, name)); err != nil { + t.Errorf("expected %q to be extracted: %v", name, err) + } + } +} + +func TestIsContainedPath(t *testing.T) { + root := filepath.Clean(t.TempDir()) + tests := []struct { + name string + contained bool + }{ + {"file.txt", true}, + {"sub/file.txt", true}, + {"../escape.txt", false}, + {"sub/../../escape.txt", false}, + } + for _, tc := range tests { + path := filepath.Join(root, tc.name) + if got := isContainedPath(root, path); got != tc.contained { + t.Errorf("isContainedPath(root, join(%q)) = %v, want %v", tc.name, got, tc.contained) + } + } +}