diff --git a/blankParam_checker.go b/blankParam_checker.go new file mode 100644 index 0000000..76b5a35 --- /dev/null +++ b/blankParam_checker.go @@ -0,0 +1,71 @@ +package checkers + +import ( + "go/ast" + + "github.com/go-lintpack/lintpack" + "github.com/go-lintpack/lintpack/astwalk" +) + +func init() { + var info lintpack.CheckerInfo + info.Name = "blankParam" + info.Tags = []string{"style", "opinionated", "experimental"} + info.Summary = "Detects unused params and suggests to name them as `_` (blank)" + info.Before = `func f(a int, b float64) // b isn't used inside function body` + info.After = `func f(a int, _ float64) // everything is cool` + + collection.AddChecker(&info, func(ctx *lintpack.CheckerContext) lintpack.FileWalker { + return astwalk.WalkerForFuncDecl(&blankParamChecker{ctx: ctx}) + }) +} + +type blankParamChecker struct { + astwalk.WalkHandler + ctx *lintpack.CheckerContext +} + +func (c *blankParamChecker) VisitFuncDecl(decl *ast.FuncDecl) { + params := decl.Type.Params + if decl.Body == nil || params == nil || params.NumFields() == 0 { + return + } + + // collect all params to map + objToIdent := make(map[*ast.Object]*ast.Ident) + for _, p := range params.List { + if len(p.Names) == 0 { + c.warnUnnamed(p) + return + } + for _, id := range p.Names { + if id.Name != "_" { + objToIdent[id.Obj] = id + } + } + } + + // remove used params + // TODO(cristaloleg): we might want to have less iterations here. + for id := range c.ctx.TypesInfo.Uses { + if _, ok := objToIdent[id.Obj]; ok { + delete(objToIdent, id.Obj) + if len(objToIdent) == 0 { + return + } + } + } + + // all params that are left are unused + for _, id := range objToIdent { + c.warn(id) + } +} + +func (c *blankParamChecker) warn(param *ast.Ident) { + c.ctx.Warn(param, "rename `%s` to `_`", param) +} + +func (c *blankParamChecker) warnUnnamed(n ast.Node) { + c.ctx.Warn(n, "consider to name parameters as `_`") +} diff --git a/boolFuncPrefix_checker.go b/boolFuncPrefix_checker.go new file mode 100644 index 0000000..7e6f5a4 --- /dev/null +++ b/boolFuncPrefix_checker.go @@ -0,0 +1,68 @@ +package checkers + +import ( + "go/ast" + "go/types" + "strings" + + "github.com/go-lintpack/lintpack" + "github.com/go-lintpack/lintpack/astwalk" +) + +func init() { + var info lintpack.CheckerInfo + info.Name = "boolFuncPrefix" + info.Tags = []string{"style", "experimental", "opinionated"} + info.Summary = "Detects function returning only bool and suggests to add Is/Has/Contains prefix to it's name" + info.Before = "func Enabled() bool" + info.After = "func IsEnabled() bool" + + collection.AddChecker(&info, func(ctx *lintpack.CheckerContext) lintpack.FileWalker { + return astwalk.WalkerForFuncDecl(&boolFuncPrefixChecker{ctx: ctx}) + }) +} + +type boolFuncPrefixChecker struct { + astwalk.WalkHandler + ctx *lintpack.CheckerContext +} + +func (c *boolFuncPrefixChecker) VisitFuncDecl(decl *ast.FuncDecl) { + params := decl.Type.Params + results := decl.Type.Results + + if params.NumFields() > 0 || + results.NumFields() != 1 || + !c.isBoolType(results.List[0].Type) || + c.hasProperPrefix(decl.Name.Name) { + return + } + c.warn(decl) +} + +func (c *boolFuncPrefixChecker) warn(fn *ast.FuncDecl) { + c.ctx.Warn(fn, "consider to add Is/Has/Contains prefix to function name") +} + +func (c *boolFuncPrefixChecker) isBoolType(expr ast.Expr) bool { + typ, ok := c.ctx.TypesInfo.TypeOf(expr).(*types.Basic) + return ok && typ.Kind() == types.Bool +} + +func (c *boolFuncPrefixChecker) hasProperPrefix(name string) bool { + name = strings.ToLower(name) + excluded := []string{"exit", "quit"} + for _, ex := range excluded { + if name == ex { + return true + } + } + + prefixes := []string{"is", "has", "contains", "check", "get", "should", "need", "may", "should"} + for _, p := range prefixes { + if strings.HasPrefix(name, p) { + return true + } + } + return false +} diff --git a/floatCompare_checker.go b/floatCompare_checker.go new file mode 100644 index 0000000..931a42c --- /dev/null +++ b/floatCompare_checker.go @@ -0,0 +1,107 @@ +package checkers + +import ( + "go/ast" + "go/token" + "go/types" + + "github.com/go-lintpack/lintpack" + "github.com/go-lintpack/lintpack/astwalk" + "github.com/go-toolsmith/astequal" + "github.com/go-toolsmith/astp" + "golang.org/x/tools/go/ast/astutil" +) + +func init() { + var info lintpack.CheckerInfo + info.Name = "floatCompare" + info.Tags = []string{"diagnostic", "experimental"} + info.Summary = "Detects fragile float variables comparisons" + info.Before = ` +// x and y are floats +return x == y` + info.After = ` +// x and y are floats +return math.Abs(x - y) < eps` + + collection.AddChecker(&info, func(ctx *lintpack.CheckerContext) lintpack.FileWalker { + return astwalk.WalkerForLocalExpr(&floatCompareChecker{ctx: ctx}) + }) +} + +type floatCompareChecker struct { + astwalk.WalkHandler + ctx *lintpack.CheckerContext +} + +func (c *floatCompareChecker) VisitLocalExpr(expr ast.Expr) { + binexpr, ok := expr.(*ast.BinaryExpr) + if !ok { + return + } + if binexpr.Op == token.EQL || binexpr.Op == token.NEQ { + if c.isFloatExpr(binexpr) && + c.isMultiBinaryExpr(binexpr) && + !c.isNaNCheckExpr(binexpr) && + !c.isInfCheckExpr(binexpr) { + c.warn(binexpr) + } + } +} + +func (c *floatCompareChecker) isFloatExpr(binexpr *ast.BinaryExpr) bool { + exprx := astutil.Unparen(binexpr.X) + typx, ok := c.ctx.TypesInfo.TypeOf(exprx).(*types.Basic) + return ok && typx.Info()&types.IsFloat != 0 +} + +func (c *floatCompareChecker) isNaNCheckExpr(binexpr *ast.BinaryExpr) bool { + exprx := astutil.Unparen(binexpr.X) + expry := astutil.Unparen(binexpr.Y) + return astequal.Expr(exprx, expry) +} + +func (c *floatCompareChecker) isInfCheckExpr(binexpr *ast.BinaryExpr) bool { + binx := astutil.Unparen(binexpr.X) + biny := astutil.Unparen(binexpr.Y) + expr, bin, ok := c.identExpr(binx, biny) + if !ok { + return false + } + x := astutil.Unparen(expr) + y := astutil.Unparen(bin.X) + z := astutil.Unparen(bin.Y) + return astequal.Expr(x, y) && astequal.Expr(y, z) +} + +func (c *floatCompareChecker) isMultiBinaryExpr(binexpr *ast.BinaryExpr) bool { + exprx := astutil.Unparen(binexpr.X) + expry := astutil.Unparen(binexpr.Y) + return astp.IsBinaryExpr(exprx) || astp.IsBinaryExpr(expry) +} + +func (c *floatCompareChecker) identExpr(x, y ast.Node) (ast.Expr, *ast.BinaryExpr, bool) { + expr1, ok1 := x.(*ast.BinaryExpr) + expr2, ok2 := y.(*ast.BinaryExpr) + switch { + case ok1 && !ok2: + return y.(ast.Expr), expr1, true + case !ok1 && ok2: + return x.(ast.Expr), expr2, true + + default: + return nil, nil, false + } +} + +func (c *floatCompareChecker) warn(expr *ast.BinaryExpr) { + op := ">=" + if expr.Op == token.EQL { + op = "<" + } + format := "change `%s` to `math.Abs(%s - %s) %s eps`" + if astp.IsBinaryExpr(expr.Y) { + format = "change `%s` to `math.Abs(%s - (%s)) %s eps`" + } + c.ctx.Warn(expr, format, expr, expr.X, expr.Y, op) +} diff --git a/indexOnlyLoop_checker.go b/indexOnlyLoop_checker.go new file mode 100644 index 0000000..2e63cdf --- /dev/null +++ b/indexOnlyLoop_checker.go @@ -0,0 +1,83 @@ +package checkers + +import ( + "go/ast" + "go/types" + + "github.com/go-lintpack/lintpack" + "github.com/go-lintpack/lintpack/astwalk" + "github.com/go-toolsmith/astequal" + "github.com/go-toolsmith/typep" +) + +func init() { + var info lintpack.CheckerInfo + info.Name = "indexOnlyLoop" + info.Tags = []string{"style", "experimental"} + info.Summary = "Detects for loops that can benefit from rewrite to range loop" + info.Details = "Suggests to use for key, v := range container form." + info.Before = ` +for i := range files { + if files[i] != nil { + files[i].Close() + } +}` + info.After = ` +for _, f := range files { + if f != nil { + f.Close() + } +}` + + collection.AddChecker(&info, func(ctx *lintpack.CheckerContext) lintpack.FileWalker { + return astwalk.WalkerForStmt(&indexOnlyLoopChecker{ctx: ctx}) + }) +} + +type indexOnlyLoopChecker struct { + astwalk.WalkHandler + ctx *lintpack.CheckerContext +} + +func (c *indexOnlyLoopChecker) VisitStmt(stmt ast.Stmt) { + rng, ok := stmt.(*ast.RangeStmt) + if !ok || rng.Key == nil || rng.Value != nil { + return + } + iterated := c.ctx.TypesInfo.ObjectOf(identOf(rng.X)) + if iterated == nil || !c.elemTypeIsPtr(iterated) { + return // To avoid redundant traversals + } + count := 0 + ast.Inspect(rng.Body, func(n ast.Node) bool { + if n, ok := n.(*ast.IndexExpr); ok { + if !astequal.Expr(n.Index, rng.Key) { + return true + } + if iterated == c.ctx.TypesInfo.ObjectOf(identOf(n.X)) { + count++ + } + } + // Stop DFS traverse if we found more than one usage. + return count < 2 + }) + if count > 1 { + c.warn(stmt, rng.Key, iterated.Name()) + } +} + +func (c *indexOnlyLoopChecker) elemTypeIsPtr(obj types.Object) bool { + switch typ := obj.Type().(type) { + case *types.Slice: + return typep.IsPointer(typ.Elem()) + case *types.Array: + return typep.IsPointer(typ.Elem()) + default: + return false + } +} + +func (c *indexOnlyLoopChecker) warn(x, key ast.Node, iterated string) { + c.ctx.Warn(x, "%s occurs more than once in the loop; consider using for _, value := range %s", + key, iterated) +} diff --git a/sqlRowsClose_checker.go b/sqlRowsClose_checker.go new file mode 100644 index 0000000..9412166 --- /dev/null +++ b/sqlRowsClose_checker.go @@ -0,0 +1,144 @@ +package checkers + +import ( + "go/ast" + "go/types" + + "github.com/go-lintpack/lintpack" + "github.com/go-lintpack/lintpack/astwalk" +) + +func init() { + var info lintpack.CheckerInfo + info.Name = "sqlRowsClose" + info.Tags = []string{"diagnostic", "experimental"} + info.Summary = "Detects uses of *sql.Rows without call Close method" + info.Before = ` +rows, _ := db.Query( /**/ ) +for rows.Next { +}` + info.After = ` +rows, _ := db.Query( /**/ ) +for rows.Next { +} +rows.Close()` + + collection.AddChecker(&info, func(ctx *lintpack.CheckerContext) lintpack.FileWalker { + return astwalk.WalkerForFuncDecl(&sqlRowsCloseChecker{ctx: ctx}) + }) +} + +type sqlRowsCloseChecker struct { + astwalk.WalkHandler + ctx *lintpack.CheckerContext +} + +// Warning if sql.Rows local variables (including function parameters): +// 1. Not using as parameter in other functions call; +// 2. Not returning in functions results; +// 3. Not call Close method for variable; + +func (c *sqlRowsCloseChecker) VisitFuncDecl(decl *ast.FuncDecl) { + const rowsTypePTR = "*database/sql.Rows" + + localVars := []types.Object{} + returnVars := []types.Object{} + closeVars := []types.Object{} + + for _, b := range decl.Body.List { + switch b := b.(type) { + case *ast.AssignStmt: + // Detect local vars with sql.Rows types + if b.Lhs != nil { + for _, l := range b.Lhs { + if c.typeString(l) == rowsTypePTR { + localVars = append(localVars, c.getType(l)) + } + } + } + case *ast.ReturnStmt: + // Detect return vars with sql.Rows types + if b.Results != nil && len(b.Results) > 0 { + for _, r := range b.Results { + if c.typeString(r) == rowsTypePTR { + returnVars = append(returnVars, c.getType(r)) + } + } + } + case *ast.ExprStmt: + if sel := c.getCloseSelectorExpr(b.X); sel != nil { + closeVars = append(closeVars, c.getType(sel.X)) + } + case *ast.DeferStmt: + // Detect call Close for sql.Rows variables over defer declaration + + // is it `defer rowsVar.Close()` + if b, ok := b.Call.Fun.(*ast.SelectorExpr); ok { + funcName := qualifiedName(b.Sel) + if funcName == "Close" { + closeVars = append(closeVars, c.getType(b.X)) + } + } + + // looking for `rowsVar.Close()` inside `defer func() { ... }()` + if f, ok := b.Call.Fun.(*ast.FuncLit); ok { + if f.Body == nil || f.Body.List == nil { + continue + } + + for _, s := range f.Body.List { + expr, ok := s.(*ast.ExprStmt) + if !ok { + continue + } + ss := c.getCloseSelectorExpr(expr.X) + if ss != nil { + closeVars = append(closeVars, c.getType(ss.X)) + } + } + } + } + } + + // Check local variables + for _, l := range localVars { + // If local variable present in return or Close present - PASS + if !c.varInList(l, returnVars) && !c.varInList(l, closeVars) { + c.ctx.Warn(l.Parent(), "local variable db.Rows have not Close call") + } + } +} + +func (c *sqlRowsCloseChecker) getType(x ast.Node) types.Object { + return c.ctx.TypesInfo.ObjectOf(identOf(x)) +} + +func (c *sqlRowsCloseChecker) getCloseSelectorExpr(x ast.Node) *ast.SelectorExpr { + call, ok := x.(*ast.CallExpr) + if !ok { + return nil + } + if bb, ok := call.Fun.(*ast.SelectorExpr); ok { + // Detect call Close for sql.Rows variables + if qualifiedName(bb.Sel) == "Close" { + return bb + } + } + return nil +} + +func (c *sqlRowsCloseChecker) typeString(x ast.Expr) string { + if typ := c.ctx.TypesInfo.TypeOf(x); typ != nil { + return typ.String() + } + return "" +} + +func (c *sqlRowsCloseChecker) varInList(v types.Object, list []types.Object) bool { + for _, r := range list { + if v == r { + return true + } + } + return false +} diff --git a/testdata/blankParam/negative_tests.go b/testdata/blankParam/negative_tests.go new file mode 100644 index 0000000..af27fa4 --- /dev/null +++ b/testdata/blankParam/negative_tests.go @@ -0,0 +1,14 @@ +package checker_test + +func (f *foo) g1(x int, _ float64) { + _ = x +} + +func (f *foo) g2(_ int, _ float64) { +} + +func g3() (_ int, _ float64) { + return 0, 0 +} + +func external(x int) diff --git a/testdata/blankParam/positive_tests.go b/testdata/blankParam/positive_tests.go new file mode 100644 index 0000000..3a4c602 --- /dev/null +++ b/testdata/blankParam/positive_tests.go @@ -0,0 +1,26 @@ +package checker_test + +type foo struct{} + +/// rename `x` to `_` +func f1(x int, y float64) { + _ = y +} + +/// rename `x` to `_` +/// rename `y` to `_` +func (*foo) f2(x int, y int) { +} + +/// rename `y` to `_` +func (f *foo) f3(_, y int, z float64) { + _ = z +} + +/// consider to name parameters as `_` +func (f *foo) f4(int, float64) { +} + +/// rename `x` to `_` +func (f *foo) f5(x int, _ float64) { +} diff --git a/testdata/boolFuncPrefix/negative_tests.go b/testdata/boolFuncPrefix/negative_tests.go new file mode 100644 index 0000000..3950c3c --- /dev/null +++ b/testdata/boolFuncPrefix/negative_tests.go @@ -0,0 +1,16 @@ +package checker_test + +func IsEnabled() bool { return true } + +func HasString(s string, ss ...string) bool { return true } + +func (f *foo) ContainsFlag(flag int) bool { return true } + +func (f *foo) isEnabled() bool { return true } + +func has(s string, ss ...string) bool { return true } + +func (f *foo) containsFlag(flag int) bool { return true } + +func (f *foo) exit() bool { return true } +func quit() bool { return true } diff --git a/testdata/boolFuncPrefix/positive_tests.go b/testdata/boolFuncPrefix/positive_tests.go new file mode 100644 index 0000000..0cc2531 --- /dev/null +++ b/testdata/boolFuncPrefix/positive_tests.go @@ -0,0 +1,9 @@ +package checker_test + +type foo struct{} + +/// consider to add Is/Has/Contains prefix to function name +func enabled() bool { return true } + +/// consider to add Is/Has/Contains prefix to function name +func (f *foo) active() bool { return true } diff --git a/testdata/floatCompare/negative_tests.go b/testdata/floatCompare/negative_tests.go new file mode 100644 index 0000000..f521c2a --- /dev/null +++ b/testdata/floatCompare/negative_tests.go @@ -0,0 +1,65 @@ +package checker_test + +func foo1() float64 { + return 9.0 +} + +func g1() bool { + var a, b float64 = 10.0, 20.0 + return a < b +} + +func g2() bool { + var a, b int = 10, 20 + return a == b +} + +func g3() { + var a, b float64 = 10.0, 20.0 + + x := []float64{4.0, 5.0, 9.0, 10.0} + + var p *float64 = &a + + _ = b == 0.5 + + _ = 'a' == 'b' + + _ = a >= b + + _ = a == b + + _ = a != 0.6 + + _ = b < a && a >= b + + _ = a+b != a+b + + _ = a+b != (a + b) + + _ = (a + b) != (a + b) + + _ = foo1() == a + + _ = x[0] == a + + _ = *p == a + + _ = a == a+a + + _ = a+a != a + + _ = (a + a) == a + + _ = foo1()+foo1() == foo1() + + _ = x[0] != x[0] + + _ = x[0]+x[0] == x[0] + + _ = *p+*p == *p + + _ = x[0]+(x[0]) == x[0] + + _ = a+b == a+b +} diff --git a/testdata/floatCompare/positive_tests.go b/testdata/floatCompare/positive_tests.go new file mode 100644 index 0000000..50ca2f0 --- /dev/null +++ b/testdata/floatCompare/positive_tests.go @@ -0,0 +1,32 @@ +package checker_test + +func foo() float32 { + return 9.0 +} + +func f1() bool { + var p float32 = 10.0 + var q float32 = 5.0 + /// change `(p + q) == foo()` to `math.Abs((p + q) - foo()) < eps` + return (p + q) == foo() +} + +func f2() bool { + var a float32 = 10.0 + var b float32 = 5.0 + /// change `2*a+4*b == 0.5` to `math.Abs(2*a + 4*b - 0.5) < eps` + return 2*a+4*b == 0.5 +} + +func f3() bool { + var a float32 = 10.0 + var c float32 = 5.0 + /// change `c == (a + 5)` to `math.Abs(c - (a + 5)) < eps` + return c == (a + 5) +} + +func f4() bool { + var a float32 = 10.0 + /// change `foo()+foo() == a+a` to `math.Abs(foo() + foo() - (a + a)) < eps` + return foo()+foo() == a+a +} diff --git a/testdata/indexOnlyLoop/negative_tests.go b/testdata/indexOnlyLoop/negative_tests.go new file mode 100644 index 0000000..c85200b --- /dev/null +++ b/testdata/indexOnlyLoop/negative_tests.go @@ -0,0 +1,84 @@ +package checker_test + +func alreadyWithValue(files []*string) { + f := func(file *string) {} + + for _, file := range files { + f(file) + f(file) + } + + for _, file := range files { + f(files[0]) + f(files[1]) + f(file) + } +} + +func blankIdent(files []*string) { + for _ = range files { + _ = files[0] + _ = files[1] + } +} + +func printFilesIndexes(files []*string) { + f := func(s int) {} + + for i := range files { + f(i) + } +} + +func closeNonPtrFiles(files []string) { + f := func(s string) {} + + for i := range files { + println(files[i]) + f(files[i]) + } +} + +func indexReuse(filesA []*string, filesB []*string) { + f := func() {} + + for i := range filesA { + if filesA[i] == filesB[i] { + f() + } + } +} + +func channelUse() { + f := func(s string) {} + + c := make(chan string) + for val := range c { + f(val) + } +} + +func sliceUse() { + f := func(s string) {} + + v := []string{} + for _, val := range v[:] { + f(val) + } +} + +func rangeOverString() { + f := func(s rune) {} + + for _, ch := range "abcdef" { + f(ch) + } +} + +func shadowed() { + var xs []*string + for k := range xs { + var xs []int + println(xs[k] + xs[k]) + } +} diff --git a/testdata/indexOnlyLoop/positive_tests.go b/testdata/indexOnlyLoop/positive_tests.go new file mode 100644 index 0000000..06c967e --- /dev/null +++ b/testdata/indexOnlyLoop/positive_tests.go @@ -0,0 +1,36 @@ +package checker_test + +func closeFile(f *string) {} + +func closeFiles(files []*string) { + /// i occurs more than once in the loop; consider using for _, value := range files + for i := range files { + if files[i] != nil { + closeFile(files[i]) + } + } +} + +func sliceLoop(files []*string) { + /// k occurs more than once in the loop; consider using for _, value := range files + for k := range files[:] { + if files[k] != nil { + closeFile(files[k]) + } + } +} + +func nestedLoop(files []*string) { + /// k occurs more than once in the loop; consider using for _, value := range files + for k := range files[:] { + if files[k] != nil { + closeFile(files[k]) + } + + var xs []*int + /// j occurs more than once in the loop; consider using for _, value := range xs + for j := range xs { + println(*xs[j] + *xs[j]) + } + } +} diff --git a/testdata/sqlRowsClose/negative_tests.go b/testdata/sqlRowsClose/negative_tests.go new file mode 100644 index 0000000..217be1b --- /dev/null +++ b/testdata/sqlRowsClose/negative_tests.go @@ -0,0 +1,168 @@ +package checker_test + +import "database/sql" + +func normalLocalUseWithDefer() { + db, err := sql.Open("postgres", "") + if err != nil { + return + } + + rows, err := db.Query("SELECT * FROM testtable") + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + var testdata string + rows.Scan(&testdata) + } + + return +} + +func normalLocalUseWithAnonDefer() { + db, err := sql.Open("postgres", "") + if err != nil { + return + } + + rows, err := db.Query("SELECT * FROM testtable") + if err != nil { + return + } + defer func() { + rows.Close() + }() + + for rows.Next() { + var testdata string + rows.Scan(&testdata) + } + + return +} + +func normalLocalUseWithoutDefer() { + db, err := sql.Open("postgres", "") + if err != nil { + return + } + + rows, err := db.Query("SELECT * FROM testtable") + if err != nil { + return + } + + for rows.Next() { + var testdata string + rows.Scan(&testdata) + } + rows.Close() + + return +} + +func normalUseWhereRowsFromOtherMethodWithDefer() { + db, err := sql.Open("postgres", "") + if err != nil { + return + } + + rows, err := testMethodReturningRows(db) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + var testdata string + rows.Scan(&testdata) + } + + return +} + +func normalUseWhereRowsFromOtherMethodWithoutDefer() { + db, err := sql.Open("postgres", "") + if err != nil { + return + } + + rows, err := testMethodReturningRows(db) + if err != nil { + return + } + + for rows.Next() { + var testdata string + rows.Scan(&testdata) + } + rows.Close() + + return +} + +func normalUseWhereRowsFromOtherMethodToOtherMethodWithDefer() { + db, err := sql.Open("postgres", "") + if err != nil { + return + } + + rows, err := testMethodReturningRows(db) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + testMethodWithRows(rows) + } + + return +} + +func normalUseWhereRowsFromOtherMethodToOtherMethodWithoutDefer() { + db, err := sql.Open("postgres", "") + if err != nil { + return + } + + rows, err := testMethodReturningRows(db) + if err != nil { + return + } + + for rows.Next() { + testMethodWithRows(rows) + } + rows.Close() + + return +} + +// Internal methods + +func testMethodReturningRows(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query("SELECT * FROM testtable") + if err != nil { + return nil, err + } + + return rows, nil +} + +func testMethodWithRows(rows *sql.Rows) { + var testdata string + rows.Scan(&testdata) +} + +func testPosMethodReturningRows(db *sql.DB) (*sql.Rows, error) { + rows, err := db.Query("SELECT * FROM testtable") + if err != nil { + return nil, err + } + + return rows, nil +} diff --git a/testdata/sqlRowsClose/positive_tests.go b/testdata/sqlRowsClose/positive_tests.go new file mode 100644 index 0000000..f5196ed --- /dev/null +++ b/testdata/sqlRowsClose/positive_tests.go @@ -0,0 +1,49 @@ +package checker_test + +import "database/sql" + +/// local variable db.Rows have not Close call +func normalPosLocalUse() { + db, err := sql.Open("postgres", "") + if err != nil { + return + } + + rows, err := db.Query("SELECT * FROM testtable") + if err != nil { + return + } + + for rows.Next() { + var testdata string + rows.Scan(&testdata) + } + + return +} + +/// local variable db.Rows have not Close call +func normalPosLocalUseWithCall() { + db, err := sql.Open("postgres", "") + if err != nil { + return + } + + rows, err := db.Query("SELECT * FROM testtable") + if err != nil { + return + } + + for rows.Next() { + testPosMethodCallRows(rows) + } + + return +} + +func testPosMethodCallRows(rows *sql.Rows) { + for rows.Next() { + var testdata string + rows.Scan(&testdata) + } +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..3f255b7 --- /dev/null +++ b/utils.go @@ -0,0 +1,53 @@ +package checkers + +import ( + "go/ast" +) + +// qualifiedName returns called expr fully-quallified name. +// +// It works for simple identifiers like f => "f" and identifiers +// from other package like pkg.f => "pkg.f". +// +// For all unexpected expressions returns empty string. +func qualifiedName(x ast.Expr) string { + switch x := x.(type) { + case *ast.SelectorExpr: + pkg, ok := x.X.(*ast.Ident) + if !ok { + return "" + } + return pkg.Name + "." + x.Sel.Name + case *ast.Ident: + return x.Name + default: + return "" + } +} + +// identOf returns identifier for x that can be used to obtain associated types.Object. +// Returns nil for expressions that yield temporary results, like `f().field`. +func identOf(x ast.Node) *ast.Ident { + switch x := x.(type) { + case *ast.Ident: + return x + case *ast.SelectorExpr: + return identOf(x.Sel) + case *ast.TypeAssertExpr: + // x.(type) - x may contain ident. + return identOf(x.X) + case *ast.IndexExpr: + // x[i] - x may contain ident. + return identOf(x.X) + case *ast.StarExpr: + // *x - x may contain ident. + return identOf(x.X) + case *ast.SliceExpr: + // x[:] - x may contain ident. + return identOf(x.X) + + default: + // Note that this function is not comprehensive. + return nil + } +}