Skip to content

Commit

Permalink
Expand Sqlizer arguments in Expr
Browse files Browse the repository at this point in the history
  • Loading branch information
cbandy authored and lann committed Dec 6, 2019
1 parent a0c3234 commit a9f8687
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 13 deletions.
52 changes: 51 additions & 1 deletion expr.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package squirrel

import (
"bytes"
"database/sql/driver"
"fmt"
"reflect"
Expand Down Expand Up @@ -28,7 +29,56 @@ func Expr(sql string, args ...interface{}) expr {
}

func (e expr) ToSql() (sql string, args []interface{}, err error) {
return e.sql, e.args, nil
simple := true
for _, arg := range e.args {
if _, ok := arg.(Sqlizer); ok {
simple = false
}
}
if simple {
return e.sql, e.args, nil
}

buf := &bytes.Buffer{}
ap := e.args
sp := e.sql

var isql string
var iargs []interface{}

for err == nil && len(ap) > 0 && len(sp) > 0 {
i := strings.Index(sp, "?")
if i < 0 {
// no more placeholders
break
}
if len(sp) > i+1 && sp[i+1:i+2] == "?" {
// escaped "??"; append it and step past
buf.WriteString(sp[:i+2])
sp = sp[i+2:]
continue
}

if as, ok := ap[0].(Sqlizer); ok {
// sqlizer argument; expand it and append the result
isql, iargs, err = as.ToSql()
buf.WriteString(sp[:i])
buf.WriteString(isql)
args = append(args, iargs...)
} else {
// normal argument; append it and the placeholder
buf.WriteString(sp[:i+1])
args = append(args, ap[0])
}

// step past the argument and placeholder
ap = ap[1:]
sp = sp[i+1:]
}

// append the remaining sql and arguments
buf.WriteString(sp)
return buf.String(), append(args, ap...), err
}

type concatExpr []interface{}
Expand Down
48 changes: 48 additions & 0 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,51 @@ func TestSqlLtOrder(t *testing.T) {
expectedArgs := []interface{}{1, 2, 3}
assert.Equal(t, expectedArgs, args)
}

func TestExprEscaped(t *testing.T) {
b := Expr("count(??)", Expr("x"))
sql, args, err := b.ToSql()
assert.NoError(t, err)

expectedSql := "count(??)"
assert.Equal(t, expectedSql, sql)

expectedArgs := []interface{}{Expr("x")}
assert.Equal(t, expectedArgs, args)
}

func TestExprRecursion(t *testing.T) {
{
b := Expr("count(?)", Expr("nullif(a,?)", "b"))
sql, args, err := b.ToSql()
assert.NoError(t, err)

expectedSql := "count(nullif(a,?))"
assert.Equal(t, expectedSql, sql)

expectedArgs := []interface{}{"b"}
assert.Equal(t, expectedArgs, args)
}
{
b := Expr("extract(? from ?)", Expr("epoch"), "2001-02-03")
sql, args, err := b.ToSql()
assert.NoError(t, err)

expectedSql := "extract(epoch from ?)"
assert.Equal(t, expectedSql, sql)

expectedArgs := []interface{}{"2001-02-03"}
assert.Equal(t, expectedArgs, args)
}
{
b := Expr("JOIN t1 ON ?", And{Eq{"id": 1}, Expr("NOT c1"), Expr("? @@ ?", "x", "y")})
sql, args, err := b.ToSql()
assert.NoError(t, err)

expectedSql := "JOIN t1 ON (id = ? AND NOT c1 AND ? @@ ?)"
assert.Equal(t, expectedSql, sql)

expectedArgs := []interface{}{1, "x", "y"}
assert.Equal(t, expectedArgs, args)
}
}
11 changes: 7 additions & 4 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,13 @@ func (d *insertData) appendValuesToSQL(w io.Writer, args []interface{}) ([]inter
for r, row := range d.Values {
valueStrings := make([]string, len(row))
for v, val := range row {
e, isExpr := val.(expr)
if isExpr {
valueStrings[v] = e.sql
args = append(args, e.args...)
if vs, ok := val.(Sqlizer); ok {
vsql, vargs, err := vs.ToSql()
if err != nil {
return nil, err
}
valueStrings[v] = vsql
args = append(args, vargs...)
} else {
valueStrings[v] = "?"
args = append(args, val)
Expand Down
12 changes: 4 additions & 8 deletions update.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,13 @@ func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) {
setSqls := make([]string, len(d.SetClauses))
for i, setClause := range d.SetClauses {
var valSql string
e, isExpr := setClause.value.(expr)
if isExpr {
valSql = e.sql
args = append(args, e.args...)
} else if c, isCase := setClause.value.(CaseBuilder); isCase {
caseSql, caseArgs, err := c.ToSql()
if vs, ok := setClause.value.(Sqlizer); ok {
vsql, vargs, err := vs.ToSql()
if err != nil {
return "", nil, err
}
valSql = caseSql
args = append(args, caseArgs...)
valSql = vsql
args = append(args, vargs...)
} else {
valSql = "?"
args = append(args, setClause.value)
Expand Down

0 comments on commit a9f8687

Please sign in to comment.