From 1cc9fb952d44b939dbf717156ff5b5d860ceaf9c Mon Sep 17 00:00:00 2001 From: Rajshree Sanjayam Date: Wed, 16 Aug 2023 16:34:39 +0530 Subject: [PATCH] fix for parallel calls with prepared stmts --- sqlmock.go | 16 ++++++ sqlmock_test.go | 127 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/sqlmock.go b/sqlmock.go index d074266..d1ed601 100644 --- a/sqlmock.go +++ b/sqlmock.go @@ -296,6 +296,14 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { if next.fulfilled() { next.Unlock() fulfilled++ + + if pr, ok := next.(*ExpectedPrepare); ok { + if err := c.queryMatcher.Match(pr.expectSQL, query); err == nil { + expected = pr + next.Lock() + break + } + } continue } @@ -334,6 +342,14 @@ func (c *sqlmock) prepare(query string) (*ExpectedPrepare, error) { } func (c *sqlmock) ExpectPrepare(expectedSQL string) *ExpectedPrepare { + for _, e := range c.expected { + if ep, ok := e.(*ExpectedPrepare); ok { + if ep.expectSQL == expectedSQL { + return ep + } + } + } + e := &ExpectedPrepare{expectSQL: expectedSQL, mock: c} c.expected = append(c.expected, e) return e diff --git a/sqlmock_test.go b/sqlmock_test.go index 982a32a..a0d6a8a 100644 --- a/sqlmock_test.go +++ b/sqlmock_test.go @@ -394,6 +394,50 @@ func TestUnorderedPreparedQueryExecutions(t *testing.T) { } } +func TestParallelPreparedQueryExecutions(t *testing.T) { + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + mock.MatchExpectationsInOrder(false) + + mock.ExpectPrepare("INSERT INTO authors \\((.+)\\) VALUES \\((.+)\\)"). + ExpectExec(). + WithArgs(1, "Jane Doe"). + WillReturnResult(NewResult(1, 1)) + + mock.ExpectPrepare("INSERT INTO authors \\((.+)\\) VALUES \\((.+)\\)"). + ExpectExec(). + WithArgs(0, "John Doe"). + WillReturnResult(NewResult(0, 1)) + + t.Run("Parallel1", func(t *testing.T) { + t.Parallel() + + stmt, err := db.Prepare("INSERT INTO authors (id, name) VALUES (?, ?)") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } else { + _, err = stmt.Exec(0, "John Doe") + } + }) + + t.Run("Parallel2", func(t *testing.T) { + t.Parallel() + + stmt, err := db.Prepare("INSERT INTO authors (id, name) VALUES (?, ?)") + if err != nil { + t.Errorf("error '%s' was not expected while creating a prepared statement", err) + } else { + _, err = stmt.Exec(1, "Jane Doe") + } + }) + + t.Cleanup(func() { + db.Close() + }) +} + func TestUnexpectedOperations(t *testing.T) { t.Parallel() db, mock, err := New() @@ -632,6 +676,89 @@ func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) { // note this line is important for unordered expectation matching mock.MatchExpectationsInOrder(false) + data := []interface{}{ + 1, + "John Doe", + 2, + "Jane Doe", + } + rows := NewRows([]string{"id", "name"}) + rows.AddRow(data[0], data[1]) + rows.AddRow(data[2], data[3]) + + mock.ExpectExec("DROP TABLE IF EXISTS author").WillReturnResult(NewResult(0, 0)) + mock.ExpectExec("TRUNCATE TABLE").WillReturnResult(NewResult(0, 0)) + + mock.ExpectExec("CREATE TABLE IF NOT EXISTS author").WillReturnResult(NewResult(0, 0)) + + mock.ExpectQuery("SELECT").WillReturnRows(rows).WithArgs() + + mock.ExpectPrepare("INSERT INTO"). + ExpectExec(). + WithArgs( + data[0], + data[1], + data[2], + data[3], + ). + WillReturnResult(NewResult(0, 2)) + + var wg sync.WaitGroup + queries := []func() error{ + func() error { + _, err := db.Exec("CREATE TABLE IF NOT EXISTS author (a varchar(255)") + return err + }, + func() error { + _, err := db.Exec("TRUNCATE TABLE author") + return err + }, + func() error { + stmt, err := db.Prepare("INSERT INTO author (id,name) VALUES (?,?),(?,?)") + if err != nil { + return err + } + _, err = stmt.Exec(1, "John Doe", 2, "Jane Doe") + return err + }, + func() error { + _, err := db.Query("SELECT * FROM author") + return err + }, + func() error { + _, err := db.Exec("DROP TABLE IF EXISTS author") + return err + }, + } + + wg.Add(len(queries)) + for _, f := range queries { + go func(f func() error) { + if err := f(); err != nil { + t.Errorf("error was not expected: %s", err) + } + wg.Done() + }(f) + } + + wg.Wait() + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +func TestGoroutineExecutionMultiTypes(t *testing.T) { + t.Parallel() + db, mock, err := New() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + + // note this line is important for unordered expectation matching + mock.MatchExpectationsInOrder(false) + result := NewResult(1, 1) mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result)