Skip to content
This repository was archived by the owner on Jun 30, 2023. It is now read-only.

Commit

Permalink
jar: share decoding buffer across files in checkJAR
Browse files Browse the repository at this point in the history
This purpose of this change is to reduce allocations in checkJAR by
sharing a zip decoding buffer across files in a JAR instead of fully
decoding each file into its own byte slice. A consequence of this change
is that the buffer is fixed size, so the matches* functions cannot
operate on byte slices as they once did. Instead, they need to operate
on an io.Reader.

So, this change also refactors class file scanning into its own method,
and rewrites the match functionality in terms of rsc.io/binaryregexp.

The performance delta of this change on BenchmarkParse is:

name              old time/op    new time/op    delta
Parse-16             156µs ± 1%      38µs ± 1%  -75.77%  (p=0.000 n=9+10)
ParseParallel-16    49.7µs ± 8%    14.1µs ± 8%  -71.68%  (p=0.000 n=10+10)

name              old alloc/op   new alloc/op   delta
Parse-16            52.2kB ± 0%    18.1kB ± 0%  -65.31%  (p=0.000 n=10+10)
ParseParallel-16    52.8kB ± 0%    18.4kB ± 0%  -65.04%  (p=0.000 n=10+10)

name              old allocs/op  new allocs/op  delta
Parse-16               112 ± 0%        54 ± 0%  -51.79%  (p=0.000 n=10+10)
ParseParallel-16       112 ± 0%        54 ± 0%  -51.79%  (p=0.000 n=10+10)
  • Loading branch information
mknyszek authored and ericchiang committed Jan 6, 2022
1 parent 96b3273 commit eb36a17
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 74 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ go 1.17
require (
github.com/google/go-cmp v0.5.6
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e
rsc.io/binaryregexp v0.2.0
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+R
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
rsc.io/binaryregexp v0.2.0 h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
183 changes: 133 additions & 50 deletions jar/jar.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ import (
"os"
"path"
"strings"
"sync"

zipfork "github.com/google/log4jscanner/third_party/zip"
"rsc.io/binaryregexp"
)

const (
Expand Down Expand Up @@ -171,6 +173,14 @@ func walkZIP(r *zip.Reader, fn func(f *zip.File) error) error {
return nil
}

const bufSize = 4 << 10 // 4 KiB

var bufPool = sync.Pool{
New: func() interface{} {
return make([]byte, bufSize)
},
}

func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error {
if depth > maxZipDepth {
return fmt.Errorf("reached max zip depth of %d", maxZipDepth)
Expand Down Expand Up @@ -203,39 +213,28 @@ func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error {
defer f.Close()

info := zf.FileInfo()
var r io.Reader = f
if fsize := info.Size(); fsize > 0 {
if fsize+size > maxZipSize {
return fmt.Errorf("reading %s would exceed memory limit: %v", p, err)
}
r = io.LimitReader(f, fsize)
}

content, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("reading file %s: %v", p, err)
}
if !c.hasLookupClass {
if strings.Contains(p, "JndiLookup.class") {
c.hasLookupClass = true
}
return fmt.Errorf("stat file %s: %v", p, err)
}
if !c.hasOldJndiManagerConstructor {
c.hasOldJndiManagerConstructor = strings.Contains(p, "JndiManager") && matchesLog4JYARARule(content)
if fsize := info.Size(); fsize+size > maxZipSize {
return fmt.Errorf("reading %s would exceed memory limit: %v", p, err)
}
if strings.Contains(p, "JndiManager.class") {
c.seenJndiManagerClass = true
c.isAtLeastTwoDotSixteen = matchesTwoSixteen(content)
}
return nil
buf := bufPool.Get().([]byte)
defer bufPool.Put(buf)
return c.checkClass(p, f, buf)
}
if p == "META-INF/MANIFEST.MF" {
mf, err := zf.Open()
if err != nil {
return fmt.Errorf("opening manifest file %s: %v", p, err)
}
defer mf.Close()

buf := bufPool.Get().([]byte)
defer bufPool.Put(buf)

s := bufio.NewScanner(mf)
s.Buffer(buf, bufio.MaxScanTokenSize)
for s.Scan() {
// Use s.Bytes instead of s.Text to avoid a string allocation.
b := s.Bytes()
Expand Down Expand Up @@ -300,7 +299,7 @@ func (c *checker) checkJAR(r *zip.Reader, depth int, size int64) error {
return err
}

var (
const (
// Replicate YARA rule:
//
// strings:
Expand All @@ -311,15 +310,9 @@ var (
// }
//
// https://github.com/darkarnium/Log4j-CVE-Detect/blob/main/rules/vulnerability/log4j/CVE-2021-44228.yar
log4JYARAPrefix = []byte{0x3c, 0x69, 0x6e, 0x69, 0x74, 0x3e}
log4JYARASuffix = []byte{
0x28, 0x4c, 0x6a, 0x61, 0x76, 0x61, 0x2f, 0x6c,
0x61, 0x6e, 0x67, 0x2f, 0x53, 0x74, 0x72, 0x69,
0x6e, 0x67, 0x3b, 0x4c, 0x6a, 0x61, 0x76, 0x61,
0x78, 0x2f, 0x6e, 0x61, 0x6d, 0x69, 0x6e, 0x67,
0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74,
0x3b, 0x29, 0x56,
}

log4jYARARulePrefix = "\x3c\x69\x6e\x69\x74\x3e"
log4jYARARuleSuffix = "\x28\x4c\x6a\x61\x76\x61\x2f\x6c\x61\x6e\x67\x2f\x53\x74\x72\x69\x6e\x67\x3b\x4c\x6a\x61\x76\x61\x78\x2f\x6e\x61\x6d\x69\x6e\x67\x2f\x43\x6f\x6e\x74\x65\x78\x74\x3b\x29\x56"

// Relevant commit: https://github.com/apache/logging-log4j2/commit/44569090f1cf1e92c711fb96dfd18cd7dccc72ea
// In 2.16 the JndiManager class added the method `isJndiEnabled`. This was
Expand All @@ -334,31 +327,121 @@ var (
//
// Since this is so brittle, we're keeping the above rule that can reliably and
// non-brittle-ey detect <2.15 as a back up.
log4j216Detector = []byte("isJndiEnabled")
log4j216Pattern = "isJndiEnabled"
)

func matchesLog4JYARARule(b []byte) bool {
start := 0
for {
i := bytes.Index(b[start:], log4JYARAPrefix)
if i < 0 {
return false
// log4jPattern is a byte-matching regular expression that checks for two
// conditions in a Java class file:
// 1. Does the YARA rule match?
// 2. Have we found the 2.16 pattern?
var log4jPattern *binaryregexp.Regexp

func init() {
// Since this means we want to check two patterns in parallel we create all
// 4 combinations of how the patterns may appear, given that they do not
// share a matching prefix or a suffix (which they do not).
//
// The four combinations are:
// 1. [216Pattern]
// 2. [YARARulePattern]
// 3. [216Pattern.*YARARulePattern]
// 4. [YARARulePattern.*216Pattern]
//
// By creating submatches for each of these cases, we can identify which
// patterns are actually present. Also, in order to ensure (1) and (2)
// do not shadow (3) and (4), we need to look for the longest match.
yaraRule := binaryregexp.QuoteMeta(log4jYARARulePrefix) +
".{0,3}" + binaryregexp.QuoteMeta(log4jYARARuleSuffix)
log4jPattern = binaryregexp.MustCompile(
fmt.Sprintf("(?P<216>%s)|(?P<YARA>%s)|(?P<216First>%s.*%s)|(?P<YARAFirst>%s.*%s)",
log4j216Pattern,
yaraRule,
log4j216Pattern, yaraRule,
yaraRule, log4j216Pattern,
),
)
log4jPattern.Longest()
}

func (c *checker) checkClass(filename string, r io.Reader, buf []byte) error {
if !c.hasLookupClass && strings.Contains(filename, "JndiLookup.class") {
c.hasLookupClass = true
}
checkForOldJndiManagerConstructor := !c.hasOldJndiManagerConstructor && strings.Contains(filename, "JndiManager")
checkJndiManagerVersion := strings.Contains(filename, "JndiManager.class")
if !checkForOldJndiManagerConstructor && !checkJndiManagerVersion {
return nil
}
if checkJndiManagerVersion {
c.seenJndiManagerClass = true
}

br := newByteReader(r, buf)
matches := log4jPattern.FindReaderSubmatchIndex(br)

// Error reading.
if err := br.Err(); err != nil && err != io.EOF {
return err
}

// No match.
if matches == nil {
return nil
}

// We have a match!
switch {
case matches[2] > 0:
// 1. [216Pattern]
if checkJndiManagerVersion {
c.isAtLeastTwoDotSixteen = true
}
case matches[4] > 0:
// 2. [YARARulePattern]
if checkForOldJndiManagerConstructor {
c.hasOldJndiManagerConstructor = true
}
n := i + len(log4JYARAPrefix)
if len(b) <= n {
return false
case matches[6] > 0:
// 3. [216Pattern.*YARARulePattern]
fallthrough
case matches[8] > 0:
// 4. [YARARulePattern.*216Pattern]
if checkJndiManagerVersion {
c.isAtLeastTwoDotSixteen = true
}
j := bytes.Index(b[n:], log4JYARASuffix)
if j < 0 {
return false
if checkForOldJndiManagerConstructor {
c.hasOldJndiManagerConstructor = true
}
if j <= 3 {
return true
}
return nil
}

type byteReader struct {
r io.Reader
buf []byte
off int
err error
}

func newByteReader(r io.Reader, buf []byte) *byteReader {
return &byteReader{r: r, buf: buf[:0]}
}

func (b *byteReader) ReadByte() (byte, error) {
for b.off == len(b.buf) {
if b.err != nil {
return 0, b.err
}
start = i + len(log4JYARAPrefix)
n, err := b.r.Read(b.buf[:cap(b.buf)])
b.err = err
b.buf = b.buf[:n]
b.off = 0
}
result := b.buf[b.off]
b.off++
return result, nil
}

func matchesTwoSixteen(b []byte) bool {
return bytes.Contains(b, log4j216Detector)
func (b *byteReader) Err() error {
return b.err
}
Loading

0 comments on commit eb36a17

Please sign in to comment.