diff --git a/go.mod b/go.mod index d5f4902..5ce1d8d 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.17 require ( github.com/google/go-cmp v0.5.6 golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e + rsc.io/binaryregexp v0.2.0 ) diff --git a/go.sum b/go.sum index 3b904ea..247412a 100644 --- a/go.sum +++ b/go.sum @@ -4,3 +4,5 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+R golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE= +rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= diff --git a/jar/jar.go b/jar/jar.go index 08ca77e..a1793c1 100644 --- a/jar/jar.go +++ b/jar/jar.go @@ -26,8 +26,10 @@ import ( "os" "path" "strings" + "sync" zipfork "github.com/google/log4jscanner/third_party/zip" + "rsc.io/binaryregexp" ) const ( @@ -171,6 +173,14 @@ func walkZIP(r *zip.Reader, fn func(f *zip.File) error) error { return nil } +const bufSize = 4 << 10 // 4 KiB + +var bufPool = sync.Pool{ + New: func() interface{} { + return make([]byte, bufSize) + }, +} + func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error { if depth > maxZipDepth { return fmt.Errorf("reached max zip depth of %d", maxZipDepth) @@ -203,31 +213,15 @@ func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error { defer f.Close() info := zf.FileInfo() - var r io.Reader = f - if fsize := info.Size(); fsize > 0 { - if fsize+size > maxZipSize { - return fmt.Errorf("reading %s would exceed memory limit: %v", p, err) - } - r = io.LimitReader(f, fsize) - } - - content, err := io.ReadAll(r) if err != nil { - return fmt.Errorf("reading file %s: %v", p, err) - } - if !c.hasLookupClass { - if strings.Contains(p, "JndiLookup.class") { - c.hasLookupClass = true - } + return fmt.Errorf("stat file %s: %v", p, err) } - if !c.hasOldJndiManagerConstructor { - c.hasOldJndiManagerConstructor = strings.Contains(p, "JndiManager") && matchesLog4JYARARule(content) + if fsize := info.Size(); fsize+size > maxZipSize { + return fmt.Errorf("reading %s would exceed memory limit: %v", p, err) } - if strings.Contains(p, "JndiManager.class") { - c.seenJndiManagerClass = true - c.isAtLeastTwoDotSixteen = matchesTwoSixteen(content) - } - return nil + buf := bufPool.Get().([]byte) + defer bufPool.Put(buf) + return c.checkClass(p, f, buf) } if p == "META-INF/MANIFEST.MF" { mf, err := zf.Open() @@ -235,7 +229,12 @@ func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error { return fmt.Errorf("opening manifest file %s: %v", p, err) } defer mf.Close() + + buf := bufPool.Get().([]byte) + defer bufPool.Put(buf) + s := bufio.NewScanner(mf) + s.Buffer(buf, bufio.MaxScanTokenSize) for s.Scan() { // Use s.Bytes instead of s.Text to avoid a string allocation. b := s.Bytes() @@ -300,7 +299,7 @@ func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error { return err } -var ( +const ( // Replicate YARA rule: // // strings: @@ -311,15 +310,9 @@ var ( // } // // https://github.com/darkarnium/Log4j-CVE-Detect/blob/main/rules/vulnerability/log4j/CVE-2021-44228.yar - log4JYARAPrefix = []byte{0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e} - log4JYARASuffix = []byte{ - 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, - 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, - 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, - 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, - 0x3b, 0x29, 0x56, - } + + log4jYARARulePrefix = "\x3c\x69\x6e\x69\x74\x3e" + log4jYARARuleSuffix = "\x28\x4c\x6a\x61\x76\x61\x2f\x6c\x61\x6e\x67\x2f\x53\x74\x72\x69\x6e\x67\x3b\x4c\x6a\x61\x76\x61\x78\x2f\x6e\x61\x6d\x69\x6e\x67\x2f\x43\x6f\x6e\x74\x65\x78\x74\x3b\x29\x56" // Relevant commit: https://github.com/apache/logging-log4j2/commit/44569090f1cf1e92c711fb96dfd18cd7dccc72ea // In 2.16 the JndiManager class added the method `isJndiEnabled`. This was @@ -334,31 +327,121 @@ var ( // // Since this is so brittle, we're keeping the above rule that can reliably and // non-brittle-ey detect <2.15 as a back up. - log4j216Detector = []byte("isJndiEnabled") + log4j216Pattern = "isJndiEnabled" ) -func matchesLog4JYARARule(b []byte) bool { - start := 0 - for { - i := bytes.Index(b[start:], log4JYARAPrefix) - if i < 0 { - return false +// log4jPattern is a byte-matching regular expression that checks for two +// conditions in a Java class file: +// 1. Does the YARA rule match? +// 2. Have we found the 2.16 pattern? +var log4jPattern *binaryregexp.Regexp + +func init() { + // Since this means we want to check two patterns in parallel we create all + // 4 combinations of how the patterns may appear, given that they do not + // share a matching prefix or a suffix (which they do not). + // + // The four combinations are: + // 1. [216Pattern] + // 2. [YARARulePattern] + // 3. [216Pattern.*YARARulePattern] + // 4. [YARARulePattern.*216Pattern] + // + // By creating submatches for each of these cases, we can identify which + // patterns are actually present. Also, in order to ensure (1) and (2) + // do not shadow (3) and (4), we need to look for the longest match. + yaraRule := binaryregexp.QuoteMeta(log4jYARARulePrefix) + + ".{0,3}" + binaryregexp.QuoteMeta(log4jYARARuleSuffix) + log4jPattern = binaryregexp.MustCompile( + fmt.Sprintf("(?P<216>%s)|(?P%s)|(?P<216First>%s.*%s)|(?P%s.*%s)", + log4j216Pattern, + yaraRule, + log4j216Pattern, yaraRule, + yaraRule, log4j216Pattern, + ), + ) + log4jPattern.Longest() +} + +func (c *checker) checkClass(filename string, r io.Reader, buf []byte) error { + if !c.hasLookupClass && strings.Contains(filename, "JndiLookup.class") { + c.hasLookupClass = true + } + checkForOldJndiManagerConstructor := !c.hasOldJndiManagerConstructor && strings.Contains(filename, "JndiManager") + checkJndiManagerVersion := strings.Contains(filename, "JndiManager.class") + if !checkForOldJndiManagerConstructor && !checkJndiManagerVersion { + return nil + } + if checkJndiManagerVersion { + c.seenJndiManagerClass = true + } + + br := newByteReader(r, buf) + matches := log4jPattern.FindReaderSubmatchIndex(br) + + // Error reading. + if err := br.Err(); err != nil && err != io.EOF { + return err + } + + // No match. + if matches == nil { + return nil + } + + // We have a match! + switch { + case matches[2] > 0: + // 1. [216Pattern] + if checkJndiManagerVersion { + c.isAtLeastTwoDotSixteen = true + } + case matches[4] > 0: + // 2. [YARARulePattern] + if checkForOldJndiManagerConstructor { + c.hasOldJndiManagerConstructor = true } - n := i + len(log4JYARAPrefix) - if len(b) <= n { - return false + case matches[6] > 0: + // 3. [216Pattern.*YARARulePattern] + fallthrough + case matches[8] > 0: + // 4. [YARARulePattern.*216Pattern] + if checkJndiManagerVersion { + c.isAtLeastTwoDotSixteen = true } - j := bytes.Index(b[n:], log4JYARASuffix) - if j < 0 { - return false + if checkForOldJndiManagerConstructor { + c.hasOldJndiManagerConstructor = true } - if j <= 3 { - return true + } + return nil +} + +type byteReader struct { + r io.Reader + buf []byte + off int + err error +} + +func newByteReader(r io.Reader, buf []byte) *byteReader { + return &byteReader{r: r, buf: buf[:0]} +} + +func (b *byteReader) ReadByte() (byte, error) { + for b.off == len(b.buf) { + if b.err != nil { + return 0, b.err } - start = i + len(log4JYARAPrefix) + n, err := b.r.Read(b.buf[:cap(b.buf)]) + b.err = err + b.buf = b.buf[:n] + b.off = 0 } + result := b.buf[b.off] + b.off++ + return result, nil } -func matchesTwoSixteen(b []byte) bool { - return bytes.Contains(b, log4j216Detector) +func (b *byteReader) Err() error { + return b.err } diff --git a/jar/jar_test.go b/jar/jar_test.go index 7271ed9..f9b4b20 100644 --- a/jar/jar_test.go +++ b/jar/jar_test.go @@ -15,6 +15,9 @@ package jar import ( + "bytes" + "fmt" + "io" "path/filepath" "testing" ) @@ -121,31 +124,166 @@ func BenchmarkParseParallel(b *testing.B) { }) } -func TestYARARule(t *testing.T) { - data := []byte{ - 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, - 0x00, 0x00, 0x00, - 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, - 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, - 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, - 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, - 0x3b, 0x29, 0x56, +func TestLog4jPattern(t *testing.T) { + tests := []struct { + input []byte + matchType int + }{ + {append([]byte{0x0, 0x1}, []byte("isJndiEnabled")...), 0}, + {[]byte{ + 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, + 0x00, 0x00, 0x00, + 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, + 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, + 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, + 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x3b, 0x29, 0x56, + }, 1}, + {append(make([]byte, 1000), []byte{ + 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, + 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, + 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, + 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x3b, 0x29, 0x56, + }...), -1}, + {append([]byte("isJndiEnabled"), []byte{ + 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, + 0x00, 0x00, 0x00, + 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, + 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, + 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, + 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x3b, 0x29, 0x56, + // Some random bytes. + 0xff, 0xff, 0xff, + }...), 2}, + {append([]byte{ + 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, + 0x00, 0x00, 0x00, + 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, + 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, + 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, + 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x3b, 0x29, 0x56, + // Some random bytes. + 0x15, 0x7f, 0xa5, + }, []byte("isJndiEnabled")...), 3}, + } + for _, test := range tests { + br := newByteReader(bytes.NewReader(test.input), make([]byte, 16)) + matches := log4jPattern.FindReaderSubmatchIndex(br) + if matches == nil && test.matchType >= 0 { + t.Error("expected match") + continue + } + switch test.matchType { + case 0: + if matches[(test.matchType+1)*2] < 0 { + t.Error("expected match of 2.16 only") + } + case 1: + if matches[(test.matchType+1)*2] < 0 { + t.Error("expected match of YARA rule only") + } + case 2: + if matches[(test.matchType+1)*2] < 0 { + t.Error("expected match of 2.16 then YARA rule") + } + case 3: + if matches[(test.matchType+1)*2] < 0 { + t.Error("expected match of YARA rule then then 2.16") + } + default: + if matches != nil { + t.Error("unexpected match") + } + } } - if !matchesLog4JYARARule(data) { - t.Errorf("expected to match YARA rule") +} + +func TestByteReader(t *testing.T) { + check := func(buf []byte, f func() io.Reader, expect []byte, expectErr error) { + t.Helper() + + br := newByteReader(f(), buf) + i := 0 + for { + b, err := br.ReadByte() + if err != nil { + if err != expectErr { + t.Errorf("expected error %v, got %v", expectErr, err) + } + if br.Err() != err { + t.Errorf("Err method result %v didn't match final error %v", br.Err(), err) + } + break + } + if b != expect[i] { + t.Errorf("read unexpected value %d at index %d", b, i) + break + } + i++ + } + if i != len(expect) { + t.Errorf("expected to read %d bytes, read %d bytes instead", len(expect), i) + } + } + // Intentionally reuse a buffer to see how it deals with + // a dirty buffer. + buf := make([]byte, 8192) + + small := []byte("hello world") + newSmallReader := func() io.Reader { + return bytes.NewReader(small) + } + check(buf[:5], newSmallReader, small, io.EOF) + check(buf[:1], newSmallReader, small, io.EOF) + check(buf[:103], newSmallReader, small, io.EOF) + + large := bytes.Repeat(small, 1001) + newLargeReader := func() io.Reader { + return bytes.NewReader(large) + } + check(buf[:1], newLargeReader, large, io.EOF) + check(buf[:1041], newLargeReader, large, io.EOF) + check(buf[:], newLargeReader, large, io.EOF) + + const failAfter = 105 + bad := fmt.Errorf("this is bad") + newBadReader := func() io.Reader { + return newFaultReader(bytes.NewReader(large), bad, failAfter) + } + check(buf[:4], newBadReader, large[:failAfter], bad) + check(buf[:1], newBadReader, large[:failAfter], bad) + check(buf[:971], newBadReader, large[:failAfter], bad) +} + +type faultReader struct { + io.Reader + fault error + after int + + read int +} + +func newFaultReader(r io.Reader, fault error, after int) *faultReader { + return &faultReader{r, fault, after, 0} +} + +func (f *faultReader) Read(b []byte) (int, error) { + if f.read >= f.after { + return 0, f.fault } - data2 := append(make([]byte, 1000), []byte{ - 0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c, - 0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61, - 0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67, - 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, - 0x3b, 0x29, 0x56, - }...) - if matchesLog4JYARARule(data2) { - t.Errorf("unexpected match on YARA rule") + n, err := f.Reader.Read(b) + f.read += n + if f.read >= f.after { + return f.after - (f.read - n), f.fault } + return n, err }