From eb36a174060ff54197cc1c9826fbaf5bbc47a2cd Mon Sep 17 00:00:00 2001 From: Michael Anthony Knyszek Date: Wed, 5 Jan 2022 20:53:30 +0000 Subject: [PATCH] jar: share decoding buffer across files in checkJAR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This purpose of this change is to reduce allocations in checkJAR by sharing a zip decoding buffer across files in a JAR instead of fully decoding each file into its own byte slice. A consequence of this change is that the buffer is fixed size, so the matches* functions cannot operate on byte slices as they once did. Instead, they need to operate on an io.Reader. So, this change also refactors class file scanning into its own method, and rewrites the match functionality in terms of rsc.io/binaryregexp. The performance delta of this change on BenchmarkParse is: name old time/op new time/op delta Parse-16 156µs ± 1% 38µs ± 1% -75.77% (p=0.000 n=9+10) ParseParallel-16 49.7µs ± 8% 14.1µs ± 8% -71.68% (p=0.000 n=10+10) name old alloc/op new alloc/op delta Parse-16 52.2kB ± 0% 18.1kB ± 0% -65.31% (p=0.000 n=10+10) ParseParallel-16 52.8kB ± 0% 18.4kB ± 0% -65.04% (p=0.000 n=10+10) name old allocs/op new allocs/op delta Parse-16 112 ± 0% 54 ± 0% -51.79% (p=0.000 n=10+10) ParseParallel-16 112 ± 0% 54 ± 0% -51.79% (p=0.000 n=10+10) --- go.mod | 1 + go.sum | 2 + jar/jar.go | 183 ++++++++++++++++++++++++++++++++++------------- jar/jar_test.go | 186 +++++++++++++++++++++++++++++++++++++++++------- 4 files changed, 298 insertions(+), 74 deletions(-) diff --git a/go.mod b/go.mod index d5f4902..5ce1d8d 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.17 require ( github.com/google/go-cmp v0.5.6 golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e + rsc.io/binaryregexp v0.2.0 ) diff --git a/go.sum b/go.sum index 3b904ea..247412a 100644 --- a/go.sum +++ b/go.sum @@ -4,3 +4,5 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+R golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= +rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= diff --git a/jar/jar.go b/jar/jar.go index 08ca77e..a1793c1 100644 --- a/jar/jar.go +++ b/jar/jar.go @@ -26,8 +26,10 @@ import ( "os" "path" "strings" + "sync" zipfork "github.com/google/log4jscanner/third_party/zip" + "rsc.io/binaryregexp" ) const ( @@ -171,6 +173,14 @@ func walkZIP(r *zip.Reader, fn func(f *zip.File) error) error { return nil } +const bufSize = 4 << 10 // 4 KiB + +var bufPool = sync.Pool{ + New: func() interface{} { + return make([]byte, bufSize) + }, +} + func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error { if depth > maxZipDepth { return fmt.Errorf("reached max zip depth of %d", maxZipDepth) @@ -203,31 +213,15 @@ func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error { defer f.Close() info := zf.FileInfo() - var r io.Reader = f - if fsize := info.Size(); fsize > 0 { - if fsize+size > maxZipSize { - return fmt.Errorf("reading %s would exceed memory limit: %v", p, err) - } - r = io.LimitReader(f, fsize) - } - - content, err := io.ReadAll(r) if err != nil { - return fmt.Errorf("reading file %s: %v", p, err) - } - if !c.hasLookupClass { - if strings.Contains(p, "JndiLookup.class") { - c.hasLookupClass = true - } + return fmt.Errorf("stat file %s: %v", p, err) } - if !c.hasOldJndiManagerConstructor { - c.hasOldJndiManagerConstructor = strings.Contains(p, "JndiManager") && matchesLog4JYARARule(content) + if fsize := info.Size(); fsize+size > maxZipSize { + return fmt.Errorf("reading %s would exceed memory limit: %v", p, err) } - if strings.Contains(p, "JndiManager.class") { - c.seenJndiManagerClass = true - c.isAtLeastTwoDotSixteen = matchesTwoSixteen(content) - } - return nil + buf := bufPool.Get().([]byte) + defer bufPool.Put(buf) + return c.checkClass(p, f, buf) } if p == "META-INF/MANIFEST.MF" { mf, err := zf.Open() @@ -235,7 +229,12 @@ func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error { return fmt.Errorf("opening manifest file %s: %v", p, err) } defer mf.Close() + + buf := bufPool.Get().([]byte) + defer bufPool.Put(buf) + s := bufio.NewScanner(mf) + s.Buffer(buf, bufio.MaxScanTokenSize) for s.Scan() { // Use s.Bytes instead of s.Text to avoid a string allocation. b := s.Bytes() @@ -300,7 +299,7 @@ func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error { return err } -var ( +const ( // Replicate YARA rule: // // strings: @@ -311,15 +310,9 @@ var ( // } // // https://github.com/darkarnium/Log4j-CVE-Detect/blob/main/rules/vulnerability/log4j/CVE-2021-44228.yar - log4JYARAPrefix = []byte{0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e} - log4JYARASuffix = []byte{ - 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, - 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, - 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, - 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, - 0x3b, 0x29, 0x56, - } + + log4jYARARulePrefix = "\x3c\x69\x6e\x69\x74\x3e" + log4jYARARuleSuffix = "\x28\x4c\x6a\x61\x76\x61\x2f\x6c\x61\x6e\x67\x2f\x53\x74\x72\x69\x6e\x67\x3b\x4c\x6a\x61\x76\x61\x78\x2f\x6e\x61\x6d\x69\x6e\x67\x2f\x43\x6f\x6e\x74\x65\x78\x74\x3b\x29\x56" // Relevant commit: https://github.com/apache/logging-log4j2/commit/44569090f1cf1e92c711fb96dfd18cd7dccc72ea // In 2.16 the JndiManager class added the method `isJndiEnabled`. This was @@ -334,31 +327,121 @@ var ( // // Since this is so brittle, we're keeping the above rule that can reliably and // non-brittle-ey detect <2.15 as a back up. - log4j216Detector = []byte("isJndiEnabled") + log4j216Pattern = "isJndiEnabled" ) -func matchesLog4JYARARule(b []byte) bool { - start := 0 - for { - i := bytes.Index(b[start:], log4JYARAPrefix) - if i < 0 { - return false +// log4jPattern is a byte-matching regular expression that checks for two +// conditions in a Java class file: +// 1. Does the YARA rule match? +// 2. Have we found the 2.16 pattern? +var log4jPattern *binaryregexp.Regexp + +func init() { + // Since this means we want to check two patterns in parallel we create all + // 4 combinations of how the patterns may appear, given that they do not + // share a matching prefix or a suffix (which they do not). + // + // The four combinations are: + // 1. [216Pattern] + // 2. [YARARulePattern] + // 3. [216Pattern.*YARARulePattern] + // 4. [YARARulePattern.*216Pattern] + // + // By creating submatches for each of these cases, we can identify which + // patterns are actually present. Also, in order to ensure (1) and (2) + // do not shadow (3) and (4), we need to look for the longest match. + yaraRule := binaryregexp.QuoteMeta(log4jYARARulePrefix) + + ".{0,3}" + binaryregexp.QuoteMeta(log4jYARARuleSuffix) + log4jPattern = binaryregexp.MustCompile( + fmt.Sprintf("(?P<216>%s)|(?P%s)|(?P<216First>%s.*%s)|(?P%s.*%s)", + log4j216Pattern, + yaraRule, + log4j216Pattern, yaraRule, + yaraRule, log4j216Pattern, + ), + ) + log4jPattern.Longest() +} + +func (c *checker) checkClass(filename string, r io.Reader, buf []byte) error { + if !c.hasLookupClass && strings.Contains(filename, "JndiLookup.class") { + c.hasLookupClass = true + } + checkForOldJndiManagerConstructor := !c.hasOldJndiManagerConstructor && strings.Contains(filename, "JndiManager") + checkJndiManagerVersion := strings.Contains(filename, "JndiManager.class") + if !checkForOldJndiManagerConstructor && !checkJndiManagerVersion { + return nil + } + if checkJndiManagerVersion { + c.seenJndiManagerClass = true + } + + br := newByteReader(r, buf) + matches := log4jPattern.FindReaderSubmatchIndex(br) + + // Error reading. + if err := br.Err(); err != nil && err != io.EOF { + return err + } + + // No match. + if matches == nil { + return nil + } + + // We have a match! + switch { + case matches[2] > 0: + // 1. [216Pattern] + if checkJndiManagerVersion { + c.isAtLeastTwoDotSixteen = true + } + case matches[4] > 0: + // 2. [YARARulePattern] + if checkForOldJndiManagerConstructor { + c.hasOldJndiManagerConstructor = true } - n := i + len(log4JYARAPrefix) - if len(b) <= n { - return false + case matches[6] > 0: + // 3. [216Pattern.*YARARulePattern] + fallthrough + case matches[8] > 0: + // 4. [YARARulePattern.*216Pattern] + if checkJndiManagerVersion { + c.isAtLeastTwoDotSixteen = true } - j := bytes.Index(b[n:], log4JYARASuffix) - if j < 0 { - return false + if checkForOldJndiManagerConstructor { + c.hasOldJndiManagerConstructor = true } - if j <= 3 { - return true + } + return nil +} + +type byteReader struct { + r io.Reader + buf []byte + off int + err error +} + +func newByteReader(r io.Reader, buf []byte) *byteReader { + return &byteReader{r: r, buf: buf[:0]} +} + +func (b *byteReader) ReadByte() (byte, error) { + for b.off == len(b.buf) { + if b.err != nil { + return 0, b.err } - start = i + len(log4JYARAPrefix) + n, err := b.r.Read(b.buf[:cap(b.buf)]) + b.err = err + b.buf = b.buf[:n] + b.off = 0 } + result := b.buf[b.off] + b.off++ + return result, nil } -func matchesTwoSixteen(b []byte) bool { - return bytes.Contains(b, log4j216Detector) +func (b *byteReader) Err() error { + return b.err } diff --git a/jar/jar_test.go b/jar/jar_test.go index 7271ed9..f9b4b20 100644 --- a/jar/jar_test.go +++ b/jar/jar_test.go @@ -15,6 +15,9 @@ package jar import ( + "bytes" + "fmt" + "io" "path/filepath" "testing" ) @@ -121,31 +124,166 @@ func BenchmarkParseParallel(b *testing.B) { }) } -func TestYARARule(t *testing.T) { - data := []byte{ - 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, - 0x00, 0x00, 0x00, - 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, - 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, - 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, - 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, - 0x3b, 0x29, 0x56, +func TestLog4jPattern(t *testing.T) { + tests := []struct { + input []byte + matchType int + }{ + {append([]byte{0x0, 0x1}, []byte("isJndiEnabled")...), 0}, + {[]byte{ + 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, + 0x00, 0x00, 0x00, + 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, + 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, + 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, + 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x3b, 0x29, 0x56, + }, 1}, + {append(make([]byte, 1000), []byte{ + 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, + 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, + 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, + 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x3b, 0x29, 0x56, + }...), -1}, + {append([]byte("isJndiEnabled"), []byte{ + 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, + 0x00, 0x00, 0x00, + 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, + 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, + 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, + 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x3b, 0x29, 0x56, + // Some random bytes. + 0xff, 0xff, 0xff, + }...), 2}, + {append([]byte{ + 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, + 0x00, 0x00, 0x00, + 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, + 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, + 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, + 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x3b, 0x29, 0x56, + // Some random bytes. + 0x15, 0x7f, 0xa5, + }, []byte("isJndiEnabled")...), 3}, + } + for _, test := range tests { + br := newByteReader(bytes.NewReader(test.input), make([]byte, 16)) + matches := log4jPattern.FindReaderSubmatchIndex(br) + if matches == nil && test.matchType >= 0 { + t.Error("expected match") + continue + } + switch test.matchType { + case 0: + if matches[(test.matchType+1)*2] < 0 { + t.Error("expected match of 2.16 only") + } + case 1: + if matches[(test.matchType+1)*2] < 0 { + t.Error("expected match of YARA rule only") + } + case 2: + if matches[(test.matchType+1)*2] < 0 { + t.Error("expected match of 2.16 then YARA rule") + } + case 3: + if matches[(test.matchType+1)*2] < 0 { + t.Error("expected match of YARA rule then then 2.16") + } + default: + if matches != nil { + t.Error("unexpected match") + } + } } - if !matchesLog4JYARARule(data) { - t.Errorf("expected to match YARA rule") +} + +func TestByteReader(t *testing.T) { + check := func(buf []byte, f func() io.Reader, expect []byte, expectErr error) { + t.Helper() + + br := newByteReader(f(), buf) + i := 0 + for { + b, err := br.ReadByte() + if err != nil { + if err != expectErr { + t.Errorf("expected error %v, got %v", expectErr, err) + } + if br.Err() != err { + t.Errorf("Err method result %v didn't match final error %v", br.Err(), err) + } + break + } + if b != expect[i] { + t.Errorf("read unexpected value %d at index %d", b, i) + break + } + i++ + } + if i != len(expect) { + t.Errorf("expected to read %d bytes, read %d bytes instead", len(expect), i) + } + } + // Intentionally reuse a buffer to see how it deals with + // a dirty buffer. + buf := make([]byte, 8192) + + small := []byte("hello world") + newSmallReader := func() io.Reader { + return bytes.NewReader(small) + } + check(buf[:5], newSmallReader, small, io.EOF) + check(buf[:1], newSmallReader, small, io.EOF) + check(buf[:103], newSmallReader, small, io.EOF) + + large := bytes.Repeat(small, 1001) + newLargeReader := func() io.Reader { + return bytes.NewReader(large) + } + check(buf[:1], newLargeReader, large, io.EOF) + check(buf[:1041], newLargeReader, large, io.EOF) + check(buf[:], newLargeReader, large, io.EOF) + + const failAfter = 105 + bad := fmt.Errorf("this is bad") + newBadReader := func() io.Reader { + return newFaultReader(bytes.NewReader(large), bad, failAfter) + } + check(buf[:4], newBadReader, large[:failAfter], bad) + check(buf[:1], newBadReader, large[:failAfter], bad) + check(buf[:971], newBadReader, large[:failAfter], bad) +} + +type faultReader struct { + io.Reader + fault error + after int + + read int +} + +func newFaultReader(r io.Reader, fault error, after int) *faultReader { + return &faultReader{r, fault, after, 0} +} + +func (f *faultReader) Read(b []byte) (int, error) { + if f.read >= f.after { + return 0, f.fault } - data2 := append(make([]byte, 1000), []byte{ - 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, - 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, - 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, - 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, - 0x3b, 0x29, 0x56, - }...) - if matchesLog4JYARARule(data2) { - t.Errorf("unexpected match on YARA rule") + n, err := f.Reader.Read(b) + f.read += n + if f.read >= f.after { + return f.after - (f.read - n), f.fault } + return n, err }