Commit edc234ce by Bogdan Ungureanu

Extra libs

parent 7e1d16cf
...@@ -4,9 +4,17 @@ GO libraries ...@@ -4,9 +4,17 @@ GO libraries
Libraries Libraries
--------- ---------
fs/inotify - inotify library database/driver/mysql - mysql driver
net/cluster - simple network cluster database/drier/sqlite3s - sqlite3 static library (github.com/changkong/go-sqlite3s)
net/http/pat - http pat router library database/orm/beedb - orm for mysql/pg/sqlite/oracle (github.com/astaxie/beedb)
os/daemon - daemonize library database/orm/gorp - orm for mysql/sqlite (github.com/go-gorp/gorp)
util/ini - ini file library
fs/inotify - inotify library
net/cluster - simple network cluster
net/http/pat - http pat router library
os/daemon - daemonize library
util/ini - ini file library
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
Icon?
ehthumbs.db
Thumbs.db
language: go
go:
- 1.1
- tip
before_script:
- mysql -e 'create database gotest;'
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build appengine
package mysql
import (
"appengine/cloudsql"
"net"
)
func init() {
if dials == nil {
dials = make(map[string]dialFunc)
}
dials["cloudsql"] = func(cfg *config) (net.Conn, error) {
return cloudsql.Dial(cfg.addr)
}
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"database/sql"
"strings"
"sync"
"sync/atomic"
"testing"
)
type TB testing.B
func (tb *TB) check(err error) {
if err != nil {
tb.Fatal(err)
}
}
func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB {
tb.check(err)
return db
}
func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows {
tb.check(err)
return rows
}
func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
tb.check(err)
return stmt
}
func initDB(b *testing.B, queries ...string) *sql.DB {
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
b.Fatalf("Error on %q: %v", query, err)
}
}
return db
}
const concurrencyLevel = 10
func BenchmarkQuery(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
var got string
tb.check(stmt.QueryRow(1).Scan(&got))
if got != "one" {
b.Errorf("query = %q; want one", got)
wg.Done()
return
}
}
}()
}
}
func BenchmarkExec(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := tb.checkDB(sql.Open("mysql", dsn))
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("DO 1"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
if _, err := stmt.Exec(); err != nil {
b.Fatal(err.Error())
}
}
}()
}
}
// data, but no db writes
var roundtripSample []byte
func initRoundtripBenchmarks() ([]byte, int, int) {
if roundtripSample == nil {
roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024))
}
return roundtripSample, 16, len(roundtripSample)
}
func BenchmarkRoundtripTxt(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
sampleString := string(sample)
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
b.StartTimer()
var result string
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sampleString[0:length]
rows := tb.checkRows(db.Query(`SELECT "` + test + `"`))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if result != test {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
func BenchmarkRoundtripBin(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT ?"))
defer stmt.Close()
b.StartTimer()
var result sql.RawBytes
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sample[0:length]
rows := tb.checkRows(stmt.Query(test))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if !bytes.Equal(result, test) {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import "io"
const defaultBufSize = 4096
// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
rd io.Reader
idx int
length int
}
func newBuffer(rd io.Reader) *buffer {
var b [defaultBufSize]byte
return &buffer{
buf: b[:],
rd: rd,
}
}
// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
// move existing data to the beginning
if b.length > 0 && b.idx > 0 {
copy(b.buf[0:b.length], b.buf[b.idx:])
}
// grow buffer if necessary
// TODO: let the buffer shrink again at some point
// Maybe keep the org buf slice and swap back?
if need > len(b.buf) {
// Round up to the next multiple of the default size
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
copy(newBuf, b.buf)
b.buf = newBuf
}
b.idx = 0
for {
n, err := b.rd.Read(b.buf[b.length:])
b.length += n
if err == nil {
if b.length < need {
continue
}
return nil
}
if b.length >= need && err == io.EOF {
return nil
}
return err
}
}
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) {
if b.length < need {
// refill
if err := b.fill(need); err != nil {
return nil, err
}
}
offset := b.idx
b.idx += need
b.length -= need
return b.buf[offset:b.idx], nil
}
// returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
if b.length > 0 {
return nil
}
// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
}
return make([]byte, length)
}
// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
if b.length == 0 {
return b.buf[:length]
}
return nil
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
if b.length == 0 {
return b.buf
}
return nil
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"database/sql/driver"
"errors"
"net"
"strings"
"time"
)
type mysqlConn struct {
buf *buffer
netConn net.Conn
affectedRows uint64
insertId uint64
cfg *config
maxPacketAllowed int
maxWriteSize int
flags clientFlag
sequence uint8
parseTime bool
strict bool
}
type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
timeout time.Duration
tls *tls.Config
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
}
// Handles parameters set in DSN
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.params {
switch param {
// Charset
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// time.Time parsing
case "parseTime":
var isBool bool
mc.parseTime, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Strict mode
case "strict":
var isBool bool
mc.strict, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Compression
case "compress":
err = errors.New("Compression not implemented yet")
return
// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
if err != nil {
return
}
}
}
return
}
func (mc *mysqlConn) Begin() (driver.Tx, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
err := mc.exec("START TRANSACTION")
if err == nil {
return &mysqlTx{mc}, err
}
return nil, err
}
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if mc.netConn != nil {
err = mc.writeCommandPacket(comQuit)
if err == nil {
err = mc.netConn.Close()
} else {
mc.netConn.Close()
}
mc.netConn = nil
}
mc.cfg = nil
mc.buf = nil
return
}
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
return nil, err
}
stmt := &mysqlStmt{
mc: mc,
}
// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
}
}
return stmt, err
}
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
mc.affectedRows = 0
mc.insertId = 0
err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, err
}
// with args, must use prepared stmt
return nil, driver.ErrSkip
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err != nil {
return err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil && resLen > 0 {
if err = mc.readUntilEOF(); err != nil {
return err
}
err = mc.readUntilEOF()
}
return err
}
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen > 0 {
// Columns
rows.columns, err = mc.readColumns(resLen)
}
return rows, err
}
}
return nil, err
}
// with args, must use prepared stmt
return nil, driver.ErrSkip
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
}
}
return nil, err
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
const (
minProtocolVersion byte = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05"
)
// MySQL constants documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
const (
iOK byte = 0x00
iLocalInFile byte = 0xfb
iEOF byte = 0xfe
iERR byte = 0xff
)
type clientFlag uint32
const (
clientLongPassword clientFlag = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
clientNoSchema
clientCompress
clientODBC
clientLocalFiles
clientIgnoreSpace
clientProtocol41
clientInteractive
clientSSL
clientIgnoreSIGPIPE
clientTransactions
clientReserved
clientSecureConn
clientMultiStatements
clientMultiResults
)
const (
comQuit byte = iota + 1
comInitDB
comQuery
comFieldList
comCreateDB
comDropDB
comRefresh
comShutdown
comStatistics
comProcessInfo
comConnect
comProcessKill
comDebug
comPing
comTime
comDelayedInsert
comChangeUser
comBinlogDump
comTableDump
comConnectOut
comRegiserSlave
comStmtPrepare
comStmtExecute
comStmtSendLongData
comStmtClose
comStmtReset
comSetOption
comStmtFetch
)
const (
fieldTypeDecimal byte = iota
fieldTypeTiny
fieldTypeShort
fieldTypeLong
fieldTypeFloat
fieldTypeDouble
fieldTypeNULL
fieldTypeTimestamp
fieldTypeLongLong
fieldTypeInt24
fieldTypeDate
fieldTypeTime
fieldTypeDateTime
fieldTypeYear
fieldTypeNewDate
fieldTypeVarChar
fieldTypeBit
)
const (
fieldTypeNewDecimal byte = iota + 0xf6
fieldTypeEnum
fieldTypeSet
fieldTypeTinyBLOB
fieldTypeMediumBLOB
fieldTypeLongBLOB
fieldTypeBLOB
fieldTypeVarString
fieldTypeString
fieldTypeGeometry
)
type fieldFlag uint16
const (
flagNotNULL fieldFlag = 1 << iota
flagPriKey
flagUniqueKey
flagMultipleKey
flagBLOB
flagUnsigned
flagZeroFill
flagBinary
flagEnum
flagAutoIncrement
flagTimestamp
flagSet
flagUnknown1
flagUnknown2
flagUnknown3
flagUnknown4
)
const (
collation_ascii_general_ci byte = 11
collation_utf8_general_ci byte = 33
collation_utf8mb4_general_ci byte = 45
collation_utf8mb4_bin byte = 46
collation_latin1_general_ci byte = 48
collation_binary byte = 63
collation_utf8mb4_unicode_ci byte = 224
)
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "github.com/go-sql-driver/mysql"
//
// db, err := sql.Open("mysql", "user:password@/dbname")
//
// See https://github.com/go-sql-driver/mysql#usage for details
package mysql
import (
"database/sql"
"database/sql/driver"
"net"
)
// This struct is exported to make the driver directly accessible.
// In general the driver is used via the database/sql package.
type MySQLDriver struct{}
type dialFunc func(*config) (net.Conn, error)
var dials map[string]dialFunc
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error
// New mysqlConn
mc := &mysqlConn{
maxPacketAllowed: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
}
mc.cfg, err = parseDSN(dsn)
if err != nil {
return nil, err
}
// Connect to Server
if dial, ok := dials[mc.cfg.net]; ok {
mc.netConn, err = dial(mc.cfg)
} else {
nd := net.Dialer{Timeout: mc.cfg.timeout}
mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)
}
if err != nil {
return nil, err
}
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
mc.Close()
return nil, err
}
}
mc.buf = newBuffer(mc.netConn)
// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
if err != nil {
mc.Close()
return nil, err
}
// Send Client Authentication Packet
if err = mc.writeAuthPacket(cipher); err != nil {
mc.Close()
return nil, err
}
// Read Result Packet
err = mc.readResultOK()
if err != nil {
// Retry with old authentication method, if allowed
if mc.cfg != nil && mc.cfg.allowOldPasswords && err == ErrOldPassword {
if err = mc.writeOldAuthPacket(cipher); err != nil {
mc.Close()
return nil, err
}
if err = mc.readResultOK(); err != nil {
mc.Close()
return nil, err
}
} else {
mc.Close()
return nil, err
}
}
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxPacketAllowed = stringToInt(maxap) - 1
if mc.maxPacketAllowed < maxPacketSize {
mc.maxWriteSize = mc.maxPacketAllowed
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
func init() {
sql.Register("mysql", &MySQLDriver{})
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"io/ioutil"
"net"
"net/url"
"os"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
var (
dsn string
netAddr string
available bool
)
var (
tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC)
sDate = "2012-06-14"
tDateTime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)
sDateTime = "2011-11-20 21:27:37"
tDate0 = time.Time{}
sDate0 = "0000-00-00"
sDateTime0 = "0000-00-00 00:00:00"
)
// See https://github.com/go-sql-driver/mysql/wiki/Testing
func init() {
env := func(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
user := env("MYSQL_TEST_USER", "root")
pass := env("MYSQL_TEST_PASS", "")
prot := env("MYSQL_TEST_PROT", "tcp")
addr := env("MYSQL_TEST_ADDR", "localhost:3306")
dbname := env("MYSQL_TEST_DBNAME", "gotest")
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
c, err := net.Dial(prot, addr)
if err == nil {
available = true
c.Close()
}
}
type DBTest struct {
*testing.T
db *sql.DB
}
func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
if !available {
t.Skipf("MySQL-Server not running on %s", netAddr)
}
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("Error connecting: %s", err.Error())
}
defer db.Close()
db.Exec("DROP TABLE IF EXISTS test")
dbt := &DBTest{t, db}
for _, test := range tests {
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
}
}
func (dbt *DBTest) fail(method, query string, err error) {
if len(query) > 300 {
query = "[query too large to print]"
}
dbt.Fatalf("Error on %s %s: %s", method, query, err.Error())
}
func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
res, err := dbt.db.Exec(query, args...)
if err != nil {
dbt.fail("Exec", query, err)
}
return res
}
func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) {
rows, err := dbt.db.Query(query, args...)
if err != nil {
dbt.fail("Query", query, err)
}
return rows
}
func TestCRUD(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
// Create Table
dbt.mustExec("CREATE TABLE test (value BOOL)")
// Test for unexpected data
var out bool
rows := dbt.mustQuery("SELECT * FROM test")
if rows.Next() {
dbt.Error("unexpected data in empty table")
}
// Create Data
res := dbt.mustExec("INSERT INTO test VALUES (1)")
count, err := res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("Expected 1 affected row, got %d", count)
}
id, err := res.LastInsertId()
if err != nil {
dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
}
if id != 0 {
dbt.Fatalf("Expected InsertID 0, got %d", id)
}
// Read
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if true != out {
dbt.Errorf("true != %t", out)
}
if rows.Next() {
dbt.Error("unexpected data")
}
} else {
dbt.Error("no data")
}
// Update
res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true)
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("Expected 1 affected row, got %d", count)
}
// Check Update
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if false != out {
dbt.Errorf("false != %t", out)
}
if rows.Next() {
dbt.Error("unexpected data")
}
} else {
dbt.Error("no data")
}
// Delete
res = dbt.mustExec("DELETE FROM test WHERE value = ?", false)
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("Expected 1 affected row, got %d", count)
}
// Check for unexpected rows
res = dbt.mustExec("DELETE FROM test")
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 0 {
dbt.Fatalf("Expected 0 affected row, got %d", count)
}
})
}
func TestInt(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
in := int64(42)
var out int64
var rows *sql.Rows
// SIGNED
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ")")
dbt.mustExec("INSERT INTO test VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Errorf("%s: %d != %d", v, in, out)
}
} else {
dbt.Errorf("%s: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
// UNSIGNED ZEROFILL
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)")
dbt.mustExec("INSERT INTO test VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out)
}
} else {
dbt.Errorf("%s ZEROFILL: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
})
}
func TestFloat(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [2]string{"FLOAT", "DOUBLE"}
in := float32(42.23)
var out float32
var rows *sql.Rows
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ")")
dbt.mustExec("INSERT INTO test VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Errorf("%s: %g != %g", v, in, out)
}
} else {
dbt.Errorf("%s: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
})
}
func TestString(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"}
in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย"
var out string
var rows *sql.Rows
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8")
dbt.mustExec("INSERT INTO test VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Errorf("%s: %s != %s", v, in, out)
}
} else {
dbt.Errorf("%s: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
// BLOB
dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8")
id := 2
in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " +
"sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " +
"sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " +
"Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " +
"Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " +
"sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " +
"sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " +
"Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet."
dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in)
err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out)
if err != nil {
dbt.Fatalf("Error on BLOB-Query: %s", err.Error())
} else if out != in {
dbt.Errorf("BLOB: %s != %s", in, out)
}
})
}
func TestDateTime(t *testing.T) {
type testmode struct {
selectSuffix string
args []interface{}
}
type timetest struct {
in interface{}
sOut string
tOut time.Time
tIsZero bool
}
type tester func(dbt *DBTest, rows *sql.Rows,
test *timetest, sqltype, resulttype, mode string)
type setup struct {
vartype string
dsnSuffix string
test tester
}
var (
modes = map[string]*testmode{
"text": &testmode{},
"binary": &testmode{" WHERE 1 = ?", []interface{}{1}},
}
timetests = map[string][]*timetest{
"DATE": {
{sDate, sDate, tDate, false},
{sDate0, sDate0, tDate0, true},
{tDate, sDate, tDate, false},
{tDate0, sDate0, tDate0, true},
},
"DATETIME": {
{sDateTime, sDateTime, tDateTime, false},
{sDateTime0, sDateTime0, tDate0, true},
{tDateTime, sDateTime, tDateTime, false},
{tDate0, sDateTime0, tDate0, true},
},
}
setups = []*setup{
{"string", "&parseTime=false", func(
dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
var sOut string
if err := rows.Scan(&sOut); err != nil {
dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
} else if test.sOut != sOut {
dbt.Errorf("%s (%s %s): %s != %s", sqltype, resulttype, mode, test.sOut, sOut)
}
}},
{"time.Time", "&parseTime=true", func(
dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
var tOut time.Time
if err := rows.Scan(&tOut); err != nil {
dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
} else if test.tOut != tOut || test.tIsZero != tOut.IsZero() {
dbt.Errorf("%s (%s %s): %s [%t] != %s [%t]", sqltype, resulttype, mode, test.tOut, test.tIsZero, tOut, tOut.IsZero())
}
}},
}
)
var s *setup
testTime := func(dbt *DBTest) {
var rows *sql.Rows
for sqltype, tests := range timetests {
dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
for _, test := range tests {
for mode, q := range modes {
dbt.mustExec("TRUNCATE test")
dbt.mustExec("INSERT INTO test VALUES (?)", test.in)
rows = dbt.mustQuery("SELECT value FROM test"+q.selectSuffix, q.args...)
if rows.Next() {
s.test(dbt, rows, test, sqltype, s.vartype, mode)
} else {
if err := rows.Err(); err != nil {
dbt.Errorf("%s (%s %s): %s",
sqltype, s.vartype, mode, err.Error())
} else {
dbt.Errorf("%s (%s %s): no data",
sqltype, s.vartype, mode)
}
}
}
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
}
timeDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
for _, v := range setups {
s = v
runTests(t, timeDsn+s.dsnSuffix, testTime)
}
}
func TestNULL(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
nullStmt, err := dbt.db.Prepare("SELECT NULL")
if err != nil {
dbt.Fatal(err)
}
defer nullStmt.Close()
nonNullStmt, err := dbt.db.Prepare("SELECT 1")
if err != nil {
dbt.Fatal(err)
}
defer nonNullStmt.Close()
// NullBool
var nb sql.NullBool
// Invalid
if err = nullStmt.QueryRow().Scan(&nb); err != nil {
dbt.Fatal(err)
}
if nb.Valid {
dbt.Error("Valid NullBool which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
dbt.Fatal(err)
}
if !nb.Valid {
dbt.Error("Invalid NullBool which should be valid")
} else if nb.Bool != true {
dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool)
}
// NullFloat64
var nf sql.NullFloat64
// Invalid
if err = nullStmt.QueryRow().Scan(&nf); err != nil {
dbt.Fatal(err)
}
if nf.Valid {
dbt.Error("Valid NullFloat64 which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
dbt.Fatal(err)
}
if !nf.Valid {
dbt.Error("Invalid NullFloat64 which should be valid")
} else if nf.Float64 != float64(1) {
dbt.Errorf("Unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
}
// NullInt64
var ni sql.NullInt64
// Invalid
if err = nullStmt.QueryRow().Scan(&ni); err != nil {
dbt.Fatal(err)
}
if ni.Valid {
dbt.Error("Valid NullInt64 which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
dbt.Fatal(err)
}
if !ni.Valid {
dbt.Error("Invalid NullInt64 which should be valid")
} else if ni.Int64 != int64(1) {
dbt.Errorf("Unexpected NullInt64 value: %d (should be 1)", ni.Int64)
}
// NullString
var ns sql.NullString
// Invalid
if err = nullStmt.QueryRow().Scan(&ns); err != nil {
dbt.Fatal(err)
}
if ns.Valid {
dbt.Error("Valid NullString which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
dbt.Fatal(err)
}
if !ns.Valid {
dbt.Error("Invalid NullString which should be valid")
} else if ns.String != `1` {
dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)")
}
// nil-bytes
var b []byte
// Read nil
if err = nullStmt.QueryRow().Scan(&b); err != nil {
dbt.Fatal(err)
}
if b != nil {
dbt.Error("Non-nil []byte wich should be nil")
}
// Read non-nil
if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
dbt.Fatal(err)
}
if b == nil {
dbt.Error("Nil []byte wich should be non-nil")
}
// Insert nil
b = nil
success := false
if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil {
dbt.Fatal(err)
}
if !success {
dbt.Error("Inserting []byte(nil) as NULL failed")
}
// Check input==output with input==nil
b = nil
if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
dbt.Fatal(err)
}
if b != nil {
dbt.Error("Non-nil echo from nil input")
}
// Check input==output with input!=nil
b = []byte("")
if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
dbt.Fatal(err)
}
if b == nil {
dbt.Error("nil echo from non-nil input")
}
// Insert NULL
dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)")
dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2)
var out interface{}
rows := dbt.mustQuery("SELECT * FROM test")
if rows.Next() {
rows.Scan(&out)
if out != nil {
dbt.Errorf("%v != nil", out)
}
} else {
dbt.Error("no data")
}
})
}
func TestLongData(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
var maxAllowedPacketSize int
err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
if err != nil {
dbt.Fatal(err)
}
maxAllowedPacketSize--
// don't get too ambitious
if maxAllowedPacketSize > 1<<25 {
maxAllowedPacketSize = 1 << 25
}
dbt.mustExec("CREATE TABLE test (value LONGBLOB)")
in := strings.Repeat(`a`, maxAllowedPacketSize+1)
var out string
var rows *sql.Rows
// Long text data
const nonDataQueryLen = 28 // length query w/o value
inS := in[:maxAllowedPacketSize-nonDataQueryLen]
dbt.mustExec("INSERT INTO test VALUES('" + inS + "')")
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if inS != out {
dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out))
}
if rows.Next() {
dbt.Error("LONGBLOB: unexpexted row")
}
} else {
dbt.Fatalf("LONGBLOB: no data")
}
// Empty table
dbt.mustExec("TRUNCATE TABLE test")
// Long binary data
dbt.mustExec("INSERT INTO test VALUES(?)", in)
rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1)
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out))
}
if rows.Next() {
dbt.Error("LONGBLOB: unexpexted row")
}
} else {
if err = rows.Err(); err != nil {
dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error())
} else {
dbt.Fatal("LONGBLOB: no data (err: <nil>)")
}
}
})
}
func TestLoadData(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
verifyLoadDataResult := func() {
rows, err := dbt.db.Query("SELECT * FROM test")
if err != nil {
dbt.Fatal(err.Error())
}
i := 0
values := [4]string{
"a string",
"a string containing a \t",
"a string containing a \n",
"a string containing both \t\n",
}
var id int
var value string
for rows.Next() {
i++
err = rows.Scan(&id, &value)
if err != nil {
dbt.Fatal(err.Error())
}
if i != id {
dbt.Fatalf("%d != %d", i, id)
}
if values[i-1] != value {
dbt.Fatalf("%s != %s", values[i-1], value)
}
}
err = rows.Err()
if err != nil {
dbt.Fatal(err.Error())
}
if i != 4 {
dbt.Fatalf("Rows count mismatch. Got %d, want 4", i)
}
}
file, err := ioutil.TempFile("", "gotest")
defer os.Remove(file.Name())
if err != nil {
dbt.Fatal(err)
}
file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
file.Close()
dbt.db.Exec("DROP TABLE IF EXISTS test")
dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
// Local File
RegisterLocalFile(file.Name())
dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name()))
verifyLoadDataResult()
// negative test
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")
if err == nil {
dbt.Fatal("Load non-existent file didn't fail")
} else if err.Error() != "Local File 'doesnotexist' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" {
dbt.Fatal(err.Error())
}
// Empty table
dbt.mustExec("TRUNCATE TABLE test")
// Reader
RegisterReaderHandler("test", func() io.Reader {
file, err = os.Open(file.Name())
if err != nil {
dbt.Fatal(err)
}
return file
})
dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test")
verifyLoadDataResult()
// negative test
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test")
if err == nil {
dbt.Fatal("Load non-existent Reader didn't fail")
} else if err.Error() != "Reader 'doesnotexist' is not registered" {
dbt.Fatal(err.Error())
}
})
}
func TestFoundRows(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
count, err := res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 2 {
dbt.Fatalf("Expected 2 affected rows, got %d", count)
}
res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 2 {
dbt.Fatalf("Expected 2 affected rows, got %d", count)
}
})
runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
count, err := res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 2 {
dbt.Fatalf("Expected 2 matched rows, got %d", count)
}
res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 3 {
dbt.Fatalf("Expected 3 matched rows, got %d", count)
}
})
}
func TestStrict(t *testing.T) {
// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
runTests(t, relaxedDsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
var queries = [...]struct {
in string
codes []string
}{
{"DROP TABLE IF EXISTS no_such_table", []string{"1051"}},
{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}},
}
var err error
var checkWarnings = func(err error, mode string, idx int) {
if err == nil {
dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in)
}
if warnings, ok := err.(MySQLWarnings); ok {
var codes = make([]string, len(warnings))
for i := range warnings {
codes[i] = warnings[i].Code
}
if len(codes) != len(queries[idx].codes) {
dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
}
for i := range warnings {
if codes[i] != queries[idx].codes[i] {
dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
return
}
}
} else {
dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
}
}
// text protocol
for i := range queries {
_, err = dbt.db.Exec(queries[i].in)
checkWarnings(err, "text", i)
}
var stmt *sql.Stmt
// binary protocol
for i := range queries {
stmt, err = dbt.db.Prepare(queries[i].in)
if err != nil {
dbt.Errorf("Error on preparing query %s: %s", queries[i].in, err.Error())
}
_, err = stmt.Exec()
checkWarnings(err, "binary", i)
err = stmt.Close()
if err != nil {
dbt.Errorf("Error on closing stmt for query %s: %s", queries[i].in, err.Error())
}
}
})
}
func TestTLS(t *testing.T) {
tlsTest := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
if err == ErrNoTLS {
dbt.Skip("Server does not support TLS")
} else {
dbt.Fatalf("Error on Ping: %s", err.Error())
}
}
rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'")
var variable, value *sql.RawBytes
for rows.Next() {
if err := rows.Scan(&variable, &value); err != nil {
dbt.Fatal(err.Error())
}
if value == nil {
dbt.Fatal("No Cipher")
}
}
}
runTests(t, dsn+"&tls=skip-verify", tlsTest)
// Verify that registering / using a custom cfg works
RegisterTLSConfig("custom-skip-verify", &tls.Config{
InsecureSkipVerify: true,
})
runTests(t, dsn+"&tls=custom-skip-verify", tlsTest)
}
func TestReuseClosedConnection(t *testing.T) {
// this test does not use sql.database, it uses the driver directly
if !available {
t.Skipf("MySQL-Server not running on %s", netAddr)
}
md := &MySQLDriver{}
conn, err := md.Open(dsn)
if err != nil {
t.Fatalf("Error connecting: %s", err.Error())
}
stmt, err := conn.Prepare("DO 1")
if err != nil {
t.Fatalf("Error preparing statement: %s", err.Error())
}
_, err = stmt.Exec(nil)
if err != nil {
t.Fatalf("Error executing statement: %s", err.Error())
}
err = conn.Close()
if err != nil {
t.Fatalf("Error closing connection: %s", err.Error())
}
defer func() {
if err := recover(); err != nil {
t.Errorf("Panic after reusing a closed connection: %v", err)
}
}()
_, err = stmt.Exec(nil)
if err != nil && err != driver.ErrBadConn {
t.Errorf("Unexpected error '%s', expected '%s'",
err.Error(), driver.ErrBadConn.Error())
}
}
func TestCharset(t *testing.T) {
if !available {
t.Skipf("MySQL-Server not running on %s", netAddr)
}
mustSetCharset := func(charsetParam, expected string) {
runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
rows := dbt.mustQuery("SELECT @@character_set_connection")
defer rows.Close()
if !rows.Next() {
dbt.Fatalf("Error getting connection charset: %s", rows.Err())
}
var got string
rows.Scan(&got)
if got != expected {
dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
}
})
}
// non utf8 test
mustSetCharset("charset=ascii", "ascii")
// when the first charset is invalid, use the second
mustSetCharset("charset=none,utf8", "utf8")
// when the first charset is valid, use it
mustSetCharset("charset=ascii,utf8", "ascii")
mustSetCharset("charset=utf8,ascii", "utf8")
}
func TestFailingCharset(t *testing.T) {
runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
// run query to really establish connection...
_, err := dbt.db.Exec("SELECT 1")
if err == nil {
dbt.db.Close()
t.Fatalf("Connection must not succeed without a valid charset")
}
})
}
func TestRawBytesResultExceedsBuffer(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
// defaultBufSize from buffer.go
expected := strings.Repeat("abc", defaultBufSize)
rows := dbt.mustQuery("SELECT '" + expected + "'")
defer rows.Close()
if !rows.Next() {
dbt.Error("expected result, got none")
}
var result sql.RawBytes
rows.Scan(&result)
if expected != string(result) {
dbt.Error("result did not match expected value")
}
})
}
func TestTimezoneConversion(t *testing.T) {
zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
// Regression test for timezone handling
tzTest := func(dbt *DBTest) {
// Create table
dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
// Insert local time into database (should be converted)
usCentral, _ := time.LoadLocation("US/Central")
now := time.Now().In(usCentral)
dbt.mustExec("INSERT INTO test VALUE (?)", now)
// Retrieve time from DB
rows := dbt.mustQuery("SELECT ts FROM test")
if !rows.Next() {
dbt.Fatal("Didn't get any rows out")
}
var nowDB time.Time
err := rows.Scan(&nowDB)
if err != nil {
dbt.Fatal("Err", err)
}
// Check that dates match
if now.Unix() != nowDB.Unix() {
dbt.Errorf("Times don't match.\n")
dbt.Errorf(" Now(%v)=%v\n", usCentral, now)
dbt.Errorf(" Now(UTC)=%v\n", nowDB)
}
}
for _, tz := range zones {
runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
}
}
// This tests for https://github.com/go-sql-driver/mysql/pull/139
//
// An extra (invisible) nil byte was being added to the beginning of positive
// time strings.
func TestTimeSign(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
var sTimes = []struct {
value string
fieldType string
}{
{"12:34:56", "TIME"},
{"-12:34:56", "TIME"},
// As described in http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html
// they *should* work, but only in 5.6+.
// { "12:34:56.789", "TIME(3)" },
// { "-12:34:56.789", "TIME(3)" },
}
for _, sTime := range sTimes {
dbt.db.Exec("DROP TABLE IF EXISTS test")
dbt.mustExec("CREATE TABLE test (id INT, time_field " + sTime.fieldType + ")")
dbt.mustExec("INSERT INTO test (id, time_field) VALUES(1, '" + sTime.value + "')")
rows := dbt.mustQuery("SELECT time_field FROM test WHERE id = ?", 1)
if rows.Next() {
var oTime string
rows.Scan(&oTime)
if oTime != sTime.value {
dbt.Errorf(`time values differ: got %q, expected %q.`, oTime, sTime.value)
}
} else {
dbt.Error("expecting at least one row.")
}
}
})
}
// Special cases
func TestRowsClose(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
rows, err := dbt.db.Query("SELECT 1")
if err != nil {
dbt.Fatal(err)
}
err = rows.Close()
if err != nil {
dbt.Fatal(err)
}
if rows.Next() {
dbt.Fatal("Unexpected row after rows.Close()")
}
err = rows.Err()
if err != nil {
dbt.Fatal(err)
}
})
}
// dangling statements
// http://code.google.com/p/go/issues/detail?id=3865
func TestCloseStmtBeforeRows(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
stmt, err := dbt.db.Prepare("SELECT 1")
if err != nil {
dbt.Fatal(err)
}
rows, err := stmt.Query()
if err != nil {
stmt.Close()
dbt.Fatal(err)
}
defer rows.Close()
err = stmt.Close()
if err != nil {
dbt.Fatal(err)
}
if !rows.Next() {
dbt.Fatal("Getting row failed")
} else {
err = rows.Err()
if err != nil {
dbt.Fatal(err)
}
var out bool
err = rows.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != true {
dbt.Errorf("true != %t", out)
}
}
})
}
// It is valid to have multiple Rows for the same Stmt
// http://code.google.com/p/go/issues/detail?id=3734
func TestStmtMultiRows(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0")
if err != nil {
dbt.Fatal(err)
}
rows1, err := stmt.Query()
if err != nil {
stmt.Close()
dbt.Fatal(err)
}
defer rows1.Close()
rows2, err := stmt.Query()
if err != nil {
stmt.Close()
dbt.Fatal(err)
}
defer rows2.Close()
var out bool
// 1
if !rows1.Next() {
dbt.Fatal("1st rows1.Next failed")
} else {
err = rows1.Err()
if err != nil {
dbt.Fatal(err)
}
err = rows1.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != true {
dbt.Errorf("true != %t", out)
}
}
if !rows2.Next() {
dbt.Fatal("1st rows2.Next failed")
} else {
err = rows2.Err()
if err != nil {
dbt.Fatal(err)
}
err = rows2.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != true {
dbt.Errorf("true != %t", out)
}
}
// 2
if !rows1.Next() {
dbt.Fatal("2nd rows1.Next failed")
} else {
err = rows1.Err()
if err != nil {
dbt.Fatal(err)
}
err = rows1.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != false {
dbt.Errorf("false != %t", out)
}
if rows1.Next() {
dbt.Fatal("Unexpected row on rows1")
}
err = rows1.Close()
if err != nil {
dbt.Fatal(err)
}
}
if !rows2.Next() {
dbt.Fatal("2nd rows2.Next failed")
} else {
err = rows2.Err()
if err != nil {
dbt.Fatal(err)
}
err = rows2.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != false {
dbt.Errorf("false != %t", out)
}
if rows2.Next() {
dbt.Fatal("Unexpected row on rows2")
}
err = rows2.Close()
if err != nil {
dbt.Fatal(err)
}
}
})
}
// Regression test for
// * more than 32 NULL parameters (issue 209)
// * more parameters than fit into the buffer (issue 201)
func TestPreparedManyCols(t *testing.T) {
const numParams = defaultBufSize
runTests(t, dsn, func(dbt *DBTest) {
query := "SELECT ?" + strings.Repeat(",?", numParams-1)
stmt, err := dbt.db.Prepare(query)
if err != nil {
dbt.Fatal(err)
}
defer stmt.Close()
// create more parameters than fit into the buffer
// which will take nil-values
params := make([]interface{}, numParams)
rows, err := stmt.Query(params...)
if err != nil {
stmt.Close()
dbt.Fatal(err)
}
defer rows.Close()
})
}
func TestConcurrent(t *testing.T) {
if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
t.Skip("MYSQL_TEST_CONCURRENT env var not set")
}
runTests(t, dsn, func(dbt *DBTest) {
var max int
err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)
if err != nil {
dbt.Fatalf("%s", err.Error())
}
dbt.Logf("Testing up to %d concurrent connections \r\n", max)
var remaining, succeeded int32 = int32(max), 0
var wg sync.WaitGroup
wg.Add(max)
var fatalError string
var once sync.Once
fatal := func(s string, vals ...interface{}) {
once.Do(func() {
fatalError = fmt.Sprintf(s, vals...)
})
}
for i := 0; i < max; i++ {
go func(id int) {
defer wg.Done()
tx, err := dbt.db.Begin()
atomic.AddInt32(&remaining, -1)
if err != nil {
if err.Error() != "Error 1040: Too many connections" {
fatal("Error on Conn %d: %s", id, err.Error())
}
return
}
// keep the connection busy until all connections are open
for remaining > 0 {
if _, err = tx.Exec("DO 1"); err != nil {
fatal("Error on Conn %d: %s", id, err.Error())
return
}
}
if err = tx.Commit(); err != nil {
fatal("Error on Conn %d: %s", id, err.Error())
return
}
// everything went fine with this connection
atomic.AddInt32(&succeeded, 1)
}(i)
}
// wait until all conections are open
wg.Wait()
if fatalError != "" {
dbt.Fatal(fatalError)
}
dbt.Logf("Reached %d concurrent connections\r\n", succeeded)
})
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"errors"
"fmt"
"io"
"log"
"os"
)
var (
ErrInvalidConn = errors.New("Invalid Connection")
ErrMalformPkt = errors.New("Malformed Packet")
ErrNoTLS = errors.New("TLS encryption requested but server does not support TLS")
ErrOldPassword = errors.New("This server only supports the insecure old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
ErrPktSync = errors.New("Commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
ErrBusyBuffer = errors.New("Busy buffer")
errLog Logger = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
)
// Logger is used to log critical error messages.
type Logger interface {
Print(v ...interface{})
}
// SetLogger is used to set the logger for critical errors.
// The initial logger is stderr.
func SetLogger(logger Logger) error {
if logger == nil {
return errors.New("logger is nil")
}
errLog = logger
return nil
}
// MySQLError is an error type which represents a single MySQL error
type MySQLError struct {
Number uint16
Message string
}
func (me *MySQLError) Error() string {
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}
// MySQLWarnings is an error type which represents a group of one or more MySQL
// warnings
type MySQLWarnings []MysqlWarning
func (mws MySQLWarnings) Error() string {
var msg string
for i, warning := range mws {
if i > 0 {
msg += "\r\n"
}
msg += fmt.Sprintf(
"%s %s: %s",
warning.Level,
warning.Code,
warning.Message,
)
}
return msg
}
// MysqlWarning is an error type which represents a single MySQL warning.
// Warnings are returned in groups only. See MySQLWarnings
type MysqlWarning struct {
Level string
Code string
Message string
}
func (mc *mysqlConn) getWarnings() (err error) {
rows, err := mc.Query("SHOW WARNINGS", []driver.Value{})
if err != nil {
return
}
var warnings = MySQLWarnings{}
var values = make([]driver.Value, 3)
var warning MysqlWarning
var raw []byte
var ok bool
for {
err = rows.Next(values)
switch err {
case nil:
warning = MysqlWarning{}
if raw, ok = values[0].([]byte); ok {
warning.Level = string(raw)
} else {
warning.Level = fmt.Sprintf("%s", values[0])
}
if raw, ok = values[1].([]byte); ok {
warning.Code = string(raw)
} else {
warning.Code = fmt.Sprintf("%s", values[1])
}
if raw, ok = values[2].([]byte); ok {
warning.Message = string(raw)
} else {
warning.Message = fmt.Sprintf("%s", values[0])
}
warnings = append(warnings, warning)
case io.EOF:
return warnings
default:
rows.Close()
return
}
}
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"log"
"testing"
)
func TestSetLogger(t *testing.T) {
previous := errLog
defer func() {
errLog = previous
}()
const expected = "prefix: test\n"
buffer := bytes.NewBuffer(make([]byte, 0, 64))
logger := log.New(buffer, "prefix: ", 0)
SetLogger(logger)
errLog.Print("test")
if actual := buffer.String(); actual != expected {
t.Errorf("expected %q, got %q", expected, actual)
}
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"fmt"
"io"
"os"
"strings"
)
var (
fileRegister map[string]bool
readerRegister map[string]func() io.Reader
)
// RegisterLocalFile adds the given file to the file whitelist,
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
// Alternatively you can allow the use of all local files with
// the DSN parameter 'allowAllFiles=true'
//
// filePath := "/home/gopher/data.csv"
// mysql.RegisterLocalFile(filePath)
// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterLocalFile(filePath string) {
// lazy map init
if fileRegister == nil {
fileRegister = make(map[string]bool)
}
fileRegister[strings.Trim(filePath, `"`)] = true
}
// DeregisterLocalFile removes the given filepath from the whitelist.
func DeregisterLocalFile(filePath string) {
delete(fileRegister, strings.Trim(filePath, `"`))
}
// RegisterReaderHandler registers a handler function which is used
// to receive a io.Reader.
// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
// If the handler returns a io.ReadCloser Close() is called when the
// request is finished.
//
// mysql.RegisterReaderHandler("data", func() io.Reader {
// var csvReader io.Reader // Some Reader that returns CSV data
// ... // Open Reader here
// return csvReader
// })
// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterReaderHandler(name string, handler func() io.Reader) {
// lazy map init
if readerRegister == nil {
readerRegister = make(map[string]func() io.Reader)
}
readerRegister[name] = handler
}
// DeregisterReaderHandler removes the ReaderHandler function with
// the given name from the registry.
func DeregisterReaderHandler(name string) {
delete(readerRegister, name)
}
func deferredClose(err *error, closer io.Closer) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
var rdr io.Reader
var data []byte
if strings.HasPrefix(name, "Reader::") { // io.Reader
name = name[8:]
if handler, inMap := readerRegister[name]; inMap {
rdr = handler()
if rdr != nil {
data = make([]byte, 4+mc.maxWriteSize)
if cl, ok := rdr.(io.Closer); ok {
defer deferredClose(&err, cl)
}
} else {
err = fmt.Errorf("Reader '%s' is <nil>", name)
}
} else {
err = fmt.Errorf("Reader '%s' is not registered", name)
}
} else { // File
name = strings.Trim(name, `"`)
if mc.cfg.allowAllFiles || fileRegister[name] {
var file *os.File
var fi os.FileInfo
if file, err = os.Open(name); err == nil {
defer deferredClose(&err, file)
// get file size
if fi, err = file.Stat(); err == nil {
rdr = file
if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize {
data = make([]byte, 4+fileSize)
} else if fileSize <= mc.maxPacketAllowed {
data = make([]byte, 4+mc.maxWriteSize)
} else {
err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed)
}
}
}
} else {
err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
}
}
// send content packets
if err == nil {
var n int
for err == nil {
n, err = rdr.Read(data[4:])
if n > 0 {
if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
return ioErr
}
}
}
if err == io.EOF {
err = nil
}
}
// send empty packet (termination)
if data == nil {
data = make([]byte, 4)
}
if ioErr := mc.writePacket(data[:4]); ioErr != nil {
return ioErr
}
// read OK packet
if err == nil {
return mc.readResultOK()
} else {
mc.readPacket()
}
return err
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"math"
"time"
)
// Packets documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
var payload []byte
for {
// Read packet header
data, err := mc.buf.readNext(4)
if err != nil {
errLog.Print(err)
mc.Close()
return nil, driver.ErrBadConn
}
// Packet Length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
if pktLen < 1 {
errLog.Print(ErrMalformPkt)
mc.Close()
return nil, driver.ErrBadConn
}
// Check Packet Sync [8 bit]
if data[3] != mc.sequence {
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
} else {
return nil, ErrPktSync
}
}
mc.sequence++
// Read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen)
if err != nil {
errLog.Print(err)
mc.Close()
return nil, driver.ErrBadConn
}
isLastPacket := (pktLen < maxPacketSize)
// Zero allocations for non-splitting packets
if isLastPacket && payload == nil {
return data, nil
}
payload = append(payload, data...)
if isLastPacket {
return payload, nil
}
}
}
// Write packet buffer 'data'
func (mc *mysqlConn) writePacket(data []byte) error {
pktLen := len(data) - 4
if pktLen > mc.maxPacketAllowed {
return ErrPktTooLarge
}
for {
var size int
if pktLen >= maxPacketSize {
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
size = maxPacketSize
} else {
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
size = pktLen
}
data[3] = mc.sequence
// Write packet
n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
continue
}
// Handle error
if err == nil { // n != len(data)
errLog.Print(ErrMalformPkt)
} else {
errLog.Print(err)
}
return driver.ErrBadConn
}
}
/******************************************************************************
* Initialisation Process *
******************************************************************************/
// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
data, err := mc.readPacket()
if err != nil {
return nil, err
}
if data[0] == iERR {
return nil, mc.handleErrorPacket(data)
}
// protocol version [1 byte]
if data[0] < minProtocolVersion {
return nil, fmt.Errorf(
"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
data[0],
minProtocolVersion,
)
}
// server version [null terminated string]
// connection id [4 bytes]
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
// first part of the password cipher [8 bytes]
cipher := data[pos : pos+8]
// (filler) always 0x00 [1 byte]
pos += 8 + 1
// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if mc.flags&clientProtocol41 == 0 {
return nil, ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
return nil, ErrNoTLS
}
pos += 2
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
// capability flags (upper 2 bytes) [2 bytes]
// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
//
// The web documentation is ambiguous about the length. However,
// according to mysql-5.7/sql/auth/sql_authentication.cc line 538,
// the 13th byte is "\0 byte, terminating the second part of
// a scramble". So the second part of the password cipher is
// a NULL terminated string that's at least 13 bytes with the
// last byte being NULL.
//
// The official Python library uses the fixed length 12
// which seems to work but technically could have a hidden bug.
cipher = append(cipher, data[pos:pos+12]...)
// TODO: Verify string termination
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
// \NUL otherwise
//
//if data[len(data)-1] == 0 {
// return
//}
//return ErrMalformPkt
return cipher, nil
}
// make a memory safe copy of the cipher slice
var b [8]byte
copy(b[:], cipher)
return b[:], nil
}
// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
clientLongPassword |
clientTransactions |
clientLocalFiles |
mc.flags&clientLongFlag
if mc.cfg.clientFoundRows {
clientFlags |= clientFoundRows
}
// To enable TLS / SSL
if mc.cfg.tls != nil {
clientFlags |= clientSSL
}
// User Password
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd))
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
// To specify a db name
if n := len(mc.cfg.dbname); n > 0 {
clientFlags |= clientConnectWithDB
pktLen += n + 1
}
// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
}
// ClientFlags [32 bit]
data[4] = byte(clientFlags)
data[5] = byte(clientFlags >> 8)
data[6] = byte(clientFlags >> 16)
data[7] = byte(clientFlags >> 24)
// MaxPacketSize [32 bit] (none)
data[8] = 0x00
data[9] = 0x00
data[10] = 0x00
data[11] = 0x00
// Charset [1 byte]
data[12] = collation_utf8_general_ci
// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if mc.cfg.tls != nil {
// Send TLS / SSL request packet
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
return err
}
// Switch to TLS
tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
if err := tlsConn.Handshake(); err != nil {
return err
}
mc.netConn = tlsConn
mc.buf.rd = tlsConn
}
// Filler [23 bytes] (all 0x00)
pos := 13 + 23
// User [null terminated string]
if len(mc.cfg.user) > 0 {
pos += copy(data[pos:], mc.cfg.user)
}
data[pos] = 0x00
pos++
// ScrambleBuffer [length encoded integer]
data[pos] = byte(len(scrambleBuff))
pos += 1 + copy(data[pos+1:], scrambleBuff)
// Databasename [null terminated string]
if len(mc.cfg.dbname) > 0 {
pos += copy(data[pos:], mc.cfg.dbname)
data[pos] = 0x00
}
// Send Auth packet
return mc.writePacket(data)
}
// Client old authentication packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
// User password
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
// Calculate the packet lenght and add a tailing 0
pktLen := len(scrambleBuff) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
}
// Add the scrambled password [null terminated string]
copy(data[4:], scrambleBuff)
data[4+pktLen-1] = 0x00
return mc.writePacket(data)
}
/******************************************************************************
* Command Packets *
******************************************************************************/
func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0
data := mc.buf.takeSmallBuffer(4 + 1)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
}
// Add command byte
data[4] = command
// Send CMD packet
return mc.writePacket(data)
}
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
// Reset Packet Sequence
mc.sequence = 0
pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
}
// Add command byte
data[4] = command
// Add arg
copy(data[5:], arg)
// Send CMD packet
return mc.writePacket(data)
}
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
}
// Add command byte
data[4] = command
// Add arg [32 bit]
data[5] = byte(arg)
data[6] = byte(arg >> 8)
data[7] = byte(arg >> 16)
data[8] = byte(arg >> 24)
// Send CMD packet
return mc.writePacket(data)
}
/******************************************************************************
* Result Packets *
******************************************************************************/
// Returns error if Packet is not an 'Result OK'-Packet
func (mc *mysqlConn) readResultOK() error {
data, err := mc.readPacket()
if err == nil {
// packet indicator
switch data[0] {
case iOK:
return mc.handleOkPacket(data)
case iEOF:
// someone is using old_passwords
return ErrOldPassword
default: // Error otherwise
return mc.handleErrorPacket(data)
}
}
return err
}
// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
data, err := mc.readPacket()
if err == nil {
switch data[0] {
case iOK:
return 0, mc.handleOkPacket(data)
case iERR:
return 0, mc.handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// column count
num, _, n := readLengthEncodedInteger(data)
if n-len(data) == 0 {
return int(num), nil
}
return 0, ErrMalformPkt
}
return 0, err
}
// Error Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
if data[0] != iERR {
return ErrMalformPkt
}
// 0xff [1 byte]
// Error Number [16 bit uint]
errno := binary.LittleEndian.Uint16(data[1:3])
pos := 3
// SQL State [optional: # + 5bytes string]
if data[3] == 0x23 {
//sqlstate := string(data[4 : 4+5])
pos = 9
}
// Error Message [string]
return &MySQLError{
Number: errno,
Message: string(data[pos:]),
}
}
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error {
var n, m int
// 0x00 [1 byte]
// Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
// Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
// server_status [2 bytes]
// warning count [2 bytes]
if !mc.strict {
return nil
} else {
pos := 1 + n + m + 2
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
return mc.getWarnings()
}
return nil
}
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
columns := make([]mysqlField, count)
for i := 0; ; i++ {
data, err := mc.readPacket()
if err != nil {
return nil, err
}
// EOF Packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
if i == count {
return columns, nil
}
return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
}
// Catalog
pos, err := skipLengthEncodedString(data)
if err != nil {
return nil, err
}
// Database [len coded string]
n, err := skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
// Table [len coded string]
n, err = skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
// Original table [len coded string]
n, err = skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
// Name [len coded string]
name, _, n, err := readLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
columns[i].name = string(name)
pos += n
// Original name [len coded string]
n, err = skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
// Filler [1 byte]
// Charset [16 bit uint]
// Length [32 bit uint]
pos += n + 1 + 2 + 4
// Field type [byte]
columns[i].fieldType = data[pos]
pos++
// Flags [16 bit uint]
columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
//pos += 2
// Decimals [8 bit uint]
//pos++
// Default value [len coded binary]
//if pos < len(data) {
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
//}
}
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc
data, err := mc.readPacket()
if err != nil {
return err
}
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
return io.EOF
}
// RowSet Packet
var n int
var isNull bool
pos := 0
for i := range dest {
// Read bytes and convert to string
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
pos += n
if err == nil {
if !isNull {
if !mc.parseTime {
continue
} else {
switch rows.columns[i].fieldType {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
string(dest[i].([]byte)),
mc.cfg.loc,
)
if err == nil {
continue
}
default:
continue
}
}
} else {
dest[i] = nil
continue
}
}
return err // err != nil
}
return nil
}
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
func (mc *mysqlConn) readUntilEOF() error {
for {
data, err := mc.readPacket()
// No Err and no EOF Packet
if err == nil && data[0] != iEOF {
continue
}
return err // Err or EOF
}
}
/******************************************************************************
* Prepared Statements *
******************************************************************************/
// Prepare Result Packets
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
data, err := stmt.mc.readPacket()
if err == nil {
// packet indicator [1 byte]
if data[0] != iOK {
return 0, stmt.mc.handleErrorPacket(data)
}
// statement id [4 bytes]
stmt.id = binary.LittleEndian.Uint32(data[1:5])
// Column count [16 bit uint]
columnCount := binary.LittleEndian.Uint16(data[5:7])
// Param count [16 bit uint]
stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
// Reserved [8 bit]
// Warning count [16 bit uint]
if !stmt.mc.strict {
return columnCount, nil
} else {
// Check for warnings count > 0, only available in MySQL > 4.1
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
return columnCount, stmt.mc.getWarnings()
}
return columnCount, nil
}
}
return 0, err
}
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
maxLen := stmt.mc.maxPacketAllowed - 1
pktLen := maxLen
// After the header (bytes 0-3) follows before the data:
// 1 byte command
// 4 bytes stmtID
// 2 bytes paramID
const dataOffset = 1 + 4 + 2
// Can not use the write buffer since
// a) the buffer is too small
// b) it is in use
data := make([]byte, 4+1+4+2+len(arg))
copy(data[4+dataOffset:], arg)
for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
if dataOffset+argLen < maxLen {
pktLen = dataOffset + argLen
}
stmt.mc.sequence = 0
// Add command byte [1 byte]
data[4] = comStmtSendLongData
// Add stmtID [32 bit]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// Add paramID [16 bit]
data[9] = byte(paramID)
data[10] = byte(paramID >> 8)
// Send CMD packet
err := stmt.mc.writePacket(data[:4+pktLen])
if err == nil {
data = data[pktLen-dataOffset:]
continue
}
return err
}
// Reset Packet Sequence
stmt.mc.sequence = 0
return nil
}
// Execute Prepared Statement
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if len(args) != stmt.paramCount {
return fmt.Errorf(
"Arguments count mismatch (Got: %d Has: %d)",
len(args),
stmt.paramCount,
)
}
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc
// Reset packet-sequence
mc.sequence = 0
var data []byte
if len(args) == 0 {
data = mc.buf.takeBuffer(minPktLen)
} else {
data = mc.buf.takeCompleteBuffer()
}
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
}
// command [1 byte]
data[4] = comStmtExecute
// statement_id [4 bytes]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
data[9] = 0x00
// iteration_count (uint32(1)) [4 bytes]
data[10] = 0x01
data[11] = 0x00
data[12] = 0x00
data[13] = 0x00
if len(args) > 0 {
pos := minPktLen
var nullMask []byte
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
// buffer has to be extended but we don't know by how much so
// we depend on append after all data with known sizes fit.
// We stop at that because we deal with a lot of columns here
// which makes the required allocation size hard to guess.
tmp := make([]byte, pos+maskLen+typesLen)
copy(tmp[:pos], data[:pos])
data = tmp
nullMask = data[pos : pos+maskLen]
pos += maskLen
} else {
nullMask = data[pos : pos+maskLen]
for i := 0; i < maskLen; i++ {
nullMask[i] = 0
}
pos += maskLen
}
// newParameterBoundFlag 1 [1 byte]
data[pos] = 0x01
pos++
// type of each parameter [len(args)*2 bytes]
paramTypes := data[pos:]
pos += len(args) * 2
// value of each parameter [n bytes]
paramValues := data[pos:pos]
valuesCap := cap(paramValues)
for i, arg := range args {
// build NULL-bitmap
if arg == nil {
nullMask[i/8] |= 1 << (uint(i) & 7)
paramTypes[i+i] = fieldTypeNULL
paramTypes[i+i+1] = 0x00
continue
}
// cache types and values
switch v := arg.(type) {
case int64:
paramTypes[i+i] = fieldTypeLongLong
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case float64:
paramTypes[i+i] = fieldTypeDouble
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
math.Float64bits(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(math.Float64bits(v))...,
)
}
case bool:
paramTypes[i+i] = fieldTypeTiny
paramTypes[i+i+1] = 0x00
if v {
paramValues = append(paramValues, 0x01)
} else {
paramValues = append(paramValues, 0x00)
}
case []byte:
// Common case (non-nil value) first
if v != nil {
paramTypes[i+i] = fieldTypeString
paramTypes[i+i+1] = 0x00
if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
paramValues = append(paramValues, v...)
} else {
if err := stmt.writeCommandLongData(i, v); err != nil {
return err
}
}
continue
}
// Handle []byte(nil) as a NULL value
nullMask[i/8] |= 1 << (uint(i) & 7)
paramTypes[i+i] = fieldTypeNULL
paramTypes[i+i+1] = 0x00
case string:
paramTypes[i+i] = fieldTypeString
paramTypes[i+i+1] = 0x00
if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
paramValues = append(paramValues, v...)
} else {
if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
return err
}
}
case time.Time:
paramTypes[i+i] = fieldTypeString
paramTypes[i+i+1] = 0x00
var val []byte
if v.IsZero() {
val = []byte("0000-00-00")
} else {
val = []byte(v.In(mc.cfg.loc).Format(timeFormat))
}
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(val)),
)
paramValues = append(paramValues, val...)
default:
return fmt.Errorf("Can't convert type: %T", arg)
}
}
// Check if param values exceeded the available buffer
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
mc.buf.buf = data
}
pos += len(paramValues)
data = data[:pos]
}
return mc.writePacket(data)
}
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
func (rows *binaryRows) readRow(dest []driver.Value) error {
data, err := rows.mc.readPacket()
if err != nil {
return err
}
// packet indicator [1 byte]
if data[0] != iOK {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
return io.EOF
}
// Error otherwise
return rows.mc.handleErrorPacket(data)
}
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
pos := 1 + (len(dest)+7+2)>>3
nullMask := data[1:pos]
for i := range dest {
// Field is NULL
// (byte >> bit-pos) % 2 == 1
if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
dest[i] = nil
continue
}
// Convert to byte-coded string
switch rows.columns[i].fieldType {
case fieldTypeNULL:
dest[i] = nil
continue
// Numeric Types
case fieldTypeTiny:
if rows.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(data[pos])
} else {
dest[i] = int64(int8(data[pos]))
}
pos++
continue
case fieldTypeShort, fieldTypeYear:
if rows.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
} else {
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
}
pos += 2
continue
case fieldTypeInt24, fieldTypeLong:
if rows.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
} else {
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
}
pos += 4
continue
case fieldTypeLongLong:
if rows.columns[i].flags&flagUnsigned != 0 {
val := binary.LittleEndian.Uint64(data[pos : pos+8])
if val > math.MaxInt64 {
dest[i] = uint64ToString(val)
} else {
dest[i] = int64(val)
}
} else {
dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
}
pos += 8
continue
case fieldTypeFloat:
dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
pos += 4
continue
case fieldTypeDouble:
dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
pos += 8
continue
// Length coded Binary Strings
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
var isNull bool
var n int
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
pos += n
if err == nil {
if !isNull {
continue
} else {
dest[i] = nil
continue
}
}
return err
// Date YYYY-MM-DD
case fieldTypeDate, fieldTypeNewDate:
num, isNull, n := readLengthEncodedInteger(data[pos:])
pos += n
if isNull {
dest[i] = nil
continue
}
if rows.mc.parseTime {
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc)
} else {
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], false)
}
if err == nil {
pos += int(num)
continue
} else {
return err
}
// Time [-][H]HH:MM:SS[.fractal]
case fieldTypeTime:
num, isNull, n := readLengthEncodedInteger(data[pos:])
pos += n
if num == 0 {
if isNull {
dest[i] = nil
continue
} else {
dest[i] = []byte("00:00:00")
continue
}
}
var sign string
if data[pos] == 1 {
sign = "-"
}
switch num {
case 8:
dest[i] = []byte(fmt.Sprintf(
sign+"%02d:%02d:%02d",
uint16(data[pos+1])*24+uint16(data[pos+5]),
data[pos+6],
data[pos+7],
))
pos += 8
continue
case 12:
dest[i] = []byte(fmt.Sprintf(
sign+"%02d:%02d:%02d.%06d",
uint16(data[pos+1])*24+uint16(data[pos+5]),
data[pos+6],
data[pos+7],
binary.LittleEndian.Uint32(data[pos+8:pos+12]),
))
pos += 12
continue
default:
return fmt.Errorf("Invalid TIME-packet length %d", num)
}
// Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
case fieldTypeTimestamp, fieldTypeDateTime:
num, isNull, n := readLengthEncodedInteger(data[pos:])
pos += n
if isNull {
dest[i] = nil
continue
}
if rows.mc.parseTime {
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc)
} else {
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], true)
}
if err == nil {
pos += int(num)
continue
} else {
return err
}
// Please report if this happens!
default:
return fmt.Errorf("Unknown FieldType %d", rows.columns[i].fieldType)
}
}
return nil
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlResult struct {
affectedRows int64
insertId int64
}
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertId, nil
}
func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows, nil
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"io"
)
type mysqlField struct {
fieldType byte
flags fieldFlag
name string
}
type mysqlRows struct {
mc *mysqlConn
columns []mysqlField
}
type binaryRows struct {
mysqlRows
}
type textRows struct {
mysqlRows
}
func (rows *mysqlRows) Columns() []string {
columns := make([]string, len(rows.columns))
for i := range columns {
columns[i] = rows.columns[i].name
}
return columns
}
func (rows *mysqlRows) Close() error {
mc := rows.mc
if mc == nil {
return nil
}
if mc.netConn == nil {
return ErrInvalidConn
}
// Remove unread packets from stream
err := mc.readUntilEOF()
rows.mc = nil
return err
}
func (rows *binaryRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return ErrInvalidConn
}
// Fetch next row from stream
if err := rows.readRow(dest); err != io.EOF {
return err
}
rows.mc = nil
}
return io.EOF
}
func (rows *textRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return ErrInvalidConn
}
// Fetch next row from stream
if err := rows.readRow(dest); err != io.EOF {
return err
}
rows.mc = nil
}
return io.EOF
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
)
type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
columns []mysqlField // cached from the first query
}
func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
stmt.mc = nil
return err
}
func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
}
mc := stmt.mc
mc.affectedRows = 0
mc.insertId = 0
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
if resLen > 0 {
// Columns
err = mc.readUntilEOF()
if err != nil {
return nil, err
}
// Rows
err = mc.readUntilEOF()
}
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, nil
}
}
return nil, err
}
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
if stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
}
mc := stmt.mc
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
rows := new(binaryRows)
rows.mc = mc
if resLen > 0 {
// Columns
// If not cached, read them and cache them
if stmt.columns == nil {
rows.columns, err = mc.readColumns(resLen)
stmt.columns = rows.columns
} else {
rows.columns = stmt.columns
err = mc.readUntilEOF()
}
}
return rows, err
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlTx struct {
mc *mysqlConn
}
func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return ErrInvalidConn
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}
func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return ErrInvalidConn
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/sha1"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"net/url"
"strings"
"time"
)
var (
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name")
)
func init() {
tlsConfigRegister = make(map[string]*tls.Config)
}
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
// Use the key as a value in the DSN where tls=value.
//
// rootCertPool := x509.NewCertPool()
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
// if err != nil {
// log.Fatal(err)
// }
// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
// log.Fatal("Failed to append PEM.")
// }
// clientCert := make([]tls.Certificate, 0, 1)
// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
// if err != nil {
// log.Fatal(err)
// }
// clientCert = append(clientCert, certs)
// mysql.RegisterTLSConfig("custom", &tls.Config{
// RootCAs: rootCertPool,
// Certificates: clientCert,
// })
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
//
func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
return fmt.Errorf("Key '%s' is reserved", key)
}
tlsConfigRegister[key] = config
return nil
}
// DeregisterTLSConfig removes the tls.Config associated with key.
func DeregisterTLSConfig(key string) {
delete(tlsConfigRegister, key)
}
// parseDSN parses the DSN string to a config
func parseDSN(dsn string) (cfg *config, err error) {
cfg = new(config)
// TODO: use strings.IndexByte when we can depend on Go 1.2
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
foundSlash = true
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.passwd = dsn[k+1 : j]
break
}
}
cfg.user = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.addr = dsn[k+1 : i-1]
break
}
}
cfg.net = dsn[j+1 : k]
}
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
cfg.dbname = dsn[i+1 : j]
break
}
}
if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
}
// Set default network if empty
if cfg.net == "" {
cfg.net = "tcp"
}
// Set default address if empty
if cfg.addr == "" {
switch cfg.net {
case "tcp":
cfg.addr = "127.0.0.1:3306"
case "unix":
cfg.addr = "/tmp/mysql.sock"
default:
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
}
}
// Set default location if empty
if cfg.loc == nil {
cfg.loc = time.UTC
}
return
}
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}
// cfg params
switch value := param[1]; param[0] {
// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.allowAllFiles, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.clientFoundRows, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.allowOldPasswords, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Time Location
case "loc":
if value, err = url.QueryUnescape(value); err != nil {
return
}
cfg.loc, err = time.LoadLocation(value)
if err != nil {
return
}
// Dial Timeout
case "timeout":
cfg.timeout, err = time.ParseDuration(value)
if err != nil {
return
}
// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.tls = &tls.Config{}
}
} else {
if strings.ToLower(value) == "skip-verify" {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
cfg.tls = tlsConfig
} else {
return fmt.Errorf("Invalid value / unknown config name: %s", value)
}
}
default:
// lazy init
if cfg.params == nil {
cfg.params = make(map[string]string)
}
if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
}
return
}
// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}
// Not a valid bool value
return
}
/******************************************************************************
* Authentication *
******************************************************************************/
// Encrypt password using 4.1+ method
func scramblePassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}
// stage1Hash = SHA1(password)
crypt := sha1.New()
crypt.Write(password)
stage1 := crypt.Sum(nil)
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
// inner Hash
crypt.Reset()
crypt.Write(stage1)
hash := crypt.Sum(nil)
// outer Hash
crypt.Reset()
crypt.Write(scramble)
crypt.Write(hash)
scramble = crypt.Sum(nil)
// token = scrambleHash XOR stage1Hash
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
// Encrypt password using pre 4.1 (old password) method
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
type myRnd struct {
seed1, seed2 uint32
}
const myRndMaxVal = 0x3FFFFFFF
// Pseudo random number generator
func newMyRnd(seed1, seed2 uint32) *myRnd {
return &myRnd{
seed1: seed1 % myRndMaxVal,
seed2: seed2 % myRndMaxVal,
}
}
// Tested to be equivalent to MariaDB's floating point variant
// http://play.golang.org/p/QHvhd4qved
// http://play.golang.org/p/RG0q4ElWDx
func (r *myRnd) NextByte() byte {
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
}
// Generate binary hash from byte string using insecure pre 4.1 method
func pwHash(password []byte) (result [2]uint32) {
var add uint32 = 7
var tmp uint32
result[0] = 1345345333
result[1] = 0x12345671
for _, c := range password {
// skip spaces and tabs in password
if c == ' ' || c == '\t' {
continue
}
tmp = uint32(c)
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
result[1] += (result[1] << 8) ^ result[0]
add += tmp
}
// Remove sign bit (1<<31)-1)
result[0] &= 0x7FFFFFFF
result[1] &= 0x7FFFFFFF
return
}
// Encrypt password using insecure pre 4.1 method
func scrambleOldPassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}
scramble = scramble[:8]
hashPw := pwHash(password)
hashSc := pwHash(scramble)
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
var out [8]byte
for i := range out {
out[i] = r.NextByte() + 64
}
mask := r.NextByte()
for i := range out {
out[i] ^= mask
}
return out[:]
}
/******************************************************************************
* Time related utils *
******************************************************************************/
// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// This NullTime implementation is not driver-specific
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails.
func (nt *NullTime) Scan(value interface{}) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
}
switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
}
nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value)
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
switch len(str) {
case 10: // YYYY-MM-DD
if str == "0000-00-00" {
return
}
t, err = time.Parse(timeFormat[:10], str)
case 19: // YYYY-MM-DD HH:MM:SS
if str == "0000-00-00 00:00:00" {
return
}
t, err = time.Parse(timeFormat, str)
default:
err = fmt.Errorf("Invalid Time-String: %s", str)
return
}
// Adjust location
if err == nil && loc != time.UTC {
y, mo, d := t.Date()
h, mi, s := t.Clock()
t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
}
return
}
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
switch num {
case 0:
return time.Time{}, nil
case 4:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
0, 0, 0, 0,
loc,
), nil
case 7:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
0,
loc,
), nil
case 11:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
loc,
), nil
}
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
}
// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
// if the DATE or DATETIME has the zero value.
// It must never be changed.
// The current behavior depends on database/sql copying the result.
var zeroDateTime = []byte("0000-00-00 00:00:00")
func formatBinaryDateTime(src []byte, withTime bool) (driver.Value, error) {
if len(src) == 0 {
if withTime {
return zeroDateTime, nil
}
return zeroDateTime[:10], nil
}
var dst []byte
if withTime {
if len(src) == 11 {
dst = []byte("0000-00-00 00:00:00.000000")
} else {
dst = []byte("0000-00-00 00:00:00")
}
} else {
dst = []byte("0000-00-00")
}
switch len(src) {
case 11:
microsecs := binary.LittleEndian.Uint32(src[7:11])
tmp32 := microsecs / 10
dst[25] += byte(microsecs - 10*tmp32)
tmp32, microsecs = tmp32/10, tmp32
dst[24] += byte(microsecs - 10*tmp32)
tmp32, microsecs = tmp32/10, tmp32
dst[23] += byte(microsecs - 10*tmp32)
tmp32, microsecs = tmp32/10, tmp32
dst[22] += byte(microsecs - 10*tmp32)
tmp32, microsecs = tmp32/10, tmp32
dst[21] += byte(microsecs - 10*tmp32)
dst[20] += byte(microsecs / 10)
fallthrough
case 7:
second := src[6]
tmp := second / 10
dst[18] += second - 10*tmp
dst[17] += tmp
minute := src[5]
tmp = minute / 10
dst[15] += minute - 10*tmp
dst[14] += tmp
hour := src[4]
tmp = hour / 10
dst[12] += hour - 10*tmp
dst[11] += tmp
fallthrough
case 4:
day := src[3]
tmp := day / 10
dst[9] += day - 10*tmp
dst[8] += tmp
month := src[2]
tmp = month / 10
dst[6] += month - 10*tmp
dst[5] += tmp
year := binary.LittleEndian.Uint16(src[:2])
tmp16 := year / 10
dst[3] += byte(year - 10*tmp16)
tmp16, year = tmp16/10, tmp16
dst[2] += byte(year - 10*tmp16)
tmp16, year = tmp16/10, tmp16
dst[1] += byte(year - 10*tmp16)
dst[0] += byte(tmp16)
return dst, nil
}
var t string
if withTime {
t = "DATETIME"
} else {
t = "DATE"
}
return nil, fmt.Errorf("invalid %s-packet length %d", t, len(src))
}
/******************************************************************************
* Convert from and to bytes *
******************************************************************************/
func uint64ToBytes(n uint64) []byte {
return []byte{
byte(n),
byte(n >> 8),
byte(n >> 16),
byte(n >> 24),
byte(n >> 32),
byte(n >> 40),
byte(n >> 48),
byte(n >> 56),
}
}
func uint64ToString(n uint64) []byte {
var a [20]byte
i := 20
// U+0030 = 0
// ...
// U+0039 = 9
var q uint64
for n >= 10 {
i--
q = n / 10
a[i] = uint8(n-q*10) + 0x30
n = q
}
i--
a[i] = uint8(n) + 0x30
return a[i:]
}
// treats string value as unsigned integer representation
func stringToInt(b []byte) int {
val := 0
for i := range b {
val *= 10
val += int(b[i] - 0x30)
}
return val
}
// returns the string read as a bytes slice, wheter the value is NULL,
// the number of bytes read and an error, in case the string is longer than
// the input slice
func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
// Get length
num, isNull, n := readLengthEncodedInteger(b)
if num < 1 {
return b[n:n], isNull, n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return b[n-int(num) : n], false, n, nil
}
return nil, false, n, io.EOF
}
// returns the number of bytes skipped and an error, in case the string is
// longer than the input slice
func skipLengthEncodedString(b []byte) (int, error) {
// Get length
num, _, n := readLengthEncodedInteger(b)
if num < 1 {
return n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return n, nil
}
return n, io.EOF
}
// returns the number read, whether the value is NULL and the number of bytes read
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
switch b[0] {
// 251: NULL
case 0xfb:
return 0, true, 1
// 252: value of following 2
case 0xfc:
return uint64(b[1]) | uint64(b[2])<<8, false, 3
// 253: value of following 3
case 0xfd:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
// 254: value of following 8
case 0xfe:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<56,
false, 9
}
// 0-250: value of first byte
return uint64(b[0]), false, 1
}
// encodes a uint64 value and appends it to the given bytes slice
func appendLengthEncodedInteger(b []byte, n uint64) []byte {
switch {
case n <= 250:
return append(b, byte(n))
case n <= 0xffff:
return append(b, 0xfc, byte(n), byte(n>>8))
case n <= 0xffffff:
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
}
return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"
)
var testDSNs = []struct {
in string
out string
loc *time.Location
}{
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
}
func TestDSNParser(t *testing.T) {
var cfg *config
var err error
var res string
for i, tst := range testDSNs {
cfg, err = parseDSN(tst.in)
if err != nil {
t.Error(err.Error())
}
// pointer not static
cfg.tls = nil
res = fmt.Sprintf("%+v", cfg)
if res != fmt.Sprintf(tst.out, tst.loc) {
t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc))
}
}
}
func TestDSNParserInvalid(t *testing.T) {
var invalidDSNs = []string{
"@net(addr/", // no closing brace
"@tcp(/", // no closing brace
"tcp(/", // no closing brace
"(/", // no closing brace
"net(addr)//", // unescaped
"user:pass@tcp(1.2.3.4:3306)", // no trailing slash
//"/dbname?arg=/some/unescaped/path",
}
for i, tst := range invalidDSNs {
if _, err := parseDSN(tst); err == nil {
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
}
}
}
func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for _, tst := range testDSNs {
if _, err := parseDSN(tst.in); err != nil {
b.Error(err.Error())
}
}
}
}
func TestScanNullTime(t *testing.T) {
var scanTests = []struct {
in interface{}
error bool
valid bool
time time.Time
}{
{tDate, false, true, tDate},
{sDate, false, true, tDate},
{[]byte(sDate), false, true, tDate},
{tDateTime, false, true, tDateTime},
{sDateTime, false, true, tDateTime},
{[]byte(sDateTime), false, true, tDateTime},
{tDate0, false, true, tDate0},
{sDate0, false, true, tDate0},
{[]byte(sDate0), false, true, tDate0},
{sDateTime0, false, true, tDate0},
{[]byte(sDateTime0), false, true, tDate0},
{"", true, false, tDate0},
{"1234", true, false, tDate0},
{0, true, false, tDate0},
}
var nt = NullTime{}
var err error
for _, tst := range scanTests {
err = nt.Scan(tst.in)
if (err != nil) != tst.error {
t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil))
}
if nt.Valid != tst.valid {
t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid)
}
if nt.Time != tst.time {
t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
}
}
}
func TestLengthEncodedInteger(t *testing.T) {
var integerTests = []struct {
num uint64
encoded []byte
}{
{0x0000000000000000, []byte{0x00}},
{0x0000000000000012, []byte{0x12}},
{0x00000000000000fa, []byte{0xfa}},
{0x0000000000000100, []byte{0xfc, 0x00, 0x01}},
{0x0000000000001234, []byte{0xfc, 0x34, 0x12}},
{0x000000000000ffff, []byte{0xfc, 0xff, 0xff}},
{0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}},
{0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}},
{0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}},
{0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}},
{0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}},
{0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
}
for _, tst := range integerTests {
num, isNull, numLen := readLengthEncodedInteger(tst.encoded)
if isNull {
t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num)
}
if num != tst.num {
t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num)
}
if numLen != len(tst.encoded) {
t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen)
}
encoded := appendLengthEncodedInteger(nil, num)
if !bytes.Equal(encoded, tst.encoded) {
t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded)
}
}
}
func TestOldPass(t *testing.T) {
scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2}
vectors := []struct {
pass string
out string
}{
{" pass", "47575c5a435b4251"},
{"pass ", "47575c5a435b4251"},
{"123\t456", "575c47505b5b5559"},
{"C0mpl!ca ted#PASS123", "5d5d554849584a45"},
}
for _, tuple := range vectors {
ours := scrambleOldPassword(scramble, []byte(tuple.pass))
if tuple.out != fmt.Sprintf("%x", ours) {
t.Errorf("Failed old password %q", tuple.pass)
}
}
}
func TestFormatBinaryDateTime(t *testing.T) {
rawDate := [11]byte{}
binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years
rawDate[2] = 12 // months
rawDate[3] = 30 // days
rawDate[4] = 15 // hours
rawDate[5] = 46 // minutes
rawDate[6] = 23 // seconds
binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds
expect := func(expected string, length int, withTime bool) {
actual, _ := formatBinaryDateTime(rawDate[:length], withTime)
bytes, ok := actual.([]byte)
if !ok {
t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
}
if string(bytes) != expected {
t.Errorf(
"expected %q, got %q for length %d, withTime %v",
bytes, actual, length, withTime,
)
}
}
expect("0000-00-00", 0, false)
expect("0000-00-00 00:00:00", 0, true)
expect("1978-12-30", 4, false)
expect("1978-12-30 15:46:23", 7, true)
expect("1978-12-30 15:46:23.987654", 11, true)
}
This source diff could not be displayed because it is too large. You can view the blob instead.
package sqlite3s
/*
#include <sqlite3.h>
#include <stdlib.h>
#include <string.h>
#ifndef SQLITE_OPEN_READWRITE
# define SQLITE_OPEN_READWRITE 0
#endif
#ifndef SQLITE_OPEN_FULLMUTEX
# define SQLITE_OPEN_FULLMUTEX 0
#endif
static int
_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) {
#ifdef SQLITE_OPEN_URI
return sqlite3_open_v2(filename, ppDb, flags | SQLITE_OPEN_URI, zVfs);
#else
return sqlite3_open_v2(filename, ppDb, flags, zVfs);
#endif
}
static int
_sqlite3_bind_text(sqlite3_stmt *stmt, int n, char *p, int np) {
return sqlite3_bind_text(stmt, n, p, np, SQLITE_TRANSIENT);
}
static int
_sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
return sqlite3_bind_blob(stmt, n, p, np, SQLITE_TRANSIENT);
}
#include <stdio.h>
#include <stdint.h>
static long
_sqlite3_last_insert_rowid(sqlite3* db) {
return (long) sqlite3_last_insert_rowid(db);
}
static long
_sqlite3_changes(sqlite3* db) {
return (long) sqlite3_changes(db);
}
*/
import "C"
import (
"database/sql"
"database/sql/driver"
"errors"
"io"
"strings"
"time"
"unsafe"
)
// Timestamp formats understood by both this module and SQLite.
// The first format in the slice will be used when saving time values
// into the database. When parsing a string from a timestamp or
// datetime column, the formats are tried in order.
var SQLiteTimestampFormats = []string{
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04",
"2006-01-02",
}
func init() {
sql.Register("sqlite3", &SQLiteDriver{})
}
// Driver struct.
type SQLiteDriver struct {
}
// Conn struct.
type SQLiteConn struct {
db *C.sqlite3
}
// Tx struct.
type SQLiteTx struct {
c *SQLiteConn
}
// Stmt struct.
type SQLiteStmt struct {
c *SQLiteConn
s *C.sqlite3_stmt
t string
closed bool
}
// Result struct.
type SQLiteResult struct {
id int64
changes int64
}
// Rows struct.
type SQLiteRows struct {
s *SQLiteStmt
nc int
cols []string
decltype []string
}
// Commit transaction.
func (tx *SQLiteTx) Commit() error {
if err := tx.c.exec("COMMIT"); err != nil {
return err
}
return nil
}
// Rollback transaction.
func (tx *SQLiteTx) Rollback() error {
if err := tx.c.exec("ROLLBACK"); err != nil {
return err
}
return nil
}
func (c *SQLiteConn) exec(cmd string) error {
pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd))
rv := C.sqlite3_exec(c.db, pcmd, nil, nil, nil)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
return nil
}
// Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) {
if err := c.exec("BEGIN"); err != nil {
return nil, err
}
return &SQLiteTx{c}, nil
}
// Open database and return a new connection.
// You can specify DSN string with URI filename.
// test.db
// file:test.db?cache=shared&mode=memory
// :memory:
// file::memory:
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
}
var db *C.sqlite3
name := C.CString(dsn)
defer C.free(unsafe.Pointer(name))
rv := C._sqlite3_open_v2(name, &db,
C.SQLITE_OPEN_FULLMUTEX|
C.SQLITE_OPEN_READWRITE|
C.SQLITE_OPEN_CREATE,
nil)
if rv != 0 {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
if db == nil {
return nil, errors.New("sqlite succeeded without returning a database")
}
rv = C.sqlite3_busy_timeout(db, 5000)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
return &SQLiteConn{db}, nil
}
// Close the connection.
func (c *SQLiteConn) Close() error {
s := C.sqlite3_next_stmt(c.db, nil)
for s != nil {
C.sqlite3_finalize(s)
s = C.sqlite3_next_stmt(c.db, nil)
}
rv := C.sqlite3_close(c.db)
if rv != C.SQLITE_OK {
return errors.New("error while closing sqlite database connection")
}
c.db = nil
return nil
}
// Prepare query string. Return a new statement.
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
pquery := C.CString(query)
defer C.free(unsafe.Pointer(pquery))
var s *C.sqlite3_stmt
var perror *C.char
rv := C.sqlite3_prepare_v2(c.db, pquery, -1, &s, &perror)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
var t string
if perror != nil && C.strlen(perror) > 0 {
t = C.GoString(perror)
}
return &SQLiteStmt{c: c, s: s, t: t}, nil
}
// Close the statement.
func (s *SQLiteStmt) Close() error {
if s.closed {
return nil
}
s.closed = true
if s.c == nil || s.c.db == nil {
return errors.New("sqlite statement with already closed database connection")
}
rv := C.sqlite3_finalize(s.s)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
}
return nil
}
// Return a number of parameters.
func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
}
func (s *SQLiteStmt) bind(args []driver.Value) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
}
for i, v := range args {
n := C.int(i + 1)
switch v := v.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
if len(v) == 0 {
b := []byte{0}
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
} else {
b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
case int:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case int32:
rv = C.sqlite3_bind_int(s.s, n, C.int(v))
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case byte:
rv = C.sqlite3_bind_int(s.s, n, C.int(v))
case bool:
if bool(v) {
rv = C.sqlite3_bind_int(s.s, n, 1)
} else {
rv = C.sqlite3_bind_int(s.s, n, 0)
}
case float32:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
var p *byte
if len(v) > 0 {
p = &v[0]
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
case time.Time:
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
}
}
return nil
}
// Query the statment with arguments. Return records.
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
if err := s.bind(args); err != nil {
return nil, err
}
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil}, nil
}
// Return last inserted ID.
func (r *SQLiteResult) LastInsertId() (int64, error) {
return r.id, nil
}
// Return how many rows affected.
func (r *SQLiteResult) RowsAffected() (int64, error) {
return r.changes, nil
}
// Execute the statement with arguments. Return result object.
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
if err := s.bind(args); err != nil {
return nil, err
}
rv := C.sqlite3_step(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
}
res := &SQLiteResult{
int64(C._sqlite3_last_insert_rowid(s.c.db)),
int64(C._sqlite3_changes(s.c.db)),
}
return res, nil
}
// Close the rows.
func (rc *SQLiteRows) Close() error {
rv := C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(rc.s.c.db)))
}
return nil
}
// Return column names.
func (rc *SQLiteRows) Columns() []string {
if rc.nc != len(rc.cols) {
rc.cols = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
}
}
return rc.cols
}
// Move cursor to next.
func (rc *SQLiteRows) Next(dest []driver.Value) error {
rv := C.sqlite3_step(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF
}
if rv != C.SQLITE_ROW {
return errors.New(C.GoString(C.sqlite3_errmsg(rc.s.c.db)))
}
if rc.decltype == nil {
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
}
}
for i := range dest {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
case C.SQLITE_INTEGER:
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
switch rc.decltype[i] {
case "timestamp", "datetime":
dest[i] = time.Unix(val, 0)
case "boolean":
dest[i] = val > 0
default:
dest[i] = val
}
case C.SQLITE_FLOAT:
dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i)))
case C.SQLITE_BLOB:
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
p := C.sqlite3_column_blob(rc.s.s, C.int(i))
switch dest[i].(type) {
case sql.RawBytes:
dest[i] = (*[1 << 30]byte)(unsafe.Pointer(p))[0:n]
default:
slice := make([]byte, n)
copy(slice[:], (*[1 << 30]byte)(unsafe.Pointer(p))[0:n])
dest[i] = slice
}
case C.SQLITE_NULL:
dest[i] = nil
case C.SQLITE_TEXT:
var err error
s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))))
switch rc.decltype[i] {
case "timestamp", "datetime":
for _, format := range SQLiteTimestampFormats {
if dest[i], err = time.Parse(format, s); err == nil {
break
}
}
if err != nil {
// The column is a time value, so return the zero time on parse failure.
dest[i] = time.Time{}
}
default:
dest[i] = s
}
}
}
return nil
}
This source diff could not be displayed because it is too large. You can view the blob instead.
package sqlite3s
/*
#cgo CFLAGS: -DNDEBUG
#cgo CFLAGS: -I .
#cgo linux LDFLAGS: -ldl
*/
import "C"
package sqlite3s
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"os"
"path/filepath"
"testing"
"time"
)
func TempFilename() string {
randBytes := make([]byte, 16)
rand.Read(randBytes)
return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db")
}
func TestOpen(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() {
t.Error("Failed to create ./foo.db")
}
}
func TestInsert(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
res, err := db.Exec("insert into foo(id) values(123)")
if err != nil {
t.Fatal("Failed to insert record:", err)
}
affected, _ := res.RowsAffected()
if affected != 1 {
t.Fatalf("Expected %d for affected rows, but %d:", 1, affected)
}
rows, err := db.Query("select id from foo")
if err != nil {
t.Fatal("Failed to select records:", err)
}
defer rows.Close()
rows.Next()
var result int
rows.Scan(&result)
if result != 123 {
t.Errorf("Fetched %q; expected %q", 123, result)
}
}
func TestUpdate(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
res, err := db.Exec("insert into foo(id) values(123)")
if err != nil {
t.Fatal("Failed to insert record:", err)
}
expected, err := res.LastInsertId()
if err != nil {
t.Fatal("Failed to get LastInsertId:", err)
}
affected, _ := res.RowsAffected()
if err != nil {
t.Fatal("Failed to get RowsAffected:", err)
}
if affected != 1 {
t.Fatalf("Expected %d for affected rows, but %d:", 1, affected)
}
res, err = db.Exec("update foo set id = 234")
if err != nil {
t.Fatal("Failed to update record:", err)
}
lastId, err := res.LastInsertId()
if err != nil {
t.Fatal("Failed to get LastInsertId:", err)
}
if expected != lastId {
t.Errorf("Expected %q for last Id, but %q:", expected, lastId)
}
affected, _ = res.RowsAffected()
if err != nil {
t.Fatal("Failed to get RowsAffected:", err)
}
if affected != 1 {
t.Fatalf("Expected %d for affected rows, but %d:", 1, affected)
}
rows, err := db.Query("select id from foo")
if err != nil {
t.Fatal("Failed to select records:", err)
}
defer rows.Close()
rows.Next()
var result int
rows.Scan(&result)
if result != 234 {
t.Errorf("Fetched %q; expected %q", 234, result)
}
}
func TestDelete(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
res, err := db.Exec("insert into foo(id) values(123)")
if err != nil {
t.Fatal("Failed to insert record:", err)
}
expected, err := res.LastInsertId()
if err != nil {
t.Fatal("Failed to get LastInsertId:", err)
}
affected, err := res.RowsAffected()
if err != nil {
t.Fatal("Failed to get RowsAffected:", err)
}
if affected != 1 {
t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected)
}
res, err = db.Exec("delete from foo where id = 123")
if err != nil {
t.Fatal("Failed to delete record:", err)
}
lastId, err := res.LastInsertId()
if err != nil {
t.Fatal("Failed to get LastInsertId:", err)
}
if expected != lastId {
t.Errorf("Expected %q for last Id, but %q:", expected, lastId)
}
affected, err = res.RowsAffected()
if err != nil {
t.Fatal("Failed to get RowsAffected:", err)
}
if affected != 1 {
t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected)
}
rows, err := db.Query("select id from foo")
if err != nil {
t.Fatal("Failed to select records:", err)
}
defer rows.Close()
if rows.Next() {
t.Error("Fetched row but expected not rows")
}
}
func TestBooleanRoundtrip(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE TABLE foo(id INTEGER, value BOOL)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(1, ?)", true)
if err != nil {
t.Fatal("Failed to insert true value:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(2, ?)", false)
if err != nil {
t.Fatal("Failed to insert false value:", err)
}
rows, err := db.Query("SELECT id, value FROM foo")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
for rows.Next() {
var id int
var value bool
if err := rows.Scan(&id, &value); err != nil {
t.Error("Unable to scan results:", err)
continue
}
if id == 1 && !value {
t.Error("Value for id 1 should be true, not false")
} else if id == 2 && value {
t.Error("Value for id 2 should be false, not true")
}
}
}
func TestTimestamp(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE TABLE foo(id INTEGER, ts timeSTAMP, dt DATETIME)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC)
timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC)
timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC)
tests := []struct {
value interface{}
expected time.Time
}{
{"nonsense", time.Time{}},
{"0000-00-00 00:00:00", time.Time{}},
{timestamp1, timestamp1},
{timestamp1.Unix(), timestamp1},
{timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1},
{timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1},
{timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1},
{timestamp1.Format("2006-01-02 15:04:05"), timestamp1},
{timestamp1.Format("2006-01-02T15:04:05"), timestamp1},
{timestamp2, timestamp2},
{"2006-01-02 15:04:05.123456789", timestamp2},
{"2006-01-02T15:04:05.123456789", timestamp2},
{"2012-11-04", timestamp3},
{"2012-11-04 00:00", timestamp3},
{"2012-11-04 00:00:00", timestamp3},
{"2012-11-04 00:00:00.000", timestamp3},
{"2012-11-04T00:00", timestamp3},
{"2012-11-04T00:00:00", timestamp3},
{"2012-11-04T00:00:00.000", timestamp3},
}
for i := range tests {
_, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value)
if err != nil {
t.Fatal("Failed to insert timestamp:", err)
}
}
rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
seen := 0
for rows.Next() {
var id int
var ts, dt time.Time
if err := rows.Scan(&id, &ts, &dt); err != nil {
t.Error("Unable to scan results:", err)
continue
}
if id < 0 || id >= len(tests) {
t.Error("Bad row id: ", id)
continue
}
seen++
if !tests[id].expected.Equal(ts) {
t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
}
if !tests[id].expected.Equal(dt) {
t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
}
}
if seen != len(tests) {
t.Errorf("Expected to see %d rows", len(tests))
}
}
func TestBoolean(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
bool1 := true
_, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(1, ?)", bool1)
if err != nil {
t.Fatal("Failed to insert boolean:", err)
}
bool2 := false
_, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(2, ?)", bool2)
if err != nil {
t.Fatal("Failed to insert boolean:", err)
}
bool3 := "nonsense"
_, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(3, ?)", bool3)
if err != nil {
t.Fatal("Failed to insert nonsense:", err)
}
rows, err := db.Query("SELECT id, fbool FROM foo where fbool = ?", bool1)
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
counter := 0
var id int
var fbool bool
for rows.Next() {
if err := rows.Scan(&id, &fbool); err != nil {
t.Fatal("Unable to scan results:", err)
}
counter++
}
if counter != 1 {
t.Fatalf("Expected 1 row but %v", counter)
}
if id != 1 && fbool != true {
t.Fatalf("Value for id 1 should be %v, not %v", bool1, fbool)
}
rows, err = db.Query("SELECT id, fbool FROM foo where fbool = ?", bool2)
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
counter = 0
for rows.Next() {
if err := rows.Scan(&id, &fbool); err != nil {
t.Fatal("Unable to scan results:", err)
}
counter++
}
if counter != 1 {
t.Fatalf("Expected 1 row but %v", counter)
}
if id != 2 && fbool != false {
t.Fatalf("Value for id 2 should be %v, not %v", bool2, fbool)
}
// make sure "nonsense" triggered an error
rows, err = db.Query("SELECT id, fbool FROM foo where id=?;", 3)
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
rows.Next()
err = rows.Scan(&id, &fbool)
if err == nil {
t.Error("Expected error from \"nonsense\" bool")
}
}
package beedb
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
var OnDebug = false
var PluralizeTableNames = false
type Model struct {
Db *sql.DB
TableName string
LimitStr int
OffsetStr int
WhereStr string
ParamStr []interface{}
OrderStr string
ColumnStr string
PrimaryKey string
JoinStr string
GroupByStr string
HavingStr string
QuoteIdentifier string
ParamIdentifier string
ParamIteration int
}
/**
* Add New sql.DB in the future i will add ConnectionPool.Get()
*/
func New(db *sql.DB, options ...interface{}) (m Model) {
if len(options) == 0 {
m = Model{Db: db, ColumnStr: "*", PrimaryKey: "Id", QuoteIdentifier: "`", ParamIdentifier: "?", ParamIteration: 1}
} else if options[0] == "pg" {
m = Model{Db: db, ColumnStr: "id", PrimaryKey: "Id", QuoteIdentifier: "\"", ParamIdentifier: options[0].(string), ParamIteration: 1}
} else if options[0] == "mssql" {
m = Model{Db: db, ColumnStr: "id", PrimaryKey: "id", QuoteIdentifier: "", ParamIdentifier: options[0].(string), ParamIteration: 1}
}
return
}
func (orm *Model) SetTable(tbname string) *Model {
orm.TableName = tbname
return orm
}
func (orm *Model) SetPK(pk string) *Model {
orm.PrimaryKey = pk
return orm
}
func (orm *Model) Where(querystring interface{}, args ...interface{}) *Model {
switch querystring := querystring.(type) {
case string:
orm.WhereStr = querystring
case int:
if orm.ParamIdentifier == "pg" {
orm.WhereStr = fmt.Sprintf("%v%v%v = $%v", orm.QuoteIdentifier, orm.PrimaryKey, orm.QuoteIdentifier, orm.ParamIteration)
} else {
orm.WhereStr = fmt.Sprintf("%v%v%v = ?", orm.QuoteIdentifier, orm.PrimaryKey, orm.QuoteIdentifier)
orm.ParamIteration++
}
args = append(args, querystring)
}
orm.ParamStr = args
return orm
}
func (orm *Model) Limit(start int, size ...int) *Model {
orm.LimitStr = start
if len(size) > 0 {
orm.OffsetStr = size[0]
}
return orm
}
func (orm *Model) Offset(offset int) *Model {
orm.OffsetStr = offset
return orm
}
func (orm *Model) OrderBy(order string) *Model {
orm.OrderStr = order
return orm
}
func (orm *Model) Select(colums string) *Model {
orm.ColumnStr = colums
return orm
}
func (orm *Model) ScanPK(output interface{}) *Model {
if reflect.TypeOf(reflect.Indirect(reflect.ValueOf(output)).Interface()).Kind() == reflect.Slice {
sliceValue := reflect.Indirect(reflect.ValueOf(output))
sliceElementType := sliceValue.Type().Elem()
for i := 0; i < sliceElementType.NumField(); i++ {
bb := sliceElementType.Field(i).Tag
if bb.Get("beedb") == "PK" || reflect.ValueOf(bb).String() == "PK" {
orm.PrimaryKey = sliceElementType.Field(i).Name
}
}
} else {
tt := reflect.TypeOf(reflect.Indirect(reflect.ValueOf(output)).Interface())
for i := 0; i < tt.NumField(); i++ {
bb := tt.Field(i).Tag
if bb.Get("beedb") == "PK" || reflect.ValueOf(bb).String() == "PK" {
orm.PrimaryKey = tt.Field(i).Name
}
}
}
return orm
}
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (orm *Model) Join(join_operator, tablename, condition string) *Model {
if orm.JoinStr != "" {
orm.JoinStr = orm.JoinStr + fmt.Sprintf(" %v JOIN %v ON %v", join_operator, tablename, condition)
} else {
orm.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
}
return orm
}
func (orm *Model) GroupBy(keys string) *Model {
orm.GroupByStr = fmt.Sprintf("GROUP BY %v", keys)
return orm
}
func (orm *Model) Having(conditions string) *Model {
orm.HavingStr = fmt.Sprintf("HAVING %v", conditions)
return orm
}
func (orm *Model) Find(output interface{}) error {
orm.ScanPK(output)
var keys []string
results, err := scanStructIntoMap(output)
if err != nil {
return err
}
if orm.TableName == "" {
orm.TableName = getTableName(output)
}
// If we've already specific columns with Select(), use that
if orm.ColumnStr == "*" {
for key, _ := range results {
keys = append(keys, key)
}
orm.ColumnStr = strings.Join(keys, ", ")
}
orm.Limit(1)
resultsSlice, err := orm.FindMap()
if err != nil {
return err
}
if len(resultsSlice) == 0 {
return errors.New("No record found")
} else if len(resultsSlice) == 1 {
results := resultsSlice[0]
err := scanMapIntoStruct(output, results)
if err != nil {
return err
}
} else {
return errors.New("More than one record")
}
return nil
}
func (orm *Model) FindAll(rowsSlicePtr interface{}) error {
orm.ScanPK(rowsSlicePtr)
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return errors.New("needs a pointer to a slice")
}
sliceElementType := sliceValue.Type().Elem()
st := reflect.New(sliceElementType)
var keys []string
results, err := scanStructIntoMap(st.Interface())
if err != nil {
return err
}
if orm.TableName == "" {
orm.TableName = getTableName(getTypeName(rowsSlicePtr))
}
// If we've already specific columns with Select(), use that
if orm.ColumnStr == "*" {
for key, _ := range results {
keys = append(keys, key)
}
orm.ColumnStr = strings.Join(keys, ", ")
}
resultsSlice, err := orm.FindMap()
if err != nil {
return err
}
for _, results := range resultsSlice {
newValue := reflect.New(sliceElementType)
err := scanMapIntoStruct(newValue.Interface(), results)
if err != nil {
return err
}
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface()))))
}
return nil
}
func (orm *Model) FindMap() (resultsSlice []map[string][]byte, err error) {
defer orm.InitModel()
sqls := orm.generateSql()
if OnDebug {
fmt.Println(sqls)
fmt.Println(orm)
}
s, err := orm.Db.Prepare(sqls)
if err != nil {
return nil, err
}
defer s.Close()
res, err := s.Query(orm.ParamStr...)
if err != nil {
return nil, err
}
defer res.Close()
fields, err := res.Columns()
if err != nil {
return nil, err
}
for res.Next() {
result := make(map[string][]byte)
var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers = append(scanResultContainers, &scanResultContainer)
}
if err := res.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
continue
}
aa := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface())
var str string
switch aa.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
result[key] = []byte(str)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
result[key] = []byte(str)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
result[key] = []byte(str)
case reflect.Slice:
if aa.Elem().Kind() == reflect.Uint8 {
result[key] = rawValue.Interface().([]byte)
break
}
case reflect.String:
str = vv.String()
result[key] = []byte(str)
//时间类型
case reflect.Struct:
str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700")
result[key] = []byte(str)
case reflect.Bool:
if vv.Bool() {
result[key] = []byte("1")
} else {
result[key] = []byte("0")
}
}
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
func (orm *Model) generateSql() (a string) {
if orm.ParamIdentifier == "mssql" {
if orm.OffsetStr > 0 {
a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v",
orm.PrimaryKey,
orm.ColumnStr,
orm.TableName)
if orm.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, orm.WhereStr)
}
a = fmt.Sprintf("select * from (%v) "+
"as a where rownum between %v and %v",
a,
orm.OffsetStr,
orm.LimitStr)
} else if orm.LimitStr > 0 {
a = fmt.Sprintf("SELECT top %v %v FROM %v", orm.LimitStr, orm.ColumnStr, orm.TableName)
if orm.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, orm.WhereStr)
}
if orm.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, orm.GroupByStr)
}
if orm.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, orm.HavingStr)
}
if orm.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, orm.OrderStr)
}
} else {
a = fmt.Sprintf("SELECT %v FROM %v", orm.ColumnStr, orm.TableName)
if orm.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, orm.WhereStr)
}
if orm.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, orm.GroupByStr)
}
if orm.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, orm.HavingStr)
}
if orm.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, orm.OrderStr)
}
}
} else {
a = fmt.Sprintf("SELECT %v FROM %v", orm.ColumnStr, orm.TableName)
if orm.JoinStr != "" {
a = fmt.Sprintf("%v %v", a, orm.JoinStr)
}
if orm.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, orm.WhereStr)
}
if orm.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, orm.GroupByStr)
}
if orm.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, orm.HavingStr)
}
if orm.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, orm.OrderStr)
}
if orm.OffsetStr > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, orm.LimitStr, orm.OffsetStr)
} else if orm.LimitStr > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, orm.LimitStr)
}
}
return
}
//Execute sql
func (orm *Model) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) {
rs, err := orm.Db.Prepare(finalQueryString)
if err != nil {
return nil, err
}
defer rs.Close()
res, err := rs.Exec(args...)
if err != nil {
return nil, err
}
return res, nil
}
//if the struct has PrimaryKey == 0 insert else update
func (orm *Model) Save(output interface{}) error {
orm.ScanPK(output)
results, err := scanStructIntoMap(output)
if err != nil {
return err
}
if orm.TableName == "" {
orm.TableName = getTableName(output)
}
id := results[orm.PrimaryKey]
delete(results, orm.PrimaryKey)
if id == nil {
return fmt.Errorf("Unable to save because primary key %q was not found in struct", orm.PrimaryKey)
}
if reflect.ValueOf(id).Int() == 0 {
structPtr := reflect.ValueOf(output)
structVal := structPtr.Elem()
structField := structVal.FieldByName(orm.PrimaryKey)
id, err := orm.Insert(results)
if err != nil {
return err
}
var v interface{}
x, err := strconv.Atoi(strconv.FormatInt(id, 10))
if err != nil {
return err
}
v = x
structField.Set(reflect.ValueOf(v))
return nil
} else {
var condition string
if orm.ParamIdentifier == "pg" {
condition = fmt.Sprintf("%v%v%v=$%v", orm.QuoteIdentifier, strings.ToLower(orm.PrimaryKey), orm.QuoteIdentifier, orm.ParamIteration)
} else {
condition = fmt.Sprintf("%v%v%v=?", orm.QuoteIdentifier, orm.PrimaryKey, orm.QuoteIdentifier)
}
orm.Where(condition, id)
_, err := orm.Update(results)
if err != nil {
return err
}
}
return nil
}
//inert one info
func (orm *Model) Insert(properties map[string]interface{}) (int64, error) {
defer orm.InitModel()
var keys []string
var placeholders []string
var args []interface{}
for key, val := range properties {
keys = append(keys, key)
if orm.ParamIdentifier == "pg" {
ds := fmt.Sprintf("$%d", orm.ParamIteration)
placeholders = append(placeholders, ds)
} else {
placeholders = append(placeholders, "?")
}
orm.ParamIteration++
args = append(args, val)
}
ss := fmt.Sprintf("%v,%v", orm.QuoteIdentifier, orm.QuoteIdentifier)
statement := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)",
orm.QuoteIdentifier,
orm.TableName,
orm.QuoteIdentifier,
orm.QuoteIdentifier,
strings.Join(keys, ss),
orm.QuoteIdentifier,
strings.Join(placeholders, ", "))
if OnDebug {
fmt.Println(statement)
fmt.Println(orm)
}
if orm.ParamIdentifier == "pg" {
statement = fmt.Sprintf("%v RETURNING %v", statement, snakeCasedName(orm.PrimaryKey))
var id int64
orm.Db.QueryRow(statement, args...).Scan(&id)
return id, nil
} else {
res, err := orm.Exec(statement, args...)
if err != nil {
return -1, err
}
id, err := res.LastInsertId()
if err != nil {
return -1, err
}
return id, nil
}
return -1, nil
}
//insert batch info
func (orm *Model) InsertBatch(rows []map[string]interface{}) ([]int64, error) {
var ids []int64
tablename := orm.TableName
if len(rows) <= 0 {
return ids, nil
}
for i := 0; i < len(rows); i++ {
orm.TableName = tablename
id, err := orm.Insert(rows[i])
if err != nil {
return ids, err
}
ids = append(ids, id)
}
return ids, nil
}
// update info
func (orm *Model) Update(properties map[string]interface{}) (int64, error) {
defer orm.InitModel()
var updates []string
var args []interface{}
for key, val := range properties {
if orm.ParamIdentifier == "pg" {
ds := fmt.Sprintf("$%d", orm.ParamIteration)
updates = append(updates, fmt.Sprintf("%v%v%v = %v", orm.QuoteIdentifier, key, orm.QuoteIdentifier, ds))
} else {
updates = append(updates, fmt.Sprintf("%v%v%v = ?", orm.QuoteIdentifier, key, orm.QuoteIdentifier))
}
args = append(args, val)
orm.ParamIteration++
}
args = append(args, orm.ParamStr...)
if orm.ParamIdentifier == "pg" {
if n := len(orm.ParamStr); n > 0 {
for i := 1; i <= n; i++ {
orm.WhereStr = strings.Replace(orm.WhereStr, "$"+strconv.Itoa(i), "$"+strconv.Itoa(orm.ParamIteration), 1)
}
}
}
var condition string
if orm.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", orm.WhereStr)
} else {
condition = ""
}
statement := fmt.Sprintf("UPDATE %v%v%v SET %v %v",
orm.QuoteIdentifier,
orm.TableName,
orm.QuoteIdentifier,
strings.Join(updates, ", "),
condition)
if OnDebug {
fmt.Println(statement)
fmt.Println(orm)
}
res, err := orm.Exec(statement, args...)
if err != nil {
return -1, err
}
id, err := res.RowsAffected()
if err != nil {
return -1, err
}
return id, nil
}
func (orm *Model) Delete(output interface{}) (int64, error) {
defer orm.InitModel()
orm.ScanPK(output)
results, err := scanStructIntoMap(output)
if err != nil {
return 0, err
}
if orm.TableName == "" {
orm.TableName = getTableName(output)
}
id := results[strings.ToLower(orm.PrimaryKey)]
condition := fmt.Sprintf("%v%v%v='%v'", orm.QuoteIdentifier, strings.ToLower(orm.PrimaryKey), orm.QuoteIdentifier, id)
statement := fmt.Sprintf("DELETE FROM %v%v%v WHERE %v",
orm.QuoteIdentifier,
orm.TableName,
orm.QuoteIdentifier,
condition)
if OnDebug {
fmt.Println(statement)
fmt.Println(orm)
}
res, err := orm.Exec(statement)
if err != nil {
return -1, err
}
Affectid, err := res.RowsAffected()
if err != nil {
return -1, err
}
return Affectid, nil
}
func (orm *Model) DeleteAll(rowsSlicePtr interface{}) (int64, error) {
defer orm.InitModel()
orm.ScanPK(rowsSlicePtr)
if orm.TableName == "" {
//TODO: fix table name
orm.TableName = getTableName(getTypeName(rowsSlicePtr))
}
var ids []string
val := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if val.Len() == 0 {
return 0, nil
}
for i := 0; i < val.Len(); i++ {
results, err := scanStructIntoMap(val.Index(i).Interface())
if err != nil {
return 0, err
}
id := results[strings.ToLower(orm.PrimaryKey)]
switch id.(type) {
case string:
ids = append(ids, id.(string))
case int, int64, int32:
str := strconv.Itoa(id.(int))
ids = append(ids, str)
}
}
condition := fmt.Sprintf("%v%v%v in ('%v')", orm.QuoteIdentifier, strings.ToLower(orm.PrimaryKey), orm.QuoteIdentifier, strings.Join(ids, "','"))
statement := fmt.Sprintf("DELETE FROM %v%v%v WHERE %v",
orm.QuoteIdentifier,
orm.TableName,
orm.QuoteIdentifier,
condition)
if OnDebug {
fmt.Println(statement)
fmt.Println(orm)
}
res, err := orm.Exec(statement)
if err != nil {
return -1, err
}
Affectid, err := res.RowsAffected()
if err != nil {
return -1, err
}
return Affectid, nil
}
func (orm *Model) DeleteRow() (int64, error) {
defer orm.InitModel()
var condition string
if orm.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", orm.WhereStr)
} else {
condition = ""
}
statement := fmt.Sprintf("DELETE FROM %v%v%v %v",
orm.QuoteIdentifier,
orm.TableName,
orm.QuoteIdentifier,
condition)
if OnDebug {
fmt.Println(statement)
fmt.Println(orm)
}
res, err := orm.Exec(statement, orm.ParamStr...)
if err != nil {
return -1, err
}
Affectid, err := res.RowsAffected()
if err != nil {
return -1, err
}
return Affectid, nil
}
func (orm *Model) InitModel() {
orm.TableName = ""
orm.LimitStr = 0
orm.OffsetStr = 0
orm.WhereStr = ""
orm.ParamStr = make([]interface{}, 0)
orm.OrderStr = ""
orm.ColumnStr = "*"
orm.PrimaryKey = "id"
orm.JoinStr = ""
orm.GroupByStr = ""
orm.HavingStr = ""
orm.ParamIteration = 1
}
package beedb
import (
"errors"
"reflect"
"strconv"
"strings"
"time"
)
func getTypeName(obj interface{}) (typestr string) {
typ := reflect.TypeOf(obj)
typestr = typ.String()
lastDotIndex := strings.LastIndex(typestr, ".")
if lastDotIndex != -1 {
typestr = typestr[lastDotIndex+1:]
}
return
}
func snakeCasedName(name string) string {
newstr := make([]rune, 0)
firstTime := true
for _, chr := range name {
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper {
if firstTime == true {
firstTime = false
} else {
newstr = append(newstr, '_')
}
chr -= ('A' - 'a')
}
newstr = append(newstr, chr)
}
return string(newstr)
}
func titleCasedName(name string) string {
newstr := make([]rune, 0)
upNextChar := true
for _, chr := range name {
switch {
case upNextChar:
upNextChar = false
chr -= ('a' - 'A')
case chr == '_':
upNextChar = true
continue
}
newstr = append(newstr, chr)
}
return string(newstr)
}
func pluralizeString(str string) string {
if strings.HasSuffix(str, "data") {
return str
}
if strings.HasSuffix(str, "y") {
str = str[:len(str)-1] + "ie"
}
return str + "s"
}
func scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
dataStruct := reflect.Indirect(reflect.ValueOf(obj))
if dataStruct.Kind() != reflect.Struct {
return errors.New("expected a pointer to a struct")
}
dataStructType := dataStruct.Type()
for i := 0; i < dataStructType.NumField(); i++ {
field := dataStructType.Field(i)
fieldv := dataStruct.Field(i)
err := scanMapElement(fieldv, field, objMap)
if err != nil {
return err
}
}
return nil
}
func scanMapElement(fieldv reflect.Value, field reflect.StructField, objMap map[string][]byte) error {
objFieldName := field.Name
bb := field.Tag
sqlTag := bb.Get("sql")
if bb.Get("beedb") == "-" || sqlTag == "-" || reflect.ValueOf(bb).String() == "-" {
return nil
}
sqlTags := strings.Split(sqlTag, ",")
sqlFieldName := objFieldName
if len(sqlTags[0]) > 0 {
sqlFieldName = sqlTags[0]
}
inline := false
//omitempty := false //TODO!
// CHECK INLINE
if len(sqlTags) > 1 {
if stringArrayContains("inline", sqlTags[1:]) {
inline = true
}
}
if inline {
if field.Type.Kind() == reflect.Struct && field.Type.String() != "time.Time" {
for i := 0; i < field.Type.NumField(); i++ {
err := scanMapElement(fieldv.Field(i), field.Type.Field(i), objMap)
if err != nil {
return err
}
}
} else {
return errors.New("A non struct type can't be inline.")
}
}
// not inline
data, ok := objMap[sqlFieldName]
if !ok {
return nil
}
var v interface{}
switch field.Type.Kind() {
case reflect.Slice:
v = data
case reflect.String:
v = string(data)
case reflect.Bool:
v = string(data) == "1"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
x, err := strconv.Atoi(string(data))
if err != nil {
return errors.New("arg " + sqlFieldName + " as int: " + err.Error())
}
v = x
case reflect.Int64:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + sqlFieldName + " as int: " + err.Error())
}
v = x
case reflect.Float32, reflect.Float64:
x, err := strconv.ParseFloat(string(data), 64)
if err != nil {
return errors.New("arg " + sqlFieldName + " as float64: " + err.Error())
}
v = x
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return errors.New("arg " + sqlFieldName + " as int: " + err.Error())
}
v = x
//Supports Time type only (for now)
case reflect.Struct:
if fieldv.Type().String() != "time.Time" {
return errors.New("unsupported struct type in Scan: " + fieldv.Type().String())
}
x, err := time.Parse("2006-01-02 15:04:05", string(data))
if err != nil {
x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data))
if err != nil {
return errors.New("unsupported time format: " + string(data))
}
}
v = x
default:
return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String())
}
fieldv.Set(reflect.ValueOf(v))
return nil
}
func scanStructIntoMap(obj interface{}) (map[string]interface{}, error) {
dataStruct := reflect.Indirect(reflect.ValueOf(obj))
if dataStruct.Kind() != reflect.Struct {
return nil, errors.New("expected a pointer to a struct")
}
dataStructType := dataStruct.Type()
mapped := make(map[string]interface{})
for i := 0; i < dataStructType.NumField(); i++ {
field := dataStructType.Field(i)
fieldv := dataStruct.Field(i)
fieldName := field.Name
bb := field.Tag
sqlTag := bb.Get("sql")
sqlTags := strings.Split(sqlTag, ",")
var mapKey string
inline := false
if bb.Get("beedb") == "-" || sqlTag == "-" || reflect.ValueOf(bb).String() == "-" {
continue
} else if len(sqlTag) > 0 {
//TODO: support tags that are common in json like omitempty
if sqlTags[0] == "-" {
continue
}
mapKey = sqlTags[0]
} else {
mapKey = fieldName
}
if len(sqlTags) > 1 {
if stringArrayContains("inline", sqlTags[1:]) {
inline = true
}
}
if inline {
// get an inner map and then put it inside the outer map
map2, err2 := scanStructIntoMap(fieldv.Interface())
if err2 != nil {
return mapped, err2
}
for k, v := range map2 {
mapped[k] = v
}
} else {
value := dataStruct.FieldByName(fieldName).Interface()
mapped[mapKey] = value
}
}
return mapped, nil
}
func StructName(s interface{}) string {
v := reflect.TypeOf(s)
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
return v.Name()
}
func getTableName(s interface{}) string {
v := reflect.TypeOf(s)
if v.Kind() == reflect.String {
s2, _ := s.(string)
return snakeCasedName(s2)
}
tn := scanTableName(s)
if len(tn) > 0 {
return tn
}
return getTableName(StructName(s))
}
func scanTableName(s interface{}) string {
if reflect.TypeOf(reflect.Indirect(reflect.ValueOf(s)).Interface()).Kind() == reflect.Slice {
sliceValue := reflect.Indirect(reflect.ValueOf(s))
sliceElementType := sliceValue.Type().Elem()
for i := 0; i < sliceElementType.NumField(); i++ {
bb := sliceElementType.Field(i).Tag
if len(bb.Get("tname")) > 0 {
return bb.Get("tname")
}
}
} else {
tt := reflect.TypeOf(reflect.Indirect(reflect.ValueOf(s)).Interface())
for i := 0; i < tt.NumField(); i++ {
bb := tt.Field(i).Tag
if len(bb.Get("tname")) > 0 {
return bb.Get("tname")
}
}
}
return ""
}
func stringArrayContains(needle string, haystack []string) bool {
for _, v := range haystack {
if needle == v {
return true
}
}
return false
}
package beedb
import (
"testing"
"time"
)
type User struct {
SQLModel `sql:",inline"`
Name string `sql:"name" tname:"fn_group"`
Auth int `sql:"auth"`
}
type SQLModel struct {
Id int `beedb:"PK" sql:"id"`
Created time.Time `sql:"created"`
Modified time.Time `sql:"modified"`
}
func TestMapToStruct(t *testing.T) {
target := &User{}
input := map[string][]byte{
"name": []byte("Test User"),
"auth": []byte("1"),
"id": []byte("1"),
"created": []byte("2014-01-01 10:10:10"),
"modified": []byte("2014-01-01 10:10:10"),
}
err := scanMapIntoStruct(target, input)
if err != nil {
t.Errorf(err.Error())
}
_, err = scanStructIntoMap(target)
if err != nil {
t.Errorf(err.Error())
}
}
_test
_testmain.go
_obj
*~
*.6
6.out
gorptest.bin
tmp
language: go
go:
- 1.1
- tip
services:
- mysql
- postgres
- sqlite3
before_script:
- mysql -e "CREATE DATABASE gorptest;"
- mysql -u root -e "GRANT ALL ON gorptest.* TO gorptest@localhost IDENTIFIED BY 'gorptest'"
- psql -c "CREATE DATABASE gorptest;" -U postgres
- psql -c "CREATE USER "gorptest" WITH SUPERUSER PASSWORD 'gorptest';" -U postgres
- go get github.com/lib/pq
- go get github.com/mattn/go-sqlite3
- go get github.com/ziutek/mymysql/godrv
- go get github.com/go-sql-driver/mysql
script: ./test_all.sh
package gorp
import (
"errors"
"fmt"
"reflect"
"strings"
)
// The Dialect interface encapsulates behaviors that differ across
// SQL databases. At present the Dialect is only used by CreateTables()
// but this could change in the future
type Dialect interface {
// adds a suffix to any query, usually ";"
QuerySuffix() string
// ToSqlType returns the SQL column type to use when creating a
// table of the given Go Type. maxsize can be used to switch based on
// size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB,
// or LONGBLOB depending on the maxsize
ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string
// string to append to primary key column definitions
AutoIncrStr() string
// string to bind autoincrement columns to. Empty string will
// remove reference to those columns in the INSERT statement.
AutoIncrBindValue() string
AutoIncrInsertSuffix(col *ColumnMap) string
// string to append to "create table" statement for vendor specific
// table attributes
CreateTableSuffix() string
// string to truncate tables
TruncateClause() string
// bind variable string to use when forming SQL statements
// in many dbs it is "?", but Postgres appears to use $1
//
// i is a zero based index of the bind variable in this statement
//
BindVar(i int) string
// Handles quoting of a field name to ensure that it doesn't raise any
// SQL parsing exceptions by using a reserved word as a field name.
QuoteField(field string) string
// Handles building up of a schema.database string that is compatible with
// the given dialect
//
// schema - The schema that <table> lives in
// table - The table name
QuotedTableForQuery(schema string, table string) string
// Existance clause for table creation / deletion
IfSchemaNotExists(command, schema string) string
IfTableExists(command, schema, table string) string
IfTableNotExists(command, schema, table string) string
}
// IntegerAutoIncrInserter is implemented by dialects that can perform
// inserts with automatically incremented integer primary keys. If
// the dialect can handle automatic assignment of more than just
// integers, see TargetedAutoIncrInserter.
type IntegerAutoIncrInserter interface {
InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error)
}
// TargetedAutoIncrInserter is implemented by dialects that can
// perform automatic assignment of any primary key type (i.e. strings
// for uuids, integers for serials, etc).
type TargetedAutoIncrInserter interface {
// InsertAutoIncrToTarget runs an insert operation and assigns the
// automatically generated primary key directly to the passed in
// target. The target should be a pointer to the primary key
// field of the value being inserted.
InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error
}
func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
res, err := exec.Exec(insertSql, params...)
if err != nil {
return 0, err
}
return res.LastInsertId()
}
///////////////////////////////////////////////////////
// sqlite3 //
/////////////
type SqliteDialect struct {
suffix string
}
func (d SqliteDialect) QuerySuffix() string { return ";" }
func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "integer"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "integer"
case reflect.Float64, reflect.Float32:
return "real"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "blob"
}
}
switch val.Name() {
case "NullInt64":
return "integer"
case "NullFloat64":
return "real"
case "NullBool":
return "integer"
case "Time":
return "datetime"
}
if maxsize < 1 {
maxsize = 255
}
return fmt.Sprintf("varchar(%d)", maxsize)
}
// Returns autoincrement
func (d SqliteDialect) AutoIncrStr() string {
return "autoincrement"
}
func (d SqliteDialect) AutoIncrBindValue() string {
return "null"
}
func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
// Returns suffix
func (d SqliteDialect) CreateTableSuffix() string {
return d.suffix
}
// With sqlite, there technically isn't a TRUNCATE statement,
// but a DELETE FROM uses a truncate optimization:
// http://www.sqlite.org/lang_delete.html
func (d SqliteDialect) TruncateClause() string {
return "delete from"
}
// Returns "?"
func (d SqliteDialect) BindVar(i int) string {
return "?"
}
func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(exec, insertSql, params...)
}
func (d SqliteDialect) QuoteField(f string) string {
return `"` + f + `"`
}
// sqlite does not have schemas like PostgreSQL does, so just escape it like normal
func (d SqliteDialect) QuotedTableForQuery(schema string, table string) string {
return d.QuoteField(table)
}
func (d SqliteDialect) IfSchemaNotExists(command, schema string) string {
return fmt.Sprintf("%s if not exists", command)
}
func (d SqliteDialect) IfTableExists(command, schema, table string) string {
return fmt.Sprintf("%s if exists", command)
}
func (d SqliteDialect) IfTableNotExists(command, schema, table string) string {
return fmt.Sprintf("%s if not exists", command)
}
///////////////////////////////////////////////////////
// PostgreSQL //
////////////////
type PostgresDialect struct {
suffix string
}
func (d PostgresDialect) QuerySuffix() string { return ";" }
func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32:
if isAutoIncr {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if isAutoIncr {
return "bigserial"
}
return "bigint"
case reflect.Float64:
return "double precision"
case reflect.Float32:
return "real"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "bytea"
}
}
switch val.Name() {
case "NullInt64":
return "bigint"
case "NullFloat64":
return "double precision"
case "NullBool":
return "boolean"
case "Time":
return "timestamp with time zone"
}
if maxsize > 0 {
return fmt.Sprintf("varchar(%d)", maxsize)
} else {
return "text"
}
}
// Returns empty string
func (d PostgresDialect) AutoIncrStr() string {
return ""
}
func (d PostgresDialect) AutoIncrBindValue() string {
return "default"
}
func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return " returning " + col.ColumnName
}
// Returns suffix
func (d PostgresDialect) CreateTableSuffix() string {
return d.suffix
}
func (d PostgresDialect) TruncateClause() string {
return "truncate"
}
// Returns "$(i+1)"
func (d PostgresDialect) BindVar(i int) string {
return fmt.Sprintf("$%d", i+1)
}
func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error {
rows, err := exec.query(insertSql, params...)
if err != nil {
return err
}
defer rows.Close()
if rows.Next() {
err := rows.Scan(target)
return err
}
return errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error())
}
func (d PostgresDialect) QuoteField(f string) string {
return `"` + strings.ToLower(f) + `"`
}
func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}
return schema + "." + d.QuoteField(table)
}
func (d PostgresDialect) IfSchemaNotExists(command, schema string) string {
return fmt.Sprintf("%s if not exists", command)
}
func (d PostgresDialect) IfTableExists(command, schema, table string) string {
return fmt.Sprintf("%s if exists", command)
}
func (d PostgresDialect) IfTableNotExists(command, schema, table string) string {
return fmt.Sprintf("%s if not exists", command)
}
///////////////////////////////////////////////////////
// MySQL //
///////////
// Implementation of Dialect for MySQL databases.
type MySQLDialect struct {
// Engine is the storage engine to use "InnoDB" vs "MyISAM" for example
Engine string
// Encoding is the character encoding to use for created tables
Encoding string
}
func (d MySQLDialect) QuerySuffix() string { return ";" }
func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "boolean"
case reflect.Int8:
return "tinyint"
case reflect.Uint8:
return "tinyint unsigned"
case reflect.Int16:
return "smallint"
case reflect.Uint16:
return "smallint unsigned"
case reflect.Int, reflect.Int32:
return "int"
case reflect.Uint, reflect.Uint32:
return "int unsigned"
case reflect.Int64:
return "bigint"
case reflect.Uint64:
return "bigint unsigned"
case reflect.Float64, reflect.Float32:
return "double"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "mediumblob"
}
}
switch val.Name() {
case "NullInt64":
return "bigint"
case "NullFloat64":
return "double"
case "NullBool":
return "tinyint"
case "Time":
return "datetime"
}
if maxsize < 1 {
maxsize = 255
}
return fmt.Sprintf("varchar(%d)", maxsize)
}
// Returns auto_increment
func (d MySQLDialect) AutoIncrStr() string {
return "auto_increment"
}
func (d MySQLDialect) AutoIncrBindValue() string {
return "null"
}
func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
// Returns engine=%s charset=%s based on values stored on struct
func (d MySQLDialect) CreateTableSuffix() string {
if d.Engine == "" || d.Encoding == "" {
msg := "gorp - undefined"
if d.Engine == "" {
msg += " MySQLDialect.Engine"
}
if d.Engine == "" && d.Encoding == "" {
msg += ","
}
if d.Encoding == "" {
msg += " MySQLDialect.Encoding"
}
msg += ". Check that your MySQLDialect was correctly initialized when declared."
panic(msg)
}
return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding)
}
func (d MySQLDialect) TruncateClause() string {
return "truncate"
}
// Returns "?"
func (d MySQLDialect) BindVar(i int) string {
return "?"
}
func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(exec, insertSql, params...)
}
func (d MySQLDialect) QuoteField(f string) string {
return "`" + f + "`"
}
func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}
return schema + "." + d.QuoteField(table)
}
func (d MySQLDialect) IfSchemaNotExists(command, schema string) string {
return fmt.Sprintf("%s if not exists", command)
}
func (d MySQLDialect) IfTableExists(command, schema, table string) string {
return fmt.Sprintf("%s if exists", command)
}
func (d MySQLDialect) IfTableNotExists(command, schema, table string) string {
return fmt.Sprintf("%s if not exists", command)
}
///////////////////////////////////////////////////////
// Sql Server //
////////////////
// Implementation of Dialect for Microsoft SQL Server databases.
// Tested on SQL Server 2008 with driver: github.com/denisenkom/go-mssqldb
type SqlServerDialect struct {
suffix string
}
func (d SqlServerDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "bit"
case reflect.Int8:
return "tinyint"
case reflect.Uint8:
return "smallint"
case reflect.Int16:
return "smallint"
case reflect.Uint16:
return "int"
case reflect.Int, reflect.Int32:
return "int"
case reflect.Uint, reflect.Uint32:
return "bigint"
case reflect.Int64:
return "bigint"
case reflect.Uint64:
return "bigint"
case reflect.Float32:
return "real"
case reflect.Float64:
return "float(53)"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "varbinary"
}
}
switch val.Name() {
case "NullInt64":
return "bigint"
case "NullFloat64":
return "float(53)"
case "NullBool":
return "tinyint"
case "Time":
return "datetime"
}
if maxsize < 1 {
maxsize = 255
}
return fmt.Sprintf("varchar(%d)", maxsize)
}
// Returns auto_increment
func (d SqlServerDialect) AutoIncrStr() string {
return "identity(0,1)"
}
// Empty string removes autoincrement columns from the INSERT statements.
func (d SqlServerDialect) AutoIncrBindValue() string {
return ""
}
func (d SqlServerDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
// Returns suffix
func (d SqlServerDialect) CreateTableSuffix() string {
return d.suffix
}
func (d SqlServerDialect) TruncateClause() string {
return "delete from"
}
// Returns "?"
func (d SqlServerDialect) BindVar(i int) string {
return "?"
}
func (d SqlServerDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(exec, insertSql, params...)
}
func (d SqlServerDialect) QuoteField(f string) string {
return `"` + f + `"`
}
func (d SqlServerDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return table
}
return schema + "." + table
}
func (d SqlServerDialect) QuerySuffix() string { return ";" }
func (d SqlServerDialect) IfSchemaNotExists(command, schema string) string {
s := fmt.Sprintf("if not exists (select name from sys.schemas where name = '%s') %s", schema, command)
return s
}
func (d SqlServerDialect) IfTableExists(command, schema, table string) string {
var schema_clause string
if strings.TrimSpace(schema) != "" {
schema_clause = fmt.Sprintf("table_schema = '%s' and ", schema)
}
s := fmt.Sprintf("if exists (select * from information_schema.tables where %stable_name = '%s') %s", schema_clause, table, command)
return s
}
func (d SqlServerDialect) IfTableNotExists(command, schema, table string) string {
var schema_clause string
if strings.TrimSpace(schema) != "" {
schema_clause = fmt.Sprintf("table_schema = '%s' and ", schema)
}
s := fmt.Sprintf("if not exists (select * from information_schema.tables where %stable_name = '%s') %s", schema_clause, table, command)
return s
}
///////////////////////////////////////////////////////
// Oracle //
///////////
// Implementation of Dialect for Oracle databases.
type OracleDialect struct{}
func (d OracleDialect) QuerySuffix() string { return "" }
func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32:
if isAutoIncr {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if isAutoIncr {
return "bigserial"
}
return "bigint"
case reflect.Float64:
return "double precision"
case reflect.Float32:
return "real"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "bytea"
}
}
switch val.Name() {
case "NullInt64":
return "bigint"
case "NullFloat64":
return "double precision"
case "NullBool":
return "boolean"
case "NullTime", "Time":
return "timestamp with time zone"
}
if maxsize > 0 {
return fmt.Sprintf("varchar(%d)", maxsize)
} else {
return "text"
}
}
// Returns empty string
func (d OracleDialect) AutoIncrStr() string {
return ""
}
func (d OracleDialect) AutoIncrBindValue() string {
return "default"
}
func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return " returning " + col.ColumnName
}
// Returns suffix
func (d OracleDialect) CreateTableSuffix() string {
return ""
}
func (d OracleDialect) TruncateClause() string {
return "truncate"
}
// Returns "$(i+1)"
func (d OracleDialect) BindVar(i int) string {
return fmt.Sprintf(":%d", i+1)
}
func (d OracleDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
rows, err := exec.query(insertSql, params...)
if err != nil {
return 0, err
}
defer rows.Close()
if rows.Next() {
var id int64
err := rows.Scan(&id)
return id, err
}
return 0, errors.New("No serial value returned for insert: " + insertSql + " Encountered error: " + rows.Err().Error())
}
func (d OracleDialect) QuoteField(f string) string {
return `"` + strings.ToUpper(f) + `"`
}
func (d OracleDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}
return schema + "." + d.QuoteField(table)
}
func (d OracleDialect) IfSchemaNotExists(command, schema string) string {
return fmt.Sprintf("%s if not exists", command)
}
func (d OracleDialect) IfTableExists(command, schema, table string) string {
return fmt.Sprintf("%s if exists", command)
}
func (d OracleDialect) IfTableNotExists(command, schema, table string) string {
return fmt.Sprintf("%s if not exists", command)
}
package gorp
import (
"fmt"
)
// A non-fatal error, when a select query returns columns that do not exist
// as fields in the struct it is being mapped to
type NoFieldInTypeError struct {
TypeName string
MissingColNames []string
}
func (err *NoFieldInTypeError) Error() string {
return fmt.Sprintf("gorp: No fields %+v in type %s", err.MissingColNames, err.TypeName)
}
// returns true if the error is non-fatal (ie, we shouldn't immediately return)
func NonFatalError(err error) bool {
switch err.(type) {
case *NoFieldInTypeError:
return true
default:
return false
}
}
// Copyright 2012 James Cooper. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
// Package gorp provides a simple way to marshal Go structs to and from
// SQL databases. It uses the database/sql package, and should work with any
// compliant database/sql driver.
//
// Source code and project home:
// https://github.com/coopernurse/gorp
//
package gorp
import (
"bytes"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"time"
)
// Oracle String (empty string is null)
type OracleString struct {
sql.NullString
}
// Scan implements the Scanner interface.
func (os *OracleString) Scan(value interface{}) error {
if value == nil {
os.String, os.Valid = "", false
return nil
}
os.Valid = true
return os.NullString.Scan(value)
}
// Value implements the driver Valuer interface.
func (os OracleString) Value() (driver.Value, error) {
if !os.Valid || os.String == "" {
return nil, nil
}
return os.String, nil
}
// A nullable Time value
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) error {
nt.Time, nt.Valid = value.(time.Time)
return nil
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
var zeroVal reflect.Value
var versFieldConst = "[gorp_ver_field]"
// OptimisticLockError is returned by Update() or Delete() if the
// struct being modified has a Version field and the value is not equal to
// the current value in the database
type OptimisticLockError struct {
// Table name where the lock error occurred
TableName string
// Primary key values of the row being updated/deleted
Keys []interface{}
// true if a row was found with those keys, indicating the
// LocalVersion is stale. false if no value was found with those
// keys, suggesting the row has been deleted since loaded, or
// was never inserted to begin with
RowExists bool
// Version value on the struct passed to Update/Delete. This value is
// out of sync with the database.
LocalVersion int64
}
// Error returns a description of the cause of the lock error
func (e OptimisticLockError) Error() string {
if e.RowExists {
return fmt.Sprintf("gorp: OptimisticLockError table=%s keys=%v out of date version=%d", e.TableName, e.Keys, e.LocalVersion)
}
return fmt.Sprintf("gorp: OptimisticLockError no row found for table=%s keys=%v", e.TableName, e.Keys)
}
// The TypeConverter interface provides a way to map a value of one
// type to another type when persisting to, or loading from, a database.
//
// Example use cases: Implement type converter to convert bool types to "y"/"n" strings,
// or serialize a struct member as a JSON blob.
type TypeConverter interface {
// ToDb converts val to another type. Called before INSERT/UPDATE operations
ToDb(val interface{}) (interface{}, error)
// FromDb returns a CustomScanner appropriate for this type. This will be used
// to hold values returned from SELECT queries.
//
// In particular the CustomScanner returned should implement a Binder
// function appropriate for the Go type you wish to convert the db value to
//
// If bool==false, then no custom scanner will be used for this field.
FromDb(target interface{}) (CustomScanner, bool)
}
// CustomScanner binds a database column value to a Go type
type CustomScanner struct {
// After a row is scanned, Holder will contain the value from the database column.
// Initialize the CustomScanner with the concrete Go type you wish the database
// driver to scan the raw column into.
Holder interface{}
// Target typically holds a pointer to the target struct field to bind the Holder
// value to.
Target interface{}
// Binder is a custom function that converts the holder value to the target type
// and sets target accordingly. This function should return error if a problem
// occurs converting the holder to the target.
Binder func(holder interface{}, target interface{}) error
}
// Bind is called automatically by gorp after Scan()
func (me CustomScanner) Bind() error {
return me.Binder(me.Holder, me.Target)
}
// DbMap is the root gorp mapping object. Create one of these for each
// database schema you wish to map. Each DbMap contains a list of
// mapped tables.
//
// Example:
//
// dialect := gorp.MySQLDialect{"InnoDB", "UTF8"}
// dbmap := &gorp.DbMap{Db: db, Dialect: dialect}
//
type DbMap struct {
// Db handle to use with this map
Db *sql.DB
// Dialect implementation to use with this map
Dialect Dialect
TypeConverter TypeConverter
tables []*TableMap
logger GorpLogger
logPrefix string
}
// TableMap represents a mapping between a Go struct and a database table
// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these
type TableMap struct {
// Name of database table.
TableName string
SchemaName string
gotype reflect.Type
Columns []*ColumnMap
keys []*ColumnMap
uniqueTogether [][]string
version *ColumnMap
insertPlan bindPlan
updatePlan bindPlan
deletePlan bindPlan
getPlan bindPlan
dbmap *DbMap
}
// ResetSql removes cached insert/update/select/delete SQL strings
// associated with this TableMap. Call this if you've modified
// any column names or the table name itself.
func (t *TableMap) ResetSql() {
t.insertPlan = bindPlan{}
t.updatePlan = bindPlan{}
t.deletePlan = bindPlan{}
t.getPlan = bindPlan{}
}
// SetKeys lets you specify the fields on a struct that map to primary
// key columns on the table. If isAutoIncr is set, result.LastInsertId()
// will be used after INSERT to bind the generated id to the Go struct.
//
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
//
// Panics if isAutoIncr is true, and fieldNames length != 1
//
func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap {
if isAutoIncr && len(fieldNames) != 1 {
panic(fmt.Sprintf(
"gorp: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)",
len(fieldNames)))
}
t.keys = make([]*ColumnMap, 0)
for _, name := range fieldNames {
colmap := t.ColMap(name)
colmap.isPK = true
colmap.isAutoIncr = isAutoIncr
t.keys = append(t.keys, colmap)
}
t.ResetSql()
return t
}
// SetUniqueTogether lets you specify uniqueness constraints across multiple
// columns on the table. Each call adds an additional constraint for the
// specified columns.
//
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
//
// Panics if fieldNames length < 2.
//
func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap {
if len(fieldNames) < 2 {
panic(fmt.Sprintf(
"gorp: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint."))
}
columns := make([]string, 0)
for _, name := range fieldNames {
columns = append(columns, name)
}
t.uniqueTogether = append(t.uniqueTogether, columns)
t.ResetSql()
return t
}
// ColMap returns the ColumnMap pointer matching the given struct field
// name. It panics if the struct does not contain a field matching this
// name.
func (t *TableMap) ColMap(field string) *ColumnMap {
col := colMapOrNil(t, field)
if col == nil {
e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s",
t.TableName, t.gotype.Name(), field)
panic(e)
}
return col
}
func colMapOrNil(t *TableMap, field string) *ColumnMap {
for _, col := range t.Columns {
if col.fieldName == field || col.ColumnName == field {
return col
}
}
return nil
}
// SetVersionCol sets the column to use as the Version field. By default
// the "Version" field is used. Returns the column found, or panics
// if the struct does not contain a field matching this name.
//
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
func (t *TableMap) SetVersionCol(field string) *ColumnMap {
c := t.ColMap(field)
t.version = c
t.ResetSql()
return c
}
type bindPlan struct {
query string
argFields []string
keyFields []string
versField string
autoIncrIdx int
autoIncrFieldName string
}
func (plan bindPlan) createBindInstance(elem reflect.Value, conv TypeConverter) (bindInstance, error) {
bi := bindInstance{query: plan.query, autoIncrIdx: plan.autoIncrIdx, autoIncrFieldName: plan.autoIncrFieldName, versField: plan.versField}
if plan.versField != "" {
bi.existingVersion = elem.FieldByName(plan.versField).Int()
}
var err error
for i := 0; i < len(plan.argFields); i++ {
k := plan.argFields[i]
if k == versFieldConst {
newVer := bi.existingVersion + 1
bi.args = append(bi.args, newVer)
if bi.existingVersion == 0 {
elem.FieldByName(plan.versField).SetInt(int64(newVer))
}
} else {
val := elem.FieldByName(k).Interface()
if conv != nil {
val, err = conv.ToDb(val)
if err != nil {
return bindInstance{}, err
}
}
bi.args = append(bi.args, val)
}
}
for i := 0; i < len(plan.keyFields); i++ {
k := plan.keyFields[i]
val := elem.FieldByName(k).Interface()
if conv != nil {
val, err = conv.ToDb(val)
if err != nil {
return bindInstance{}, err
}
}
bi.keys = append(bi.keys, val)
}
return bi, nil
}
type bindInstance struct {
query string
args []interface{}
keys []interface{}
existingVersion int64
versField string
autoIncrIdx int
autoIncrFieldName string
}
func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) {
plan := t.insertPlan
if plan.query == "" {
plan.autoIncrIdx = -1
s := bytes.Buffer{}
s2 := bytes.Buffer{}
s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
x := 0
first := true
for y := range t.Columns {
col := t.Columns[y]
if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") {
if !col.Transient {
if !first {
s.WriteString(",")
s2.WriteString(",")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
if col.isAutoIncr {
s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue())
plan.autoIncrIdx = y
plan.autoIncrFieldName = col.fieldName
} else {
s2.WriteString(t.dbmap.Dialect.BindVar(x))
if col == t.version {
plan.versField = col.fieldName
plan.argFields = append(plan.argFields, versFieldConst)
} else {
plan.argFields = append(plan.argFields, col.fieldName)
}
x++
}
first = false
}
} else {
plan.autoIncrIdx = y
plan.autoIncrFieldName = col.fieldName
}
}
s.WriteString(") values (")
s.WriteString(s2.String())
s.WriteString(")")
if plan.autoIncrIdx > -1 {
s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx]))
}
s.WriteString(t.dbmap.Dialect.QuerySuffix())
plan.query = s.String()
t.insertPlan = plan
}
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
}
func (t *TableMap) bindUpdate(elem reflect.Value) (bindInstance, error) {
plan := t.updatePlan
if plan.query == "" {
s := bytes.Buffer{}
s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
x := 0
for y := range t.Columns {
col := t.Columns[y]
if !col.isAutoIncr && !col.Transient {
if x > 0 {
s.WriteString(", ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
if col == t.version {
plan.versField = col.fieldName
plan.argFields = append(plan.argFields, versFieldConst)
} else {
plan.argFields = append(plan.argFields, col.fieldName)
}
x++
}
}
s.WriteString(" where ")
for y := range t.keys {
col := t.keys[y]
if y > 0 {
s.WriteString(" and ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.argFields = append(plan.argFields, col.fieldName)
plan.keyFields = append(plan.keyFields, col.fieldName)
x++
}
if plan.versField != "" {
s.WriteString(" and ")
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.argFields = append(plan.argFields, plan.versField)
}
s.WriteString(t.dbmap.Dialect.QuerySuffix())
plan.query = s.String()
t.updatePlan = plan
}
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
}
func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) {
plan := t.deletePlan
if plan.query == "" {
s := bytes.Buffer{}
s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
for y := range t.Columns {
col := t.Columns[y]
if !col.Transient {
if col == t.version {
plan.versField = col.fieldName
}
}
}
s.WriteString(" where ")
for x := range t.keys {
k := t.keys[x]
if x > 0 {
s.WriteString(" and ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(k.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.keyFields = append(plan.keyFields, k.fieldName)
plan.argFields = append(plan.argFields, k.fieldName)
}
if plan.versField != "" {
s.WriteString(" and ")
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(len(plan.argFields)))
plan.argFields = append(plan.argFields, plan.versField)
}
s.WriteString(t.dbmap.Dialect.QuerySuffix())
plan.query = s.String()
t.deletePlan = plan
}
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
}
func (t *TableMap) bindGet() bindPlan {
plan := t.getPlan
if plan.query == "" {
s := bytes.Buffer{}
s.WriteString("select ")
x := 0
for _, col := range t.Columns {
if !col.Transient {
if x > 0 {
s.WriteString(",")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
plan.argFields = append(plan.argFields, col.fieldName)
x++
}
}
s.WriteString(" from ")
s.WriteString(t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))
s.WriteString(" where ")
for x := range t.keys {
col := t.keys[x]
if x > 0 {
s.WriteString(" and ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.keyFields = append(plan.keyFields, col.fieldName)
}
s.WriteString(t.dbmap.Dialect.QuerySuffix())
plan.query = s.String()
t.getPlan = plan
}
return plan
}
// ColumnMap represents a mapping between a Go struct field and a single
// column in a table.
// Unique and MaxSize only inform the
// CreateTables() function and are not used by Insert/Update/Delete/Get.
type ColumnMap struct {
// Column name in db table
ColumnName string
// If true, this column is skipped in generated SQL statements
Transient bool
// If true, " unique" is added to create table statements.
// Not used elsewhere
Unique bool
// Passed to Dialect.ToSqlType() to assist in informing the
// correct column type to map to in CreateTables()
// Not used elsewhere
MaxSize int
fieldName string
gotype reflect.Type
isPK bool
isAutoIncr bool
isNotNull bool
}
// Rename allows you to specify the column name in the table
//
// Example: table.ColMap("Updated").Rename("date_updated")
//
func (c *ColumnMap) Rename(colname string) *ColumnMap {
c.ColumnName = colname
return c
}
// SetTransient allows you to mark the column as transient. If true
// this column will be skipped when SQL statements are generated
func (c *ColumnMap) SetTransient(b bool) *ColumnMap {
c.Transient = b
return c
}
// SetUnique adds "unique" to the create table statements for this
// column, if b is true.
func (c *ColumnMap) SetUnique(b bool) *ColumnMap {
c.Unique = b
return c
}
// SetNotNull adds "not null" to the create table statements for this
// column, if nn is true.
func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap {
c.isNotNull = nn
return c
}
// SetMaxSize specifies the max length of values of this column. This is
// passed to the dialect.ToSqlType() function, which can use the value
// to alter the generated type for "create table" statements
func (c *ColumnMap) SetMaxSize(size int) *ColumnMap {
c.MaxSize = size
return c
}
// Transaction represents a database transaction.
// Insert/Update/Delete/Get/Exec operations will be run in the context
// of that transaction. Transactions should be terminated with
// a call to Commit() or Rollback()
type Transaction struct {
dbmap *DbMap
tx *sql.Tx
closed bool
}
// SqlExecutor exposes gorp operations that can be run from Pre/Post
// hooks. This hides whether the current operation that triggered the
// hook is in a transaction.
//
// See the DbMap function docs for each of the functions below for more
// information.
type SqlExecutor interface {
Get(i interface{}, keys ...interface{}) (interface{}, error)
Insert(list ...interface{}) error
Update(list ...interface{}) (int64, error)
Delete(list ...interface{}) (int64, error)
Exec(query string, args ...interface{}) (sql.Result, error)
Select(i interface{}, query string,
args ...interface{}) ([]interface{}, error)
SelectInt(query string, args ...interface{}) (int64, error)
SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error)
SelectFloat(query string, args ...interface{}) (float64, error)
SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error)
SelectStr(query string, args ...interface{}) (string, error)
SelectNullStr(query string, args ...interface{}) (sql.NullString, error)
SelectOne(holder interface{}, query string, args ...interface{}) error
query(query string, args ...interface{}) (*sql.Rows, error)
queryRow(query string, args ...interface{}) *sql.Row
}
// Compile-time check that DbMap and Transaction implement the SqlExecutor
// interface.
var _, _ SqlExecutor = &DbMap{}, &Transaction{}
type GorpLogger interface {
Printf(format string, v ...interface{})
}
// TraceOn turns on SQL statement logging for this DbMap. After this is
// called, all SQL statements will be sent to the logger. If prefix is
// a non-empty string, it will be written to the front of all logged
// strings, which can aid in filtering log lines.
//
// Use TraceOn if you want to spy on the SQL statements that gorp
// generates.
//
// Note that the base log.Logger type satisfies GorpLogger, but adapters can
// easily be written for other logging packages (e.g., the golang-sanctioned
// glog framework).
func (m *DbMap) TraceOn(prefix string, logger GorpLogger) {
m.logger = logger
if prefix == "" {
m.logPrefix = prefix
} else {
m.logPrefix = fmt.Sprintf("%s ", prefix)
}
}
// TraceOff turns off tracing. It is idempotent.
func (m *DbMap) TraceOff() {
m.logger = nil
m.logPrefix = ""
}
// AddTable registers the given interface type with gorp. The table name
// will be given the name of the TypeOf(i). You must call this function,
// or AddTableWithName, for any struct type you wish to persist with
// the given DbMap.
//
// This operation is idempotent. If i's type is already mapped, the
// existing *TableMap is returned
func (m *DbMap) AddTable(i interface{}) *TableMap {
return m.AddTableWithName(i, "")
}
// AddTableWithName has the same behavior as AddTable, but sets
// table.TableName to name.
func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap {
return m.AddTableWithNameAndSchema(i, "", name)
}
// AddTableWithNameAndSchema has the same behavior as AddTable, but sets
// table.TableName to name.
func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name string) *TableMap {
t := reflect.TypeOf(i)
if name == "" {
name = t.Name()
}
// check if we have a table for this type already
// if so, update the name and return the existing pointer
for i := range m.tables {
table := m.tables[i]
if table.gotype == t {
table.TableName = name
return table
}
}
tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m}
tmap.Columns, tmap.version = m.readStructColumns(t)
m.tables = append(m.tables, tmap)
return tmap
}
func (m *DbMap) readStructColumns(t reflect.Type) (cols []*ColumnMap, version *ColumnMap) {
n := t.NumField()
for i := 0; i < n; i++ {
f := t.Field(i)
if f.Anonymous && f.Type.Kind() == reflect.Struct {
// Recursively add nested fields in embedded structs.
subcols, subversion := m.readStructColumns(f.Type)
// Don't append nested fields that have the same field
// name as an already-mapped field.
for _, subcol := range subcols {
shouldAppend := true
for _, col := range cols {
if !subcol.Transient && subcol.fieldName == col.fieldName {
shouldAppend = false
break
}
}
if shouldAppend {
cols = append(cols, subcol)
}
}
if subversion != nil {
version = subversion
}
} else {
columnName := f.Tag.Get("db")
if columnName == "" {
columnName = f.Name
}
gotype := f.Type
if m.TypeConverter != nil {
// Make a new pointer to a value of type gotype and
// pass it to the TypeConverter's FromDb method to see
// if a different type should be used for the column
// type during table creation.
value := reflect.New(gotype).Interface()
scanner, useHolder := m.TypeConverter.FromDb(value)
if useHolder {
gotype = reflect.TypeOf(scanner.Holder)
}
}
cm := &ColumnMap{
ColumnName: columnName,
Transient: columnName == "-",
fieldName: f.Name,
gotype: gotype,
}
// Check for nested fields of the same field name and
// override them.
shouldAppend := true
for index, col := range cols {
if !col.Transient && col.fieldName == cm.fieldName {
cols[index] = cm
shouldAppend = false
break
}
}
if shouldAppend {
cols = append(cols, cm)
}
if cm.fieldName == "Version" {
version = cm
}
}
}
return
}
// CreateTables iterates through TableMaps registered to this DbMap and
// executes "create table" statements against the database for each.
//
// This is particularly useful in unit tests where you want to create
// and destroy the schema automatically.
func (m *DbMap) CreateTables() error {
return m.createTables(false)
}
// CreateTablesIfNotExists is similar to CreateTables, but starts
// each statement with "create table if not exists" so that existing
// tables do not raise errors
func (m *DbMap) CreateTablesIfNotExists() error {
return m.createTables(true)
}
func (m *DbMap) createTables(ifNotExists bool) error {
var err error
for i := range m.tables {
table := m.tables[i]
s := bytes.Buffer{}
if strings.TrimSpace(table.SchemaName) != "" {
schemaCreate := "create schema"
if ifNotExists {
s.WriteString(m.Dialect.IfSchemaNotExists(schemaCreate, table.SchemaName))
} else {
s.WriteString(schemaCreate)
}
s.WriteString(fmt.Sprintf(" %s;", table.SchemaName))
}
tableCreate := "create table"
if ifNotExists {
s.WriteString(m.Dialect.IfTableNotExists(tableCreate, table.SchemaName, table.TableName))
} else {
s.WriteString(tableCreate)
}
s.WriteString(fmt.Sprintf(" %s (", m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
x := 0
for _, col := range table.Columns {
if !col.Transient {
if x > 0 {
s.WriteString(", ")
}
stype := m.Dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr)
s.WriteString(fmt.Sprintf("%s %s", m.Dialect.QuoteField(col.ColumnName), stype))
if col.isPK || col.isNotNull {
s.WriteString(" not null")
}
if col.isPK && len(table.keys) == 1 {
s.WriteString(" primary key")
}
if col.Unique {
s.WriteString(" unique")
}
if col.isAutoIncr {
s.WriteString(fmt.Sprintf(" %s", m.Dialect.AutoIncrStr()))
}
x++
}
}
if len(table.keys) > 1 {
s.WriteString(", primary key (")
for x := range table.keys {
if x > 0 {
s.WriteString(", ")
}
s.WriteString(m.Dialect.QuoteField(table.keys[x].ColumnName))
}
s.WriteString(")")
}
if len(table.uniqueTogether) > 0 {
for _, columns := range table.uniqueTogether {
s.WriteString(", unique (")
for i, column := range columns {
if i > 0 {
s.WriteString(", ")
}
s.WriteString(m.Dialect.QuoteField(column))
}
s.WriteString(")")
}
}
s.WriteString(") ")
s.WriteString(m.Dialect.CreateTableSuffix())
s.WriteString(m.Dialect.QuerySuffix())
_, err = m.Exec(s.String())
if err != nil {
break
}
}
return err
}
// DropTable drops an individual table. Will throw an error
// if the table does not exist.
func (m *DbMap) DropTable(table interface{}) error {
t := reflect.TypeOf(table)
return m.dropTable(t, false)
}
// DropTable drops an individual table. Will NOT throw an error
// if the table does not exist.
func (m *DbMap) DropTableIfExists(table interface{}) error {
t := reflect.TypeOf(table)
return m.dropTable(t, true)
}
// DropTables iterates through TableMaps registered to this DbMap and
// executes "drop table" statements against the database for each.
func (m *DbMap) DropTables() error {
return m.dropTables(false)
}
// DropTablesIfExists is the same as DropTables, but uses the "if exists" clause to
// avoid errors for tables that do not exist.
func (m *DbMap) DropTablesIfExists() error {
return m.dropTables(true)
}
// Goes through all the registered tables, dropping them one by one.
// If an error is encountered, then it is returned and the rest of
// the tables are not dropped.
func (m *DbMap) dropTables(addIfExists bool) (err error) {
for _, table := range m.tables {
err = m.dropTableImpl(table, addIfExists)
if err != nil {
return
}
}
return err
}
// Implementation of dropping a single table.
func (m *DbMap) dropTable(t reflect.Type, addIfExists bool) error {
table := tableOrNil(m, t)
if table == nil {
return errors.New(fmt.Sprintf("table %s was not registered!", table.TableName))
}
return m.dropTableImpl(table, addIfExists)
}
func (m *DbMap) dropTableImpl(table *TableMap, ifExists bool) (err error) {
tableDrop := "drop table"
if ifExists {
tableDrop = m.Dialect.IfTableExists(tableDrop, table.SchemaName, table.TableName)
}
_, err = m.Exec(fmt.Sprintf("%s %s;", tableDrop, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
return err
}
// TruncateTables iterates through TableMaps registered to this DbMap and
// executes "truncate table" statements against the database for each, or in the case of
// sqlite, a "delete from" with no "where" clause, which uses the truncate optimization
// (http://www.sqlite.org/lang_delete.html)
func (m *DbMap) TruncateTables() error {
var err error
for i := range m.tables {
table := m.tables[i]
_, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName)))
if e != nil {
err = e
}
}
return err
}
// Insert runs a SQL INSERT statement for each element in list. List
// items must be pointers.
//
// Any interface whose TableMap has an auto-increment primary key will
// have its last insert id bound to the PK field on the struct.
//
// The hook functions PreInsert() and/or PostInsert() will be executed
// before/after the INSERT statement if the interface defines them.
//
// Panics if any interface in the list has not been registered with AddTable
func (m *DbMap) Insert(list ...interface{}) error {
return insert(m, m, list...)
}
// Update runs a SQL UPDATE statement for each element in list. List
// items must be pointers.
//
// The hook functions PreUpdate() and/or PostUpdate() will be executed
// before/after the UPDATE statement if the interface defines them.
//
// Returns the number of rows updated.
//
// Returns an error if SetKeys has not been called on the TableMap
// Panics if any interface in the list has not been registered with AddTable
func (m *DbMap) Update(list ...interface{}) (int64, error) {
return update(m, m, list...)
}
// Delete runs a SQL DELETE statement for each element in list. List
// items must be pointers.
//
// The hook functions PreDelete() and/or PostDelete() will be executed
// before/after the DELETE statement if the interface defines them.
//
// Returns the number of rows deleted.
//
// Returns an error if SetKeys has not been called on the TableMap
// Panics if any interface in the list has not been registered with AddTable
func (m *DbMap) Delete(list ...interface{}) (int64, error) {
return delete(m, m, list...)
}
// Get runs a SQL SELECT to fetch a single row from the table based on the
// primary key(s)
//
// i should be an empty value for the struct to load. keys should be
// the primary key value(s) for the row to load. If multiple keys
// exist on the table, the order should match the column order
// specified in SetKeys() when the table mapping was defined.
//
// The hook function PostGet() will be executed after the SELECT
// statement if the interface defines them.
//
// Returns a pointer to a struct that matches or nil if no row is found.
//
// Returns an error if SetKeys has not been called on the TableMap
// Panics if any interface in the list has not been registered with AddTable
func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) {
return get(m, m, i, keys...)
}
// Select runs an arbitrary SQL query, binding the columns in the result
// to fields on the struct specified by i. args represent the bind
// parameters for the SQL statement.
//
// Column names on the SELECT statement should be aliased to the field names
// on the struct i. Returns an error if one or more columns in the result
// do not match. It is OK if fields on i are not part of the SQL
// statement.
//
// The hook function PostGet() will be executed after the SELECT
// statement if the interface defines them.
//
// Values are returned in one of two ways:
// 1. If i is a struct or a pointer to a struct, returns a slice of pointers to
// matching rows of type i.
// 2. If i is a pointer to a slice, the results will be appended to that slice
// and nil returned.
//
// i does NOT need to be registered with AddTable()
func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return hookedselect(m, m, i, query, args...)
}
// Exec runs an arbitrary SQL statement. args represent the bind parameters.
// This is equivalent to running: Exec() using database/sql
func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) {
m.trace(query, args...)
return m.Db.Exec(query, args...)
}
// SelectInt is a convenience wrapper around the gorp.SelectInt function
func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) {
return SelectInt(m, query, args...)
}
// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function
func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
return SelectNullInt(m, query, args...)
}
// SelectFloat is a convenience wrapper around the gorp.SelectFlot function
func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) {
return SelectFloat(m, query, args...)
}
// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function
func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
return SelectNullFloat(m, query, args...)
}
// SelectStr is a convenience wrapper around the gorp.SelectStr function
func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) {
return SelectStr(m, query, args...)
}
// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function
func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
return SelectNullStr(m, query, args...)
}
// SelectOne is a convenience wrapper around the gorp.SelectOne function
func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error {
return SelectOne(m, m, holder, query, args...)
}
// Begin starts a gorp Transaction
func (m *DbMap) Begin() (*Transaction, error) {
m.trace("begin;")
tx, err := m.Db.Begin()
if err != nil {
return nil, err
}
return &Transaction{m, tx, false}, nil
}
// TableFor returns the *TableMap corresponding to the given Go Type
// If no table is mapped to that type an error is returned.
// If checkPK is true and the mapped table has no registered PKs, an error is returned.
func (m *DbMap) TableFor(t reflect.Type, checkPK bool) (*TableMap, error) {
table := tableOrNil(m, t)
if table == nil {
return nil, errors.New(fmt.Sprintf("No table found for type: %v", t.Name()))
}
if checkPK && len(table.keys) < 1 {
e := fmt.Sprintf("gorp: No keys defined for table: %s",
table.TableName)
return nil, errors.New(e)
}
return table, nil
}
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the returned statement.
// This is equivalent to running: Prepare() using database/sql
func (m *DbMap) Prepare(query string) (*sql.Stmt, error) {
m.trace(query, nil)
return m.Db.Prepare(query)
}
func tableOrNil(m *DbMap, t reflect.Type) *TableMap {
for i := range m.tables {
table := m.tables[i]
if table.gotype == t {
return table
}
}
return nil
}
func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, reflect.Value, error) {
ptrv := reflect.ValueOf(ptr)
if ptrv.Kind() != reflect.Ptr {
e := fmt.Sprintf("gorp: passed non-pointer: %v (kind=%v)", ptr,
ptrv.Kind())
return nil, reflect.Value{}, errors.New(e)
}
elem := ptrv.Elem()
etype := reflect.TypeOf(elem.Interface())
t, err := m.TableFor(etype, checkPK)
if err != nil {
return nil, reflect.Value{}, err
}
return t, elem, nil
}
func (m *DbMap) queryRow(query string, args ...interface{}) *sql.Row {
m.trace(query, args...)
return m.Db.QueryRow(query, args...)
}
func (m *DbMap) query(query string, args ...interface{}) (*sql.Rows, error) {
m.trace(query, args...)
return m.Db.Query(query, args...)
}
func (m *DbMap) trace(query string, args ...interface{}) {
if m.logger != nil {
var margs = argsString(args...)
m.logger.Printf("%s%s [%s]", m.logPrefix, query, margs)
}
}
func argsString(args ...interface{}) string {
var margs string
for i, a := range args {
var v interface{} = a
if x, ok := v.(driver.Valuer); ok {
y, err := x.Value()
if err == nil {
v = y
}
}
switch v.(type) {
case string:
v = fmt.Sprintf("%q", v)
default:
v = fmt.Sprintf("%v", v)
}
margs += fmt.Sprintf("%d:%s", i+1, v)
if i+1 < len(args) {
margs += " "
}
}
return margs
}
///////////////
// Insert has the same behavior as DbMap.Insert(), but runs in a transaction.
func (t *Transaction) Insert(list ...interface{}) error {
return insert(t.dbmap, t, list...)
}
// Update had the same behavior as DbMap.Update(), but runs in a transaction.
func (t *Transaction) Update(list ...interface{}) (int64, error) {
return update(t.dbmap, t, list...)
}
// Delete has the same behavior as DbMap.Delete(), but runs in a transaction.
func (t *Transaction) Delete(list ...interface{}) (int64, error) {
return delete(t.dbmap, t, list...)
}
// Get has the same behavior as DbMap.Get(), but runs in a transaction.
func (t *Transaction) Get(i interface{}, keys ...interface{}) (interface{}, error) {
return get(t.dbmap, t, i, keys...)
}
// Select has the same behavior as DbMap.Select(), but runs in a transaction.
func (t *Transaction) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
return hookedselect(t.dbmap, t, i, query, args...)
}
// Exec has the same behavior as DbMap.Exec(), but runs in a transaction.
func (t *Transaction) Exec(query string, args ...interface{}) (sql.Result, error) {
t.dbmap.trace(query, args...)
return t.tx.Exec(query, args...)
}
// SelectInt is a convenience wrapper around the gorp.SelectInt function.
func (t *Transaction) SelectInt(query string, args ...interface{}) (int64, error) {
return SelectInt(t, query, args...)
}
// SelectNullInt is a convenience wrapper around the gorp.SelectNullInt function.
func (t *Transaction) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
return SelectNullInt(t, query, args...)
}
// SelectFloat is a convenience wrapper around the gorp.SelectFloat function.
func (t *Transaction) SelectFloat(query string, args ...interface{}) (float64, error) {
return SelectFloat(t, query, args...)
}
// SelectNullFloat is a convenience wrapper around the gorp.SelectNullFloat function.
func (t *Transaction) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
return SelectNullFloat(t, query, args...)
}
// SelectStr is a convenience wrapper around the gorp.SelectStr function.
func (t *Transaction) SelectStr(query string, args ...interface{}) (string, error) {
return SelectStr(t, query, args...)
}
// SelectNullStr is a convenience wrapper around the gorp.SelectNullStr function.
func (t *Transaction) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
return SelectNullStr(t, query, args...)
}
// SelectOne is a convenience wrapper around the gorp.SelectOne function.
func (t *Transaction) SelectOne(holder interface{}, query string, args ...interface{}) error {
return SelectOne(t.dbmap, t, holder, query, args...)
}
// Commit commits the underlying database transaction.
func (t *Transaction) Commit() error {
if !t.closed {
t.closed = true
t.dbmap.trace("commit;")
return t.tx.Commit()
}
return sql.ErrTxDone
}
// Rollback rolls back the underlying database transaction.
func (t *Transaction) Rollback() error {
if !t.closed {
t.closed = true
t.dbmap.trace("rollback;")
return t.tx.Rollback()
}
return sql.ErrTxDone
}
// Savepoint creates a savepoint with the given name. The name is interpolated
// directly into the SQL SAVEPOINT statement, so you must sanitize it if it is
// derived from user input.
func (t *Transaction) Savepoint(name string) error {
query := "savepoint " + t.dbmap.Dialect.QuoteField(name)
t.dbmap.trace(query, nil)
_, err := t.tx.Exec(query)
return err
}
// RollbackToSavepoint rolls back to the savepoint with the given name. The
// name is interpolated directly into the SQL SAVEPOINT statement, so you must
// sanitize it if it is derived from user input.
func (t *Transaction) RollbackToSavepoint(savepoint string) error {
query := "rollback to savepoint " + t.dbmap.Dialect.QuoteField(savepoint)
t.dbmap.trace(query, nil)
_, err := t.tx.Exec(query)
return err
}
// ReleaseSavepint releases the savepoint with the given name. The name is
// interpolated directly into the SQL SAVEPOINT statement, so you must sanitize
// it if it is derived from user input.
func (t *Transaction) ReleaseSavepoint(savepoint string) error {
query := "release savepoint " + t.dbmap.Dialect.QuoteField(savepoint)
t.dbmap.trace(query, nil)
_, err := t.tx.Exec(query)
return err
}
// Prepare has the same behavior as DbMap.Prepare(), but runs in a transaction.
func (t *Transaction) Prepare(query string) (*sql.Stmt, error) {
t.dbmap.trace(query, nil)
return t.tx.Prepare(query)
}
func (t *Transaction) queryRow(query string, args ...interface{}) *sql.Row {
t.dbmap.trace(query, args...)
return t.tx.QueryRow(query, args...)
}
func (t *Transaction) query(query string, args ...interface{}) (*sql.Rows, error) {
t.dbmap.trace(query, args...)
return t.tx.Query(query, args...)
}
///////////////
// SelectInt executes the given query, which should be a SELECT statement for a single
// integer column, and returns the value of the first row returned. If no rows are
// found, zero is returned.
func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) {
var h int64
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return h, nil
}
// SelectNullInt executes the given query, which should be a SELECT statement for a single
// integer column, and returns the value of the first row returned. If no rows are
// found, the empty sql.NullInt64 value is returned.
func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) {
var h sql.NullInt64
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return h, err
}
return h, nil
}
// SelectFloat executes the given query, which should be a SELECT statement for a single
// float column, and returns the value of the first row returned. If no rows are
// found, zero is returned.
func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) {
var h float64
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return h, nil
}
// SelectNullFloat executes the given query, which should be a SELECT statement for a single
// float column, and returns the value of the first row returned. If no rows are
// found, the empty sql.NullInt64 value is returned.
func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) {
var h sql.NullFloat64
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return h, err
}
return h, nil
}
// SelectStr executes the given query, which should be a SELECT statement for a single
// char/varchar column, and returns the value of the first row returned. If no rows are
// found, an empty string is returned.
func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) {
var h string
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return "", err
}
return h, nil
}
// SelectNullStr executes the given query, which should be a SELECT
// statement for a single char/varchar column, and returns the value
// of the first row returned. If no rows are found, the empty
// sql.NullString is returned.
func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) {
var h sql.NullString
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return h, err
}
return h, nil
}
// SelectOne executes the given query (which should be a SELECT statement)
// and binds the result to holder, which must be a pointer.
//
// If no row is found, an error (sql.ErrNoRows specifically) will be returned
//
// If more than one row is found, an error will be returned.
//
func SelectOne(m *DbMap, e SqlExecutor, holder interface{}, query string, args ...interface{}) error {
t := reflect.TypeOf(holder)
if t.Kind() == reflect.Ptr {
t = t.Elem()
} else {
return fmt.Errorf("gorp: SelectOne holder must be a pointer, but got: %t", holder)
}
// Handle pointer to pointer
isptr := false
if t.Kind() == reflect.Ptr {
isptr = true
t = t.Elem()
}
if t.Kind() == reflect.Struct {
var nonFatalErr error
list, err := hookedselect(m, e, holder, query, args...)
if err != nil {
if !NonFatalError(err) {
return err
}
nonFatalErr = err
}
dest := reflect.ValueOf(holder)
if isptr {
dest = dest.Elem()
}
if list != nil && len(list) > 0 {
// check for multiple rows
if len(list) > 1 {
return fmt.Errorf("gorp: multiple rows returned for: %s - %v", query, args)
}
// Initialize if nil
if dest.IsNil() {
dest.Set(reflect.New(t))
}
// only one row found
src := reflect.ValueOf(list[0])
dest.Elem().Set(src.Elem())
} else {
// No rows found, return a proper error.
return sql.ErrNoRows
}
return nonFatalErr
}
return selectVal(e, holder, query, args...)
}
func selectVal(e SqlExecutor, holder interface{}, query string, args ...interface{}) error {
if len(args) == 1 {
switch m := e.(type) {
case *DbMap:
query, args = maybeExpandNamedQuery(m, query, args)
case *Transaction:
query, args = maybeExpandNamedQuery(m.dbmap, query, args)
}
}
rows, err := e.query(query, args...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
return sql.ErrNoRows
}
return rows.Scan(holder)
}
///////////////
func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
args ...interface{}) ([]interface{}, error) {
var nonFatalErr error
list, err := rawselect(m, exec, i, query, args...)
if err != nil {
if !NonFatalError(err) {
return nil, err
}
nonFatalErr = err
}
// Determine where the results are: written to i, or returned in list
if t, _ := toSliceType(i); t == nil {
for _, v := range list {
if v, ok := v.(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}
}
} else {
resultsValue := reflect.Indirect(reflect.ValueOf(i))
for i := 0; i < resultsValue.Len(); i++ {
if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}
}
}
return list, nonFatalErr
}
func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
args ...interface{}) ([]interface{}, error) {
var (
appendToSlice = false // Write results to i directly?
intoStruct = true // Selecting into a struct?
pointerElements = true // Are the slice elements pointers (vs values)?
)
var nonFatalErr error
// get type for i, verifying it's a supported destination
t, err := toType(i)
if err != nil {
var err2 error
if t, err2 = toSliceType(i); t == nil {
if err2 != nil {
return nil, err2
}
return nil, err
}
pointerElements = t.Kind() == reflect.Ptr
if pointerElements {
t = t.Elem()
}
appendToSlice = true
intoStruct = t.Kind() == reflect.Struct
}
// If the caller supplied a single struct/map argument, assume a "named
// parameter" query. Extract the named arguments from the struct/map, create
// the flat arg slice, and rewrite the query to use the dialect's placeholder.
if len(args) == 1 {
query, args = maybeExpandNamedQuery(m, query, args)
}
// Run the query
rows, err := exec.query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
// Fetch the column names as returned from db
cols, err := rows.Columns()
if err != nil {
return nil, err
}
if !intoStruct && len(cols) > 1 {
return nil, fmt.Errorf("gorp: select into non-struct slice requires 1 column, got %d", len(cols))
}
var colToFieldIndex [][]int
if intoStruct {
if colToFieldIndex, err = columnToFieldIndex(m, t, cols); err != nil {
if !NonFatalError(err) {
return nil, err
}
nonFatalErr = err
}
}
conv := m.TypeConverter
// Add results to one of these two slices.
var (
list = make([]interface{}, 0)
sliceValue = reflect.Indirect(reflect.ValueOf(i))
)
for {
if !rows.Next() {
// if error occured return rawselect
if rows.Err() != nil {
return nil, rows.Err()
}
// time to exit from outer "for" loop
break
}
v := reflect.New(t)
dest := make([]interface{}, len(cols))
custScan := make([]CustomScanner, 0)
for x := range cols {
f := v.Elem()
if intoStruct {
index := colToFieldIndex[x]
if index == nil {
// this field is not present in the struct, so create a dummy
// value for rows.Scan to scan into
var dummy sql.RawBytes
dest[x] = &dummy
continue
}
f = f.FieldByIndex(index)
}
target := f.Addr().Interface()
if conv != nil {
scanner, ok := conv.FromDb(target)
if ok {
target = scanner.Holder
custScan = append(custScan, scanner)
}
}
dest[x] = target
}
err = rows.Scan(dest...)
if err != nil {
return nil, err
}
for _, c := range custScan {
err = c.Bind()
if err != nil {
return nil, err
}
}
if appendToSlice {
if !pointerElements {
v = v.Elem()
}
sliceValue.Set(reflect.Append(sliceValue, v))
} else {
list = append(list, v.Interface())
}
}
if appendToSlice && sliceValue.IsNil() {
sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), 0, 0))
}
return list, nonFatalErr
}
// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used
// as input to a named query. If so, it rewrites the query to use
// dialect-dependent bindvars and instantiates the corresponding slice of
// parameters by extracting data from the map / struct.
// If not, returns the input values unchanged.
func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) {
arg := reflect.ValueOf(args[0])
for arg.Kind() == reflect.Ptr {
arg = arg.Elem()
}
switch {
case arg.Kind() == reflect.Map && arg.Type().Key().Kind() == reflect.String:
return expandNamedQuery(m, query, func(key string) reflect.Value {
return arg.MapIndex(reflect.ValueOf(key))
})
// #84 - ignore time.Time structs here - there may be a cleaner way to do this
case arg.Kind() == reflect.Struct && !(arg.Type().PkgPath() == "time" && arg.Type().Name() == "Time"):
return expandNamedQuery(m, query, arg.FieldByName)
}
return query, args
}
var keyRegexp = regexp.MustCompile(`:[[:word:]]+`)
// expandNamedQuery accepts a query with placeholders of the form ":key", and a
// single arg of Kind Struct or Map[string]. It returns the query with the
// dialect's placeholders, and a slice of args ready for positional insertion
// into the query.
func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) {
var (
n int
args []interface{}
)
return keyRegexp.ReplaceAllStringFunc(query, func(key string) string {
val := keyGetter(key[1:])
if !val.IsValid() {
return key
}
args = append(args, val.Interface())
newVar := m.Dialect.BindVar(n)
n++
return newVar
}), args
}
func columnToFieldIndex(m *DbMap, t reflect.Type, cols []string) ([][]int, error) {
colToFieldIndex := make([][]int, len(cols))
// check if type t is a mapped table - if so we'll
// check the table for column aliasing below
tableMapped := false
table := tableOrNil(m, t)
if table != nil {
tableMapped = true
}
// Loop over column names and find field in i to bind to
// based on column name. all returned columns must match
// a field in the i struct
missingColNames := []string{}
for x := range cols {
colName := strings.ToLower(cols[x])
field, found := t.FieldByNameFunc(func(fieldName string) bool {
field, _ := t.FieldByName(fieldName)
fieldName = field.Tag.Get("db")
if fieldName == "-" {
return false
} else if fieldName == "" {
fieldName = field.Name
}
if tableMapped {
colMap := colMapOrNil(table, fieldName)
if colMap != nil {
fieldName = colMap.ColumnName
}
}
return colName == strings.ToLower(fieldName)
})
if found {
colToFieldIndex[x] = field.Index
}
if colToFieldIndex[x] == nil {
missingColNames = append(missingColNames, colName)
}
}
if len(missingColNames) > 0 {
return colToFieldIndex, &NoFieldInTypeError{
TypeName: t.Name(),
MissingColNames: missingColNames,
}
}
return colToFieldIndex, nil
}
func fieldByName(val reflect.Value, fieldName string) *reflect.Value {
// try to find field by exact match
f := val.FieldByName(fieldName)
if f != zeroVal {
return &f
}
// try to find by case insensitive match - only the Postgres driver
// seems to require this - in the case where columns are aliased in the sql
fieldNameL := strings.ToLower(fieldName)
fieldCount := val.NumField()
t := val.Type()
for i := 0; i < fieldCount; i++ {
sf := t.Field(i)
if strings.ToLower(sf.Name) == fieldNameL {
f := val.Field(i)
return &f
}
}
return nil
}
// toSliceType returns the element type of the given object, if the object is a
// "*[]*Element" or "*[]Element". If not, returns nil.
// err is returned if the user was trying to pass a pointer-to-slice but failed.
func toSliceType(i interface{}) (reflect.Type, error) {
t := reflect.TypeOf(i)
if t.Kind() != reflect.Ptr {
// If it's a slice, return a more helpful error message
if t.Kind() == reflect.Slice {
return nil, fmt.Errorf("gorp: Cannot SELECT into a non-pointer slice: %v", t)
}
return nil, nil
}
if t = t.Elem(); t.Kind() != reflect.Slice {
return nil, nil
}
return t.Elem(), nil
}
func toType(i interface{}) (reflect.Type, error) {
t := reflect.TypeOf(i)
// If a Pointer to a type, follow
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("gorp: Cannot SELECT into this type: %v", reflect.TypeOf(i))
}
return t, nil
}
func get(m *DbMap, exec SqlExecutor, i interface{},
keys ...interface{}) (interface{}, error) {
t, err := toType(i)
if err != nil {
return nil, err
}
table, err := m.TableFor(t, true)
if err != nil {
return nil, err
}
plan := table.bindGet()
v := reflect.New(t)
dest := make([]interface{}, len(plan.argFields))
conv := m.TypeConverter
custScan := make([]CustomScanner, 0)
for x, fieldName := range plan.argFields {
f := v.Elem().FieldByName(fieldName)
target := f.Addr().Interface()
if conv != nil {
scanner, ok := conv.FromDb(target)
if ok {
target = scanner.Holder
custScan = append(custScan, scanner)
}
}
dest[x] = target
}
row := exec.queryRow(plan.query, keys...)
err = row.Scan(dest...)
if err != nil {
if err == sql.ErrNoRows {
err = nil
}
return nil, err
}
for _, c := range custScan {
err = c.Bind()
if err != nil {
return nil, err
}
}
if v, ok := v.Interface().(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}
return v.Interface(), nil
}
func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
count := int64(0)
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreDelete); ok {
err = v.PreDelete(exec)
if err != nil {
return -1, err
}
}
bi, err := table.bindDelete(elem)
if err != nil {
return -1, err
}
res, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return -1, err
}
rows, err := res.RowsAffected()
if err != nil {
return -1, err
}
if rows == 0 && bi.existingVersion > 0 {
return lockError(m, exec, table.TableName,
bi.existingVersion, elem, bi.keys...)
}
count += rows
if v, ok := eval.(HasPostDelete); ok {
err := v.PostDelete(exec)
if err != nil {
return -1, err
}
}
}
return count, nil
}
func update(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
count := int64(0)
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreUpdate); ok {
err = v.PreUpdate(exec)
if err != nil {
return -1, err
}
}
bi, err := table.bindUpdate(elem)
if err != nil {
return -1, err
}
res, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return -1, err
}
rows, err := res.RowsAffected()
if err != nil {
return -1, err
}
if rows == 0 && bi.existingVersion > 0 {
return lockError(m, exec, table.TableName,
bi.existingVersion, elem, bi.keys...)
}
if bi.versField != "" {
elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1)
}
count += rows
if v, ok := eval.(HasPostUpdate); ok {
err = v.PostUpdate(exec)
if err != nil {
return -1, err
}
}
}
return count, nil
}
func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, false)
if err != nil {
return err
}
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreInsert); ok {
err := v.PreInsert(exec)
if err != nil {
return err
}
}
bi, err := table.bindInsert(elem)
if err != nil {
return err
}
if bi.autoIncrIdx > -1 {
f := elem.FieldByName(bi.autoIncrFieldName)
switch inserter := m.Dialect.(type) {
case IntegerAutoIncrInserter:
id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...)
if err != nil {
return err
}
k := f.Kind()
if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
f.SetInt(id)
} else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) {
f.SetUint(uint64(id))
} else {
return fmt.Errorf("gorp: Cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName)
}
case TargetedAutoIncrInserter:
err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...)
if err != nil {
return err
}
default:
return fmt.Errorf("gorp: Cannot use autoincrement fields on dialects that do not implement an autoincrementing interface")
}
} else {
_, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return err
}
}
if v, ok := eval.(HasPostInsert); ok {
err := v.PostInsert(exec)
if err != nil {
return err
}
}
}
return nil
}
func lockError(m *DbMap, exec SqlExecutor, tableName string,
existingVer int64, elem reflect.Value,
keys ...interface{}) (int64, error) {
existing, err := get(m, exec, elem.Interface(), keys...)
if err != nil {
return -1, err
}
ole := OptimisticLockError{tableName, keys, true, existingVer}
if existing == nil {
ole.RowExists = false
}
return -1, ole
}
// PostUpdate() will be executed after the GET statement.
type HasPostGet interface {
PostGet(SqlExecutor) error
}
// PostUpdate() will be executed after the DELETE statement
type HasPostDelete interface {
PostDelete(SqlExecutor) error
}
// PostUpdate() will be executed after the UPDATE statement
type HasPostUpdate interface {
PostUpdate(SqlExecutor) error
}
// PostInsert() will be executed after the INSERT statement
type HasPostInsert interface {
PostInsert(SqlExecutor) error
}
// PreDelete() will be executed before the DELETE statement.
type HasPreDelete interface {
PreDelete(SqlExecutor) error
}
// PreUpdate() will be executed before UPDATE statement.
type HasPreUpdate interface {
PreUpdate(SqlExecutor) error
}
// PreInsert() will be executed before INSERT statement.
type HasPreInsert interface {
PreInsert(SqlExecutor) error
}
package gorp
import (
"bytes"
"database/sql"
"encoding/json"
"errors"
"fmt"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
"log"
"math/rand"
"os"
"reflect"
"strings"
"testing"
"time"
)
// verify interface compliance
var _ Dialect = SqliteDialect{}
var _ Dialect = PostgresDialect{}
var _ Dialect = MySQLDialect{}
var _ Dialect = SqlServerDialect{}
var _ Dialect = OracleDialect{}
type testable interface {
GetId() int64
Rand()
}
type Invoice struct {
Id int64
Created int64
Updated int64
Memo string
PersonId int64
IsPaid bool
}
func (me *Invoice) GetId() int64 { return me.Id }
func (me *Invoice) Rand() {
me.Memo = fmt.Sprintf("random %d", rand.Int63())
me.Created = rand.Int63()
me.Updated = rand.Int63()
}
type InvoiceTag struct {
Id int64 `db:"myid"`
Created int64 `db:"myCreated"`
Updated int64 `db:"date_updated"`
Memo string
PersonId int64 `db:"person_id"`
IsPaid bool `db:"is_Paid"`
}
func (me *InvoiceTag) GetId() int64 { return me.Id }
func (me *InvoiceTag) Rand() {
me.Memo = fmt.Sprintf("random %d", rand.Int63())
me.Created = rand.Int63()
me.Updated = rand.Int63()
}
// See: https://github.com/coopernurse/gorp/issues/175
type AliasTransientField struct {
Id int64 `db:"id"`
Bar int64 `db:"-"`
BarStr string `db:"bar"`
}
func (me *AliasTransientField) GetId() int64 { return me.Id }
func (me *AliasTransientField) Rand() {
me.BarStr = fmt.Sprintf("random %d", rand.Int63())
}
type OverriddenInvoice struct {
Invoice
Id string
}
type Person struct {
Id int64
Created int64
Updated int64
FName string
LName string
Version int64
}
type FNameOnly struct {
FName string
}
type InvoicePersonView struct {
InvoiceId int64
PersonId int64
Memo string
FName string
LegacyVersion int64
}
type TableWithNull struct {
Id int64
Str sql.NullString
Int64 sql.NullInt64
Float64 sql.NullFloat64
Bool sql.NullBool
Bytes []byte
}
type WithIgnoredColumn struct {
internal int64 `db:"-"`
Id int64
Created int64
}
type IdCreated struct {
Id int64
Created int64
}
type IdCreatedExternal struct {
IdCreated
External int64
}
type WithStringPk struct {
Id string
Name string
}
type CustomStringType string
type TypeConversionExample struct {
Id int64
PersonJSON Person
Name CustomStringType
}
type PersonUInt32 struct {
Id uint32
Name string
}
type PersonUInt64 struct {
Id uint64
Name string
}
type PersonUInt16 struct {
Id uint16
Name string
}
type WithEmbeddedStruct struct {
Id int64
Names
}
type WithEmbeddedStructBeforeAutoincrField struct {
Names
Id int64
}
type WithEmbeddedAutoincr struct {
WithEmbeddedStruct
MiddleName string
}
type Names struct {
FirstName string
LastName string
}
type UniqueColumns struct {
FirstName string
LastName string
City string
ZipCode int64
}
type SingleColumnTable struct {
SomeId string
}
type CustomDate struct {
time.Time
}
type WithCustomDate struct {
Id int64
Added CustomDate
}
type testTypeConverter struct{}
func (me testTypeConverter) ToDb(val interface{}) (interface{}, error) {
switch t := val.(type) {
case Person:
b, err := json.Marshal(t)
if err != nil {
return "", err
}
return string(b), nil
case CustomStringType:
return string(t), nil
case CustomDate:
return t.Time, nil
}
return val, nil
}
func (me testTypeConverter) FromDb(target interface{}) (CustomScanner, bool) {
switch target.(type) {
case *Person:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New("FromDb: Unable to convert Person to *string")
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return CustomScanner{new(string), target, binder}, true
case *CustomStringType:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New("FromDb: Unable to convert CustomStringType to *string")
}
st, ok := target.(*CustomStringType)
if !ok {
return errors.New(fmt.Sprint("FromDb: Unable to convert target to *CustomStringType: ", reflect.TypeOf(target)))
}
*st = CustomStringType(*s)
return nil
}
return CustomScanner{new(string), target, binder}, true
case *CustomDate:
binder := func(holder, target interface{}) error {
t, ok := holder.(*time.Time)
if !ok {
return errors.New("FromDb: Unable to convert CustomDate to *time.Time")
}
dateTarget, ok := target.(*CustomDate)
if !ok {
return errors.New(fmt.Sprint("FromDb: Unable to convert target to *CustomDate: ", reflect.TypeOf(target)))
}
dateTarget.Time = *t
return nil
}
return CustomScanner{new(time.Time), target, binder}, true
}
return CustomScanner{}, false
}
func (p *Person) PreInsert(s SqlExecutor) error {
p.Created = time.Now().UnixNano()
p.Updated = p.Created
if p.FName == "badname" {
return fmt.Errorf("Invalid name: %s", p.FName)
}
return nil
}
func (p *Person) PostInsert(s SqlExecutor) error {
p.LName = "postinsert"
return nil
}
func (p *Person) PreUpdate(s SqlExecutor) error {
p.FName = "preupdate"
return nil
}
func (p *Person) PostUpdate(s SqlExecutor) error {
p.LName = "postupdate"
return nil
}
func (p *Person) PreDelete(s SqlExecutor) error {
p.FName = "predelete"
return nil
}
func (p *Person) PostDelete(s SqlExecutor) error {
p.LName = "postdelete"
return nil
}
func (p *Person) PostGet(s SqlExecutor) error {
p.LName = "postget"
return nil
}
type PersistentUser struct {
Key int32
Id string
PassedTraining bool
}
func TestCreateTablesIfNotExists(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
err := dbmap.CreateTablesIfNotExists()
if err != nil {
t.Error(err)
}
}
func TestTruncateTables(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
err := dbmap.CreateTablesIfNotExists()
if err != nil {
t.Error(err)
}
// Insert some data
p1 := &Person{0, 0, 0, "Bob", "Smith", 0}
dbmap.Insert(p1)
inv := &Invoice{0, 0, 1, "my invoice", 0, true}
dbmap.Insert(inv)
err = dbmap.TruncateTables()
if err != nil {
t.Error(err)
}
// Make sure all rows are deleted
rows, _ := dbmap.Select(Person{}, "SELECT * FROM person_test")
if len(rows) != 0 {
t.Errorf("Expected 0 person rows, got %d", len(rows))
}
rows, _ = dbmap.Select(Invoice{}, "SELECT * FROM invoice_test")
if len(rows) != 0 {
t.Errorf("Expected 0 invoice rows, got %d", len(rows))
}
}
func TestCustomDateType(t *testing.T) {
dbmap := newDbMap()
dbmap.TypeConverter = testTypeConverter{}
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
dbmap.AddTable(WithCustomDate{}).SetKeys(true, "Id")
err := dbmap.CreateTables()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
test1 := &WithCustomDate{Added: CustomDate{Time: time.Now().Truncate(time.Second)}}
err = dbmap.Insert(test1)
if err != nil {
t.Errorf("Could not insert struct with custom date field: %s", err)
t.FailNow()
}
// Unfortunately, the mysql driver doesn't handle time.Time
// values properly during Get(). I can't find a way to work
// around that problem - every other type that I've tried is just
// silently converted. time.Time is the only type that causes
// the issue that this test checks for. As such, if the driver is
// mysql, we'll just skip the rest of this test.
if _, driver := dialectAndDriver(); driver == "mysql" {
t.Skip("TestCustomDateType can't run Get() with the mysql driver; skipping the rest of this test...")
}
result, err := dbmap.Get(new(WithCustomDate), test1.Id)
if err != nil {
t.Errorf("Could not get struct with custom date field: %s", err)
t.FailNow()
}
test2 := result.(*WithCustomDate)
if test2.Added.UTC() != test1.Added.UTC() {
t.Errorf("Custom dates do not match: %v != %v", test2.Added.UTC(), test1.Added.UTC())
}
}
func TestUIntPrimaryKey(t *testing.T) {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
dbmap.AddTable(PersonUInt64{}).SetKeys(true, "Id")
dbmap.AddTable(PersonUInt32{}).SetKeys(true, "Id")
dbmap.AddTable(PersonUInt16{}).SetKeys(true, "Id")
err := dbmap.CreateTablesIfNotExists()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
p1 := &PersonUInt64{0, "name1"}
p2 := &PersonUInt32{0, "name2"}
p3 := &PersonUInt16{0, "name3"}
err = dbmap.Insert(p1, p2, p3)
if err != nil {
t.Error(err)
}
if p1.Id != 1 {
t.Errorf("%d != 1", p1.Id)
}
if p2.Id != 1 {
t.Errorf("%d != 1", p2.Id)
}
if p3.Id != 1 {
t.Errorf("%d != 1", p3.Id)
}
}
func TestSetUniqueTogether(t *testing.T) {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
dbmap.AddTable(UniqueColumns{}).SetUniqueTogether("FirstName", "LastName").SetUniqueTogether("City", "ZipCode")
err := dbmap.CreateTablesIfNotExists()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
n1 := &UniqueColumns{"Steve", "Jobs", "Cupertino", 95014}
err = dbmap.Insert(n1)
if err != nil {
t.Error(err)
}
// Should fail because of the first constraint
n2 := &UniqueColumns{"Steve", "Jobs", "Sunnyvale", 94085}
err = dbmap.Insert(n2)
if err == nil {
t.Error(err)
}
// "unique" for Postgres/SQLite, "Duplicate entry" for MySQL
errLower := strings.ToLower(err.Error())
if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") {
t.Error(err)
}
// Should also fail because of the second unique-together
n3 := &UniqueColumns{"Steve", "Wozniak", "Cupertino", 95014}
err = dbmap.Insert(n3)
if err == nil {
t.Error(err)
}
// "unique" for Postgres/SQLite, "Duplicate entry" for MySQL
errLower = strings.ToLower(err.Error())
if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") {
t.Error(err)
}
// This one should finally succeed
n4 := &UniqueColumns{"Steve", "Wozniak", "Sunnyvale", 94085}
err = dbmap.Insert(n4)
if err != nil {
t.Error(err)
}
}
func TestPersistentUser(t *testing.T) {
dbmap := newDbMap()
dbmap.Exec("drop table if exists PersistentUser")
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key")
table.ColMap("Key").Rename("mykey")
err := dbmap.CreateTablesIfNotExists()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
pu := &PersistentUser{43, "33r", false}
err = dbmap.Insert(pu)
if err != nil {
panic(err)
}
// prove we can pass a pointer into Get
pu2, err := dbmap.Get(pu, pu.Key)
if err != nil {
panic(err)
}
if !reflect.DeepEqual(pu, pu2) {
t.Errorf("%v!=%v", pu, pu2)
}
arr, err := dbmap.Select(pu, "select * from PersistentUser")
if err != nil {
panic(err)
}
if !reflect.DeepEqual(pu, arr[0]) {
t.Errorf("%v!=%v", pu, arr[0])
}
// prove we can get the results back in a slice
var puArr []*PersistentUser
_, err = dbmap.Select(&puArr, "select * from PersistentUser")
if err != nil {
panic(err)
}
if len(puArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
if !reflect.DeepEqual(pu, puArr[0]) {
t.Errorf("%v!=%v", pu, puArr[0])
}
// prove we can get the results back in a non-pointer slice
var puValues []PersistentUser
_, err = dbmap.Select(&puValues, "select * from PersistentUser")
if err != nil {
panic(err)
}
if len(puValues) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
if !reflect.DeepEqual(*pu, puValues[0]) {
t.Errorf("%v!=%v", *pu, puValues[0])
}
// prove we can get the results back in a string slice
var idArr []*string
_, err = dbmap.Select(&idArr, "select Id from PersistentUser")
if err != nil {
panic(err)
}
if len(idArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
if !reflect.DeepEqual(pu.Id, *idArr[0]) {
t.Errorf("%v!=%v", pu.Id, *idArr[0])
}
// prove we can get the results back in an int slice
var keyArr []*int32
_, err = dbmap.Select(&keyArr, "select mykey from PersistentUser")
if err != nil {
panic(err)
}
if len(keyArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
if !reflect.DeepEqual(pu.Key, *keyArr[0]) {
t.Errorf("%v!=%v", pu.Key, *keyArr[0])
}
// prove we can get the results back in a bool slice
var passedArr []*bool
_, err = dbmap.Select(&passedArr, "select PassedTraining from PersistentUser")
if err != nil {
panic(err)
}
if len(passedArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
if !reflect.DeepEqual(pu.PassedTraining, *passedArr[0]) {
t.Errorf("%v!=%v", pu.PassedTraining, *passedArr[0])
}
// prove we can get the results back in a non-pointer slice
var stringArr []string
_, err = dbmap.Select(&stringArr, "select Id from PersistentUser")
if err != nil {
panic(err)
}
if len(stringArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
if !reflect.DeepEqual(pu.Id, stringArr[0]) {
t.Errorf("%v!=%v", pu.Id, stringArr[0])
}
}
func TestNamedQueryMap(t *testing.T) {
dbmap := newDbMap()
dbmap.Exec("drop table if exists PersistentUser")
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key")
table.ColMap("Key").Rename("mykey")
err := dbmap.CreateTablesIfNotExists()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
pu := &PersistentUser{43, "33r", false}
pu2 := &PersistentUser{500, "abc", false}
err = dbmap.Insert(pu, pu2)
if err != nil {
panic(err)
}
// Test simple case
var puArr []*PersistentUser
_, err = dbmap.Select(&puArr, "select * from PersistentUser where mykey = :Key", map[string]interface{}{
"Key": 43,
})
if err != nil {
t.Errorf("Failed to select: %s", err)
t.FailNow()
}
if len(puArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
if !reflect.DeepEqual(pu, puArr[0]) {
t.Errorf("%v!=%v", pu, puArr[0])
}
// Test more specific map value type is ok
puArr = nil
_, err = dbmap.Select(&puArr, "select * from PersistentUser where mykey = :Key", map[string]int{
"Key": 43,
})
if err != nil {
t.Errorf("Failed to select: %s", err)
t.FailNow()
}
if len(puArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
// Test multiple parameters set.
puArr = nil
_, err = dbmap.Select(&puArr, `
select * from PersistentUser
where mykey = :Key
and PassedTraining = :PassedTraining
and Id = :Id`, map[string]interface{}{
"Key": 43,
"PassedTraining": false,
"Id": "33r",
})
if err != nil {
t.Errorf("Failed to select: %s", err)
t.FailNow()
}
if len(puArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
// Test colon within a non-key string
// Test having extra, unused properties in the map.
puArr = nil
_, err = dbmap.Select(&puArr, `
select * from PersistentUser
where mykey = :Key
and Id != 'abc:def'`, map[string]interface{}{
"Key": 43,
"PassedTraining": false,
})
if err != nil {
t.Errorf("Failed to select: %s", err)
t.FailNow()
}
if len(puArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
}
func TestNamedQueryStruct(t *testing.T) {
dbmap := newDbMap()
dbmap.Exec("drop table if exists PersistentUser")
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key")
table.ColMap("Key").Rename("mykey")
err := dbmap.CreateTablesIfNotExists()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
pu := &PersistentUser{43, "33r", false}
pu2 := &PersistentUser{500, "abc", false}
err = dbmap.Insert(pu, pu2)
if err != nil {
panic(err)
}
// Test select self
var puArr []*PersistentUser
_, err = dbmap.Select(&puArr, `
select * from PersistentUser
where mykey = :Key
and PassedTraining = :PassedTraining
and Id = :Id`, pu)
if err != nil {
t.Errorf("Failed to select: %s", err)
t.FailNow()
}
if len(puArr) != 1 {
t.Errorf("Expected one persistentuser, found none")
}
if !reflect.DeepEqual(pu, puArr[0]) {
t.Errorf("%v!=%v", pu, puArr[0])
}
}
// Ensure that the slices containing SQL results are non-nil when the result set is empty.
func TestReturnsNonNilSlice(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
noResultsSQL := "select * from invoice_test where id=99999"
var r1 []*Invoice
_rawselect(dbmap, &r1, noResultsSQL)
if r1 == nil {
t.Errorf("r1==nil")
}
r2 := _rawselect(dbmap, Invoice{}, noResultsSQL)
if r2 == nil {
t.Errorf("r2==nil")
}
}
func TestOverrideVersionCol(t *testing.T) {
dbmap := newDbMap()
t1 := dbmap.AddTable(InvoicePersonView{}).SetKeys(false, "InvoiceId", "PersonId")
err := dbmap.CreateTables()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
c1 := t1.SetVersionCol("LegacyVersion")
if c1.ColumnName != "LegacyVersion" {
t.Errorf("Wrong col returned: %v", c1)
}
ipv := &InvoicePersonView{1, 2, "memo", "fname", 0}
_update(dbmap, ipv)
if ipv.LegacyVersion != 1 {
t.Errorf("LegacyVersion not updated: %d", ipv.LegacyVersion)
}
}
func TestOptimisticLocking(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
p1 := &Person{0, 0, 0, "Bob", "Smith", 0}
dbmap.Insert(p1) // Version is now 1
if p1.Version != 1 {
t.Errorf("Insert didn't incr Version: %d != %d", 1, p1.Version)
return
}
if p1.Id == 0 {
t.Errorf("Insert didn't return a generated PK")
return
}
obj, err := dbmap.Get(Person{}, p1.Id)
if err != nil {
panic(err)
}
p2 := obj.(*Person)
p2.LName = "Edwards"
dbmap.Update(p2) // Version is now 2
if p2.Version != 2 {
t.Errorf("Update didn't incr Version: %d != %d", 2, p2.Version)
}
p1.LName = "Howard"
count, err := dbmap.Update(p1)
if _, ok := err.(OptimisticLockError); !ok {
t.Errorf("update - Expected OptimisticLockError, got: %v", err)
}
if count != -1 {
t.Errorf("update - Expected -1 count, got: %d", count)
}
count, err = dbmap.Delete(p1)
if _, ok := err.(OptimisticLockError); !ok {
t.Errorf("delete - Expected OptimisticLockError, got: %v", err)
}
if count != -1 {
t.Errorf("delete - Expected -1 count, got: %d", count)
}
}
// what happens if a legacy table has a null value?
func TestDoubleAddTable(t *testing.T) {
dbmap := newDbMap()
t1 := dbmap.AddTable(TableWithNull{}).SetKeys(false, "Id")
t2 := dbmap.AddTable(TableWithNull{})
if t1 != t2 {
t.Errorf("%v != %v", t1, t2)
}
}
// what happens if a legacy table has a null value?
func TestNullValues(t *testing.T) {
dbmap := initDbMapNulls()
defer dropAndClose(dbmap)
// insert a row directly
_rawexec(dbmap, "insert into TableWithNull values (10, null, "+
"null, null, null, null)")
// try to load it
expected := &TableWithNull{Id: 10}
obj := _get(dbmap, TableWithNull{}, 10)
t1 := obj.(*TableWithNull)
if !reflect.DeepEqual(expected, t1) {
t.Errorf("%v != %v", expected, t1)
}
// update it
t1.Str = sql.NullString{"hi", true}
expected.Str = t1.Str
t1.Int64 = sql.NullInt64{999, true}
expected.Int64 = t1.Int64
t1.Float64 = sql.NullFloat64{53.33, true}
expected.Float64 = t1.Float64
t1.Bool = sql.NullBool{true, true}
expected.Bool = t1.Bool
t1.Bytes = []byte{1, 30, 31, 33}
expected.Bytes = t1.Bytes
_update(dbmap, t1)
obj = _get(dbmap, TableWithNull{}, 10)
t1 = obj.(*TableWithNull)
if t1.Str.String != "hi" {
t.Errorf("%s != hi", t1.Str.String)
}
if !reflect.DeepEqual(expected, t1) {
t.Errorf("%v != %v", expected, t1)
}
}
func TestColumnProps(t *testing.T) {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id")
t1.ColMap("Created").Rename("date_created")
t1.ColMap("Updated").SetTransient(true)
t1.ColMap("Memo").SetMaxSize(10)
t1.ColMap("PersonId").SetUnique(true)
err := dbmap.CreateTables()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
// test transient
inv := &Invoice{0, 0, 1, "my invoice", 0, true}
_insert(dbmap, inv)
obj := _get(dbmap, Invoice{}, inv.Id)
inv = obj.(*Invoice)
if inv.Updated != 0 {
t.Errorf("Saved transient column 'Updated'")
}
// test max size
inv.Memo = "this memo is too long"
err = dbmap.Insert(inv)
if err == nil {
t.Errorf("max size exceeded, but Insert did not fail.")
}
// test unique - same person id
inv = &Invoice{0, 0, 1, "my invoice2", 0, false}
err = dbmap.Insert(inv)
if err == nil {
t.Errorf("same PersonId inserted, but Insert did not fail.")
}
}
func TestRawSelect(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
p1 := &Person{0, 0, 0, "bob", "smith", 0}
_insert(dbmap, p1)
inv1 := &Invoice{0, 0, 0, "xmas order", p1.Id, true}
_insert(dbmap, inv1)
expected := &InvoicePersonView{inv1.Id, p1.Id, inv1.Memo, p1.FName, 0}
query := "select i.Id InvoiceId, p.Id PersonId, i.Memo, p.FName " +
"from invoice_test i, person_test p " +
"where i.PersonId = p.Id"
list := _rawselect(dbmap, InvoicePersonView{}, query)
if len(list) != 1 {
t.Errorf("len(list) != 1: %d", len(list))
} else if !reflect.DeepEqual(expected, list[0]) {
t.Errorf("%v != %v", expected, list[0])
}
}
func TestHooks(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
p1 := &Person{0, 0, 0, "bob", "smith", 0}
_insert(dbmap, p1)
if p1.Created == 0 || p1.Updated == 0 {
t.Errorf("p1.PreInsert() didn't run: %v", p1)
} else if p1.LName != "postinsert" {
t.Errorf("p1.PostInsert() didn't run: %v", p1)
}
obj := _get(dbmap, Person{}, p1.Id)
p1 = obj.(*Person)
if p1.LName != "postget" {
t.Errorf("p1.PostGet() didn't run: %v", p1)
}
_update(dbmap, p1)
if p1.FName != "preupdate" {
t.Errorf("p1.PreUpdate() didn't run: %v", p1)
} else if p1.LName != "postupdate" {
t.Errorf("p1.PostUpdate() didn't run: %v", p1)
}
var persons []*Person
bindVar := dbmap.Dialect.BindVar(0)
_rawselect(dbmap, &persons, "select * from person_test where id = "+bindVar, p1.Id)
if persons[0].LName != "postget" {
t.Errorf("p1.PostGet() didn't run after select: %v", p1)
}
_del(dbmap, p1)
if p1.FName != "predelete" {
t.Errorf("p1.PreDelete() didn't run: %v", p1)
} else if p1.LName != "postdelete" {
t.Errorf("p1.PostDelete() didn't run: %v", p1)
}
// Test error case
p2 := &Person{0, 0, 0, "badname", "", 0}
err := dbmap.Insert(p2)
if err == nil {
t.Errorf("p2.PreInsert() didn't return an error")
}
}
func TestTransaction(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
inv1 := &Invoice{0, 100, 200, "t1", 0, true}
inv2 := &Invoice{0, 100, 200, "t2", 0, false}
trans, err := dbmap.Begin()
if err != nil {
panic(err)
}
trans.Insert(inv1, inv2)
err = trans.Commit()
if err != nil {
panic(err)
}
obj, err := dbmap.Get(Invoice{}, inv1.Id)
if err != nil {
panic(err)
}
if !reflect.DeepEqual(inv1, obj) {
t.Errorf("%v != %v", inv1, obj)
}
obj, err = dbmap.Get(Invoice{}, inv2.Id)
if err != nil {
panic(err)
}
if !reflect.DeepEqual(inv2, obj) {
t.Errorf("%v != %v", inv2, obj)
}
}
func TestSavepoint(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
inv1 := &Invoice{0, 100, 200, "unpaid", 0, false}
trans, err := dbmap.Begin()
if err != nil {
panic(err)
}
trans.Insert(inv1)
var checkMemo = func(want string) {
memo, err := trans.SelectStr("select memo from invoice_test")
if err != nil {
panic(err)
}
if memo != want {
t.Errorf("%q != %q", want, memo)
}
}
checkMemo("unpaid")
err = trans.Savepoint("foo")
if err != nil {
panic(err)
}
checkMemo("unpaid")
inv1.Memo = "paid"
_, err = trans.Update(inv1)
if err != nil {
panic(err)
}
checkMemo("paid")
err = trans.RollbackToSavepoint("foo")
if err != nil {
panic(err)
}
checkMemo("unpaid")
err = trans.Rollback()
if err != nil {
panic(err)
}
}
func TestMultiple(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
inv1 := &Invoice{0, 100, 200, "a", 0, false}
inv2 := &Invoice{0, 100, 200, "b", 0, true}
_insert(dbmap, inv1, inv2)
inv1.Memo = "c"
inv2.Memo = "d"
_update(dbmap, inv1, inv2)
count := _del(dbmap, inv1, inv2)
if count != 2 {
t.Errorf("%d != 2", count)
}
}
func TestCrud(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
inv := &Invoice{0, 100, 200, "first order", 0, true}
testCrudInternal(t, dbmap, inv)
invtag := &InvoiceTag{0, 300, 400, "some order", 33, false}
testCrudInternal(t, dbmap, invtag)
foo := &AliasTransientField{BarStr: "some bar"}
testCrudInternal(t, dbmap, foo)
}
func testCrudInternal(t *testing.T, dbmap *DbMap, val testable) {
table, _, err := dbmap.tableForPointer(val, false)
if err != nil {
t.Errorf("couldn't call TableFor: val=%v err=%v", val, err)
}
_, err = dbmap.Exec("delete from " + table.TableName)
if err != nil {
t.Errorf("couldn't delete rows from: val=%v err=%v", val, err)
}
// INSERT row
_insert(dbmap, val)
if val.GetId() == 0 {
t.Errorf("val.GetId() was not set on INSERT")
return
}
// SELECT row
val2 := _get(dbmap, val, val.GetId())
if !reflect.DeepEqual(val, val2) {
t.Errorf("%v != %v", val, val2)
}
// UPDATE row and SELECT
val.Rand()
count := _update(dbmap, val)
if count != 1 {
t.Errorf("update 1 != %d", count)
}
val2 = _get(dbmap, val, val.GetId())
if !reflect.DeepEqual(val, val2) {
t.Errorf("%v != %v", val, val2)
}
// Select *
rows, err := dbmap.Select(val, "select * from "+table.TableName)
if err != nil {
t.Errorf("couldn't select * from %s err=%v", table.TableName, err)
} else if len(rows) != 1 {
t.Errorf("unexpected row count in %s: %d", table.TableName, len(rows))
} else if !reflect.DeepEqual(val, rows[0]) {
t.Errorf("select * result: %v != %v", val, rows[0])
}
// DELETE row
deleted := _del(dbmap, val)
if deleted != 1 {
t.Errorf("Did not delete row with Id: %d", val.GetId())
return
}
// VERIFY deleted
val2 = _get(dbmap, val, val.GetId())
if val2 != nil {
t.Errorf("Found invoice with id: %d after Delete()", val.GetId())
}
}
func TestWithIgnoredColumn(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
ic := &WithIgnoredColumn{-1, 0, 1}
_insert(dbmap, ic)
expected := &WithIgnoredColumn{0, 1, 1}
ic2 := _get(dbmap, WithIgnoredColumn{}, ic.Id).(*WithIgnoredColumn)
if !reflect.DeepEqual(expected, ic2) {
t.Errorf("%v != %v", expected, ic2)
}
if _del(dbmap, ic) != 1 {
t.Errorf("Did not delete row with Id: %d", ic.Id)
return
}
if _get(dbmap, WithIgnoredColumn{}, ic.Id) != nil {
t.Errorf("Found id: %d after Delete()", ic.Id)
}
}
func TestTypeConversionExample(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
p := Person{FName: "Bob", LName: "Smith"}
tc := &TypeConversionExample{-1, p, CustomStringType("hi")}
_insert(dbmap, tc)
expected := &TypeConversionExample{1, p, CustomStringType("hi")}
tc2 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample)
if !reflect.DeepEqual(expected, tc2) {
t.Errorf("tc2 %v != %v", expected, tc2)
}
tc2.Name = CustomStringType("hi2")
tc2.PersonJSON = Person{FName: "Jane", LName: "Doe"}
_update(dbmap, tc2)
expected = &TypeConversionExample{1, tc2.PersonJSON, CustomStringType("hi2")}
tc3 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample)
if !reflect.DeepEqual(expected, tc3) {
t.Errorf("tc3 %v != %v", expected, tc3)
}
if _del(dbmap, tc) != 1 {
t.Errorf("Did not delete row with Id: %d", tc.Id)
}
}
func TestWithEmbeddedStruct(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
es := &WithEmbeddedStruct{-1, Names{FirstName: "Alice", LastName: "Smith"}}
_insert(dbmap, es)
expected := &WithEmbeddedStruct{1, Names{FirstName: "Alice", LastName: "Smith"}}
es2 := _get(dbmap, WithEmbeddedStruct{}, es.Id).(*WithEmbeddedStruct)
if !reflect.DeepEqual(expected, es2) {
t.Errorf("%v != %v", expected, es2)
}
es2.FirstName = "Bob"
expected.FirstName = "Bob"
_update(dbmap, es2)
es2 = _get(dbmap, WithEmbeddedStruct{}, es.Id).(*WithEmbeddedStruct)
if !reflect.DeepEqual(expected, es2) {
t.Errorf("%v != %v", expected, es2)
}
ess := _rawselect(dbmap, WithEmbeddedStruct{}, "select * from embedded_struct_test")
if !reflect.DeepEqual(es2, ess[0]) {
t.Errorf("%v != %v", es2, ess[0])
}
}
func TestWithEmbeddedStructBeforeAutoincr(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
esba := &WithEmbeddedStructBeforeAutoincrField{Names: Names{FirstName: "Alice", LastName: "Smith"}}
_insert(dbmap, esba)
var expectedAutoincrId int64 = 1
if esba.Id != expectedAutoincrId {
t.Errorf("%d != %d", expectedAutoincrId, esba.Id)
}
}
func TestWithEmbeddedAutoincr(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
esa := &WithEmbeddedAutoincr{
WithEmbeddedStruct: WithEmbeddedStruct{Names: Names{FirstName: "Alice", LastName: "Smith"}},
MiddleName: "Rose",
}
_insert(dbmap, esa)
var expectedAutoincrId int64 = 1
if esa.Id != expectedAutoincrId {
t.Errorf("%d != %d", expectedAutoincrId, esa.Id)
}
}
func TestSelectVal(t *testing.T) {
dbmap := initDbMapNulls()
defer dropAndClose(dbmap)
bindVar := dbmap.Dialect.BindVar(0)
t1 := TableWithNull{Str: sql.NullString{"abc", true},
Int64: sql.NullInt64{78, true},
Float64: sql.NullFloat64{32.2, true},
Bool: sql.NullBool{true, true},
Bytes: []byte("hi")}
_insert(dbmap, &t1)
// SelectInt
i64 := selectInt(dbmap, "select Int64 from TableWithNull where Str='abc'")
if i64 != 78 {
t.Errorf("int64 %d != 78", i64)
}
i64 = selectInt(dbmap, "select count(*) from TableWithNull")
if i64 != 1 {
t.Errorf("int64 count %d != 1", i64)
}
i64 = selectInt(dbmap, "select count(*) from TableWithNull where Str="+bindVar, "asdfasdf")
if i64 != 0 {
t.Errorf("int64 no rows %d != 0", i64)
}
// SelectNullInt
n := selectNullInt(dbmap, "select Int64 from TableWithNull where Str='notfound'")
if !reflect.DeepEqual(n, sql.NullInt64{0, false}) {
t.Errorf("nullint %v != 0,false", n)
}
n = selectNullInt(dbmap, "select Int64 from TableWithNull where Str='abc'")
if !reflect.DeepEqual(n, sql.NullInt64{78, true}) {
t.Errorf("nullint %v != 78, true", n)
}
// SelectFloat
f64 := selectFloat(dbmap, "select Float64 from TableWithNull where Str='abc'")
if f64 != 32.2 {
t.Errorf("float64 %d != 32.2", f64)
}
f64 = selectFloat(dbmap, "select min(Float64) from TableWithNull")
if f64 != 32.2 {
t.Errorf("float64 min %d != 32.2", f64)
}
f64 = selectFloat(dbmap, "select count(*) from TableWithNull where Str="+bindVar, "asdfasdf")
if f64 != 0 {
t.Errorf("float64 no rows %d != 0", f64)
}
// SelectNullFloat
nf := selectNullFloat(dbmap, "select Float64 from TableWithNull where Str='notfound'")
if !reflect.DeepEqual(nf, sql.NullFloat64{0, false}) {
t.Errorf("nullfloat %v != 0,false", nf)
}
nf = selectNullFloat(dbmap, "select Float64 from TableWithNull where Str='abc'")
if !reflect.DeepEqual(nf, sql.NullFloat64{32.2, true}) {
t.Errorf("nullfloat %v != 32.2, true", nf)
}
// SelectStr
s := selectStr(dbmap, "select Str from TableWithNull where Int64="+bindVar, 78)
if s != "abc" {
t.Errorf("s %s != abc", s)
}
s = selectStr(dbmap, "select Str from TableWithNull where Str='asdfasdf'")
if s != "" {
t.Errorf("s no rows %s != ''", s)
}
// SelectNullStr
ns := selectNullStr(dbmap, "select Str from TableWithNull where Int64="+bindVar, 78)
if !reflect.DeepEqual(ns, sql.NullString{"abc", true}) {
t.Errorf("nullstr %v != abc,true", ns)
}
ns = selectNullStr(dbmap, "select Str from TableWithNull where Str='asdfasdf'")
if !reflect.DeepEqual(ns, sql.NullString{"", false}) {
t.Errorf("nullstr no rows %v != '',false", ns)
}
// SelectInt/Str with named parameters
i64 = selectInt(dbmap, "select Int64 from TableWithNull where Str=:abc", map[string]string{"abc": "abc"})
if i64 != 78 {
t.Errorf("int64 %d != 78", i64)
}
ns = selectNullStr(dbmap, "select Str from TableWithNull where Int64=:num", map[string]int{"num": 78})
if !reflect.DeepEqual(ns, sql.NullString{"abc", true}) {
t.Errorf("nullstr %v != abc,true", ns)
}
}
func TestVersionMultipleRows(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
persons := []*Person{
&Person{0, 0, 0, "Bob", "Smith", 0},
&Person{0, 0, 0, "Jane", "Smith", 0},
&Person{0, 0, 0, "Mike", "Smith", 0},
}
_insert(dbmap, persons[0], persons[1], persons[2])
for x, p := range persons {
if p.Version != 1 {
t.Errorf("person[%d].Version != 1: %d", x, p.Version)
}
}
}
func TestWithStringPk(t *testing.T) {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
dbmap.AddTableWithName(WithStringPk{}, "string_pk_test").SetKeys(true, "Id")
_, err := dbmap.Exec("create table string_pk_test (Id varchar(255), Name varchar(255));")
if err != nil {
t.Errorf("couldn't create string_pk_test: %v", err)
}
defer dropAndClose(dbmap)
row := &WithStringPk{"1", "foo"}
err = dbmap.Insert(row)
if err == nil {
t.Errorf("Expected error when inserting into table w/non Int PK and autoincr set true")
}
}
// TestSqlExecutorInterfaceSelects ensures that all DbMap methods starting with Select...
// are also exposed in the SqlExecutor interface. Select... functions can always
// run on Pre/Post hooks.
func TestSqlExecutorInterfaceSelects(t *testing.T) {
dbMapType := reflect.TypeOf(&DbMap{})
sqlExecutorType := reflect.TypeOf((*SqlExecutor)(nil)).Elem()
numDbMapMethods := dbMapType.NumMethod()
for i := 0; i < numDbMapMethods; i += 1 {
dbMapMethod := dbMapType.Method(i)
if !strings.HasPrefix(dbMapMethod.Name, "Select") {
continue
}
if _, found := sqlExecutorType.MethodByName(dbMapMethod.Name); !found {
t.Errorf("Method %s is defined on DbMap but not implemented in SqlExecutor",
dbMapMethod.Name)
}
}
}
type WithTime struct {
Id int64
Time time.Time
}
type Times struct {
One time.Time
Two time.Time
}
type EmbeddedTime struct {
Id string
Times
}
func parseTimeOrPanic(format, date string) time.Time {
t1, err := time.Parse(format, date)
if err != nil {
panic(err)
}
return t1
}
// TODO: re-enable next two tests when this is merged:
// https://github.com/ziutek/mymysql/pull/77
//
// This test currently fails w/MySQL b/c tz info is lost
func testWithTime(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
t1 := parseTimeOrPanic("2006-01-02 15:04:05 -0700 MST",
"2013-08-09 21:30:43 +0800 CST")
w1 := WithTime{1, t1}
_insert(dbmap, &w1)
obj := _get(dbmap, WithTime{}, w1.Id)
w2 := obj.(*WithTime)
if w1.Time.UnixNano() != w2.Time.UnixNano() {
t.Errorf("%v != %v", w1, w2)
}
}
// See: https://github.com/coopernurse/gorp/issues/86
func testEmbeddedTime(t *testing.T) {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
dbmap.AddTable(EmbeddedTime{}).SetKeys(false, "Id")
defer dropAndClose(dbmap)
err := dbmap.CreateTables()
if err != nil {
t.Fatal(err)
}
time1 := parseTimeOrPanic("2006-01-02 15:04:05", "2013-08-09 21:30:43")
t1 := &EmbeddedTime{Id: "abc", Times: Times{One: time1, Two: time1.Add(10 * time.Second)}}
_insert(dbmap, t1)
x := _get(dbmap, EmbeddedTime{}, t1.Id)
t2, _ := x.(*EmbeddedTime)
if t1.One.UnixNano() != t2.One.UnixNano() || t1.Two.UnixNano() != t2.Two.UnixNano() {
t.Errorf("%v != %v", t1, t2)
}
}
func TestWithTimeSelect(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
halfhourago := time.Now().UTC().Add(-30 * time.Minute)
w1 := WithTime{1, halfhourago.Add(time.Minute * -1)}
w2 := WithTime{2, halfhourago.Add(time.Second)}
_insert(dbmap, &w1, &w2)
var caseIds []int64
_, err := dbmap.Select(&caseIds, "SELECT id FROM time_test WHERE Time < "+dbmap.Dialect.BindVar(0), halfhourago)
if err != nil {
t.Error(err)
}
if len(caseIds) != 1 {
t.Errorf("%d != 1", len(caseIds))
}
if caseIds[0] != w1.Id {
t.Errorf("%d != %d", caseIds[0], w1.Id)
}
}
func TestInvoicePersonView(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
// Create some rows
p1 := &Person{0, 0, 0, "bob", "smith", 0}
dbmap.Insert(p1)
// notice how we can wire up p1.Id to the invoice easily
inv1 := &Invoice{0, 0, 0, "xmas order", p1.Id, false}
dbmap.Insert(inv1)
// Run your query
query := "select i.Id InvoiceId, p.Id PersonId, i.Memo, p.FName " +
"from invoice_test i, person_test p " +
"where i.PersonId = p.Id"
// pass a slice of pointers to Select()
// this avoids the need to type assert after the query is run
var list []*InvoicePersonView
_, err := dbmap.Select(&list, query)
if err != nil {
panic(err)
}
// this should test true
expected := &InvoicePersonView{inv1.Id, p1.Id, inv1.Memo, p1.FName, 0}
if !reflect.DeepEqual(list[0], expected) {
t.Errorf("%v != %v", list[0], expected)
}
}
func TestQuoteTableNames(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
quotedTableName := dbmap.Dialect.QuoteField("person_test")
// Use a buffer to hold the log to check generated queries
logBuffer := &bytes.Buffer{}
dbmap.TraceOn("", log.New(logBuffer, "gorptest:", log.Lmicroseconds))
// Create some rows
p1 := &Person{0, 0, 0, "bob", "smith", 0}
errorTemplate := "Expected quoted table name %v in query but didn't find it"
// Check if Insert quotes the table name
id := dbmap.Insert(p1)
if !bytes.Contains(logBuffer.Bytes(), []byte(quotedTableName)) {
t.Errorf(errorTemplate, quotedTableName)
}
logBuffer.Reset()
// Check if Get quotes the table name
dbmap.Get(Person{}, id)
if !bytes.Contains(logBuffer.Bytes(), []byte(quotedTableName)) {
t.Errorf(errorTemplate, quotedTableName)
}
logBuffer.Reset()
}
func TestSelectTooManyCols(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
p1 := &Person{0, 0, 0, "bob", "smith", 0}
p2 := &Person{0, 0, 0, "jane", "doe", 0}
_insert(dbmap, p1)
_insert(dbmap, p2)
obj := _get(dbmap, Person{}, p1.Id)
p1 = obj.(*Person)
obj = _get(dbmap, Person{}, p2.Id)
p2 = obj.(*Person)
params := map[string]interface{}{
"Id": p1.Id,
}
var p3 FNameOnly
err := dbmap.SelectOne(&p3, "select * from person_test where Id=:Id", params)
if err != nil {
if !NonFatalError(err) {
t.Error(err)
}
} else {
t.Errorf("Non-fatal error expected")
}
if p1.FName != p3.FName {
t.Errorf("%v != %v", p1.FName, p3.FName)
}
var pSlice []FNameOnly
_, err = dbmap.Select(&pSlice, "select * from person_test order by fname asc")
if err != nil {
if !NonFatalError(err) {
t.Error(err)
}
} else {
t.Errorf("Non-fatal error expected")
}
if p1.FName != pSlice[0].FName {
t.Errorf("%v != %v", p1.FName, pSlice[0].FName)
}
if p2.FName != pSlice[1].FName {
t.Errorf("%v != %v", p2.FName, pSlice[1].FName)
}
}
func TestSelectSingleVal(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
p1 := &Person{0, 0, 0, "bob", "smith", 0}
_insert(dbmap, p1)
obj := _get(dbmap, Person{}, p1.Id)
p1 = obj.(*Person)
params := map[string]interface{}{
"Id": p1.Id,
}
var p2 Person
err := dbmap.SelectOne(&p2, "select * from person_test where Id=:Id", params)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(p1, &p2) {
t.Errorf("%v != %v", p1, &p2)
}
// verify SelectOne allows non-struct holders
var s string
err = dbmap.SelectOne(&s, "select FName from person_test where Id=:Id", params)
if err != nil {
t.Error(err)
}
if s != "bob" {
t.Error("Expected bob but got: " + s)
}
// verify SelectOne requires pointer receiver
err = dbmap.SelectOne(s, "select FName from person_test where Id=:Id", params)
if err == nil {
t.Error("SelectOne should have returned error for non-pointer holder")
}
// verify SelectOne works with uninitialized pointers
var p3 *Person
err = dbmap.SelectOne(&p3, "select * from person_test where Id=:Id", params)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(p1, p3) {
t.Errorf("%v != %v", p1, p3)
}
// verify that the receiver is still nil if nothing was found
var p4 *Person
dbmap.SelectOne(&p3, "select * from person_test where 2<1 AND Id=:Id", params)
if p4 != nil {
t.Error("SelectOne should not have changed a nil receiver when no rows were found")
}
// verify that the error is set to sql.ErrNoRows if not found
err = dbmap.SelectOne(&p2, "select * from person_test where Id=:Id", map[string]interface{}{
"Id": -2222,
})
if err == nil || err != sql.ErrNoRows {
t.Error("SelectOne should have returned an sql.ErrNoRows")
}
_insert(dbmap, &Person{0, 0, 0, "bob", "smith", 0})
err = dbmap.SelectOne(&p2, "select * from person_test where Fname='bob'")
if err == nil {
t.Error("Expected error when two rows found")
}
// tests for #150
var tInt int64
var tStr string
var tBool bool
var tFloat float64
primVals := []interface{}{tInt, tStr, tBool, tFloat}
for _, prim := range primVals {
err = dbmap.SelectOne(&prim, "select * from person_test where Id=-123")
if err == nil || err != sql.ErrNoRows {
t.Error("primVals: SelectOne should have returned sql.ErrNoRows")
}
}
}
func TestSelectAlias(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
p1 := &IdCreatedExternal{IdCreated: IdCreated{Id: 1, Created: 3}, External: 2}
// Insert using embedded IdCreated, which reflects the structure of the table
_insert(dbmap, &p1.IdCreated)
// Select into IdCreatedExternal type, which includes some fields not present
// in id_created_test
var p2 IdCreatedExternal
err := dbmap.SelectOne(&p2, "select * from id_created_test where Id=1")
if err != nil {
t.Error(err)
}
if p2.Id != 1 || p2.Created != 3 || p2.External != 0 {
t.Error("Expected ignored field defaults to not set")
}
// Prove that we can supply an aliased value in the select, and that it will
// automatically map to IdCreatedExternal.External
err = dbmap.SelectOne(&p2, "SELECT *, 1 AS external FROM id_created_test")
if err != nil {
t.Error(err)
}
if p2.External != 1 {
t.Error("Expected select as can map to exported field.")
}
var rows *sql.Rows
var cols []string
rows, err = dbmap.Db.Query("SELECT * FROM id_created_test")
cols, err = rows.Columns()
if err != nil || len(cols) != 2 {
t.Error("Expected ignored column not created")
}
}
func TestMysqlPanicIfDialectNotInitialized(t *testing.T) {
_, driver := dialectAndDriver()
// this test only applies to MySQL
if os.Getenv("GORP_TEST_DIALECT") != "mysql" {
return
}
// The expected behaviour is to catch a panic.
// Here is the deferred function which will check if a panic has indeed occurred :
defer func() {
r := recover()
if r == nil {
t.Error("db.CreateTables() should panic if db is initialized with an incorrect MySQLDialect")
}
}()
// invalid MySQLDialect : does not contain Engine or Encoding specification
dialect := MySQLDialect{}
db := &DbMap{Db: connect(driver), Dialect: dialect}
db.AddTableWithName(Invoice{}, "invoice")
// the following call should panic :
db.CreateTables()
}
func TestSingleColumnKeyDbReturnsZeroRowsUpdatedOnPKChange(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
dbmap.AddTableWithName(SingleColumnTable{}, "single_column_table").SetKeys(false, "SomeId")
err := dbmap.DropTablesIfExists()
if err != nil {
t.Error("Drop tables failed")
}
err = dbmap.CreateTablesIfNotExists()
if err != nil {
t.Error("Create tables failed")
}
err = dbmap.TruncateTables()
if err != nil {
t.Error("Truncate tables failed")
}
sct := SingleColumnTable{
SomeId: "A Unique Id String",
}
count, err := dbmap.Update(&sct)
if err != nil {
t.Error(err)
}
if count != 0 {
t.Errorf("Expected 0 updated rows, got %d", count)
}
}
func TestPrepare(t *testing.T) {
dbmap := initDbMap()
defer dropAndClose(dbmap)
inv1 := &Invoice{0, 100, 200, "prepare-foo", 0, false}
inv2 := &Invoice{0, 100, 200, "prepare-bar", 0, false}
_insert(dbmap, inv1, inv2)
bindVar0 := dbmap.Dialect.BindVar(0)
bindVar1 := dbmap.Dialect.BindVar(1)
stmt, err := dbmap.Prepare(fmt.Sprintf("UPDATE invoice_test SET Memo=%s WHERE Id=%s", bindVar0, bindVar1))
if err != nil {
t.Error(err)
}
defer stmt.Close()
_, err = stmt.Exec("prepare-baz", inv1.Id)
if err != nil {
t.Error(err)
}
err = dbmap.SelectOne(inv1, "SELECT * from invoice_test WHERE Memo='prepare-baz'")
if err != nil {
t.Error(err)
}
trans, err := dbmap.Begin()
if err != nil {
t.Error(err)
}
transStmt, err := trans.Prepare(fmt.Sprintf("UPDATE invoice_test SET IsPaid=%s WHERE Id=%s", bindVar0, bindVar1))
if err != nil {
t.Error(err)
}
defer transStmt.Close()
_, err = transStmt.Exec(true, inv2.Id)
if err != nil {
t.Error(err)
}
err = dbmap.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE IsPaid=%s", bindVar0), true)
if err == nil || err != sql.ErrNoRows {
t.Error("SelectOne should have returned an sql.ErrNoRows")
}
err = trans.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE IsPaid=%s", bindVar0), true)
if err != nil {
t.Error(err)
}
err = trans.Commit()
if err != nil {
t.Error(err)
}
err = dbmap.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE IsPaid=%s", bindVar0), true)
if err != nil {
t.Error(err)
}
}
func BenchmarkNativeCrud(b *testing.B) {
b.StopTimer()
dbmap := initDbMapBench()
defer dropAndClose(dbmap)
b.StartTimer()
insert := "insert into invoice_test (Created, Updated, Memo, PersonId) values (?, ?, ?, ?)"
sel := "select Id, Created, Updated, Memo, PersonId from invoice_test where Id=?"
update := "update invoice_test set Created=?, Updated=?, Memo=?, PersonId=? where Id=?"
delete := "delete from invoice_test where Id=?"
inv := &Invoice{0, 100, 200, "my memo", 0, false}
for i := 0; i < b.N; i++ {
res, err := dbmap.Db.Exec(insert, inv.Created, inv.Updated,
inv.Memo, inv.PersonId)
if err != nil {
panic(err)
}
newid, err := res.LastInsertId()
if err != nil {
panic(err)
}
inv.Id = newid
row := dbmap.Db.QueryRow(sel, inv.Id)
err = row.Scan(&inv.Id, &inv.Created, &inv.Updated, &inv.Memo,
&inv.PersonId)
if err != nil {
panic(err)
}
inv.Created = 1000
inv.Updated = 2000
inv.Memo = "my memo 2"
inv.PersonId = 3000
_, err = dbmap.Db.Exec(update, inv.Created, inv.Updated, inv.Memo,
inv.PersonId, inv.Id)
if err != nil {
panic(err)
}
_, err = dbmap.Db.Exec(delete, inv.Id)
if err != nil {
panic(err)
}
}
}
func BenchmarkGorpCrud(b *testing.B) {
b.StopTimer()
dbmap := initDbMapBench()
defer dropAndClose(dbmap)
b.StartTimer()
inv := &Invoice{0, 100, 200, "my memo", 0, true}
for i := 0; i < b.N; i++ {
err := dbmap.Insert(inv)
if err != nil {
panic(err)
}
obj, err := dbmap.Get(Invoice{}, inv.Id)
if err != nil {
panic(err)
}
inv2, ok := obj.(*Invoice)
if !ok {
panic(fmt.Sprintf("expected *Invoice, got: %v", obj))
}
inv2.Created = 1000
inv2.Updated = 2000
inv2.Memo = "my memo 2"
inv2.PersonId = 3000
_, err = dbmap.Update(inv2)
if err != nil {
panic(err)
}
_, err = dbmap.Delete(inv2)
if err != nil {
panic(err)
}
}
}
func initDbMapBench() *DbMap {
dbmap := newDbMap()
dbmap.Db.Exec("drop table if exists invoice_test")
dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id")
err := dbmap.CreateTables()
if err != nil {
panic(err)
}
return dbmap
}
func initDbMap() *DbMap {
dbmap := newDbMap()
dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id")
dbmap.AddTableWithName(InvoiceTag{}, "invoice_tag_test").SetKeys(true, "myid")
dbmap.AddTableWithName(AliasTransientField{}, "alias_trans_field_test").SetKeys(true, "id")
dbmap.AddTableWithName(OverriddenInvoice{}, "invoice_override_test").SetKeys(false, "Id")
dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id")
dbmap.AddTableWithName(WithIgnoredColumn{}, "ignored_column_test").SetKeys(true, "Id")
dbmap.AddTableWithName(IdCreated{}, "id_created_test").SetKeys(true, "Id")
dbmap.AddTableWithName(TypeConversionExample{}, "type_conv_test").SetKeys(true, "Id")
dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "Id")
dbmap.AddTableWithName(WithEmbeddedStructBeforeAutoincrField{}, "embedded_struct_before_autoincr_test").SetKeys(true, "Id")
dbmap.AddTableWithName(WithEmbeddedAutoincr{}, "embedded_autoincr_test").SetKeys(true, "Id")
dbmap.AddTableWithName(WithTime{}, "time_test").SetKeys(true, "Id")
dbmap.TypeConverter = testTypeConverter{}
err := dbmap.DropTablesIfExists()
if err != nil {
panic(err)
}
err = dbmap.CreateTables()
if err != nil {
panic(err)
}
// See #146 and TestSelectAlias - this type is mapped to the same
// table as IdCreated, but includes an extra field that isn't in the table
dbmap.AddTableWithName(IdCreatedExternal{}, "id_created_test").SetKeys(true, "Id")
return dbmap
}
func initDbMapNulls() *DbMap {
dbmap := newDbMap()
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
dbmap.AddTable(TableWithNull{}).SetKeys(false, "Id")
err := dbmap.CreateTables()
if err != nil {
panic(err)
}
return dbmap
}
func newDbMap() *DbMap {
dialect, driver := dialectAndDriver()
dbmap := &DbMap{Db: connect(driver), Dialect: dialect}
dbmap.TraceOn("", log.New(os.Stdout, "gorptest: ", log.Lmicroseconds))
return dbmap
}
func dropAndClose(dbmap *DbMap) {
dbmap.DropTablesIfExists()
dbmap.Db.Close()
}
func connect(driver string) *sql.DB {
dsn := os.Getenv("GORP_TEST_DSN")
if dsn == "" {
panic("GORP_TEST_DSN env variable is not set. Please see README.md")
}
db, err := sql.Open(driver, dsn)
if err != nil {
panic("Error connecting to db: " + err.Error())
}
return db
}
func dialectAndDriver() (Dialect, string) {
switch os.Getenv("GORP_TEST_DIALECT") {
case "mysql":
return MySQLDialect{"InnoDB", "UTF8"}, "mymysql"
case "gomysql":
return MySQLDialect{"InnoDB", "UTF8"}, "mysql"
case "postgres":
return PostgresDialect{}, "postgres"
case "sqlite":
return SqliteDialect{}, "sqlite3"
}
panic("GORP_TEST_DIALECT env variable is not set or is invalid. Please see README.md")
}
func _insert(dbmap *DbMap, list ...interface{}) {
err := dbmap.Insert(list...)
if err != nil {
panic(err)
}
}
func _update(dbmap *DbMap, list ...interface{}) int64 {
count, err := dbmap.Update(list...)
if err != nil {
panic(err)
}
return count
}
func _del(dbmap *DbMap, list ...interface{}) int64 {
count, err := dbmap.Delete(list...)
if err != nil {
panic(err)
}
return count
}
func _get(dbmap *DbMap, i interface{}, keys ...interface{}) interface{} {
obj, err := dbmap.Get(i, keys...)
if err != nil {
panic(err)
}
return obj
}
func selectInt(dbmap *DbMap, query string, args ...interface{}) int64 {
i64, err := SelectInt(dbmap, query, args...)
if err != nil {
panic(err)
}
return i64
}
func selectNullInt(dbmap *DbMap, query string, args ...interface{}) sql.NullInt64 {
i64, err := SelectNullInt(dbmap, query, args...)
if err != nil {
panic(err)
}
return i64
}
func selectFloat(dbmap *DbMap, query string, args ...interface{}) float64 {
f64, err := SelectFloat(dbmap, query, args...)
if err != nil {
panic(err)
}
return f64
}
func selectNullFloat(dbmap *DbMap, query string, args ...interface{}) sql.NullFloat64 {
f64, err := SelectNullFloat(dbmap, query, args...)
if err != nil {
panic(err)
}
return f64
}
func selectStr(dbmap *DbMap, query string, args ...interface{}) string {
s, err := SelectStr(dbmap, query, args...)
if err != nil {
panic(err)
}
return s
}
func selectNullStr(dbmap *DbMap, query string, args ...interface{}) sql.NullString {
s, err := SelectNullStr(dbmap, query, args...)
if err != nil {
panic(err)
}
return s
}
func _rawexec(dbmap *DbMap, query string, args ...interface{}) sql.Result {
res, err := dbmap.Exec(query, args...)
if err != nil {
panic(err)
}
return res
}
func _rawselect(dbmap *DbMap, i interface{}, query string, args ...interface{}) []interface{} {
list, err := dbmap.Select(i, query, args...)
if err != nil {
panic(err)
}
return list
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment