Commit 3496093b by Bogdan Ungureanu

GPSD First release

parents
all:
GOPATH=$(shell pwd) go build
\ No newline at end of file
[server]
port=8080
httpdocs=/home/bogdan/Projects/gpsd/public_html/
[database]
hostname=46.102.175.70
database=gpsd
username=gpsd
password=IzpZsFjCxeiKWndOu90w
package main
import (
"bufio"
"os"
)
func parse_logfile(file string) error {
fd, err := os.Open(file)
if err != nil {
return err
}
defer fd.Close()
scanner := bufio.NewScanner(fd)
scanner.Split(bufio.ScanLines)
adapter := &TK103{}
adapter.onPing = parseonPing
for scanner.Scan() {
line := scanner.Text()
line = line[44 : len(line)-1]
if line[:7] == "Message" {
data := line[9:]
adapter.Handle(data)
}
}
return nil
}
func parseonPing(devstring string, event EventData) {
if event.Active {
if event.Latitude > 0 && event.Longitude > 0 {
logdata := &Log{
Time: event.Time,
Did: 1,
Active: true,
Latitude: event.Latitude,
Longitude: event.Longitude,
Speed: int(event.Speed),
Angle: event.Angle,
}
server.db.Create(&logdata)
}
}
}
body {
padding-top: 70px;
}
#map-canvas {
border: 1px solid #000;
width: 100%;
height: 1024px;
}
.tmap {
width: 640px;
height: 480px;
margin: 0;
padding: 0;
}
\ No newline at end of file
{{ define "layout" }}
<!DOCTYPE html>
<head>
<title>{{ .Title }}</title>
{{range $css_link := .CSS}}
<link rel="stylesheet" href="{{ $css_link }}">{{ end }}
{{range $js_link := .JS}}
<script type="text/javascript" src="{{ $js_link }}"></script>{{ end }}
</head>
<body>
{{ template "content" .Content }}
</body>
</html>
{{ end }}
{{ define "content" }} {{ end}}
\ No newline at end of file
{{ define "content" }}
<!-- Navigation -->
<nav class="navbar navbar-inverse navbar-fixed-top" role="navigation">
<div class="container">
<!-- Brand and toggle get grouped for better mobile display -->
<div class="navbar-header">
<button type="button" class="navbar-toggle" data-toggle="collapse" data-target="#bs-example-navbar-collapse-1">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a class="navbar-brand" href="#">Start Bootstrap</a>
</div>
<!-- Collect the nav links, forms, and other content for toggling -->
<div class="collapse navbar-collapse" id="bs-example-navbar-collapse-1">
<ul class="nav navbar-nav">
<li>
<a href="#">About</a>
</li>
<li>
<a href="#">Services</a>
</li>
<li>
<a href="/logout">Logout</a>
</li>
</ul>
</div>
<!-- /.navbar-collapse -->
</div>
<!-- /.container -->
</nav>
<div class="container">
<div class="row">
<div id="map-canvas"></div>
</div>
<div class="row">
<select name="date" id="event-map">
{{range $date := .LogEvents}} <option value="{{ $date }}">{{ $date }}</option> {{end}}
</select>
<button type="button" class="btn btn-success" onclick="show_track(1)">Track</button>
</div>
<script type="text/javascript">
var map;
var poly;
var bounds ;
function init() {
var mapCanvas = document.getElementById('map-canvas')
var mapOptions = {
center: new google.maps.LatLng(44.4333333, 26.1),
zoom: 7,
mapTypeId: google.maps.MapTypeId.ROADMAP,
};
map = new google.maps.Map(mapCanvas,mapOptions);
var polyOptions = {
strokeColor: '#000000',
strokeOpacity: 1.0,
strokeWeight: 3
};
poly = new google.maps.Polyline(polyOptions);
bounds = new google.maps.LatLngBounds();
poly.setMap(map);
//track(1);
/*
var flightPlanCoordinates = [
new google.maps.LatLng(37.772323, -122.214897),
new google.maps.LatLng(21.291982, -157.821856),
new google.maps.LatLng(-18.142599, 178.431),
new google.maps.LatLng(-27.46758, 153.027892)
];
var flightPath = new google.maps.Polyline({
path: flightPlanCoordinates,
geodesic: true,
strokeColor: '#FF0000',
strokeOpacity: 1.0,
strokeWeight: 2
});
flightPath.setMap(map);
}
*/
}
function show_track(id) {
var event = document.getElementById('event-map')
//var x = document.getElementById("mySelect").selectedIndex;
//var y = document.getElementById("mySelect").options;
//alert("Index: " + y[x].index + " is " + y[x].text);
track(id,event.options[event.selectedIndex].value)
}
function track (id,events) {
$.ajax({
url: "/track/" + id + "/" + events,
dataType: "json",
// if the data is succesfully found
success: function (data) {
var path = poly.getPath();
// Reset poly and bounds
for (var i = 0; i < path.length; i++) {
path.pop();
}
bounds = new google.maps.LatLngBounds();
for (i = 0, len = data.length; i < len; ++i) {
//console.log('got' + data[i].Latitude + data[i].Longitude);
var LatLng = new google.maps.LatLng(data[i].Latitude , data[i].Longitude);
path.push(LatLng);
bounds.extend(LatLng);
/*
var marker = new google.maps.Marker({
position: LatLng,
title: '#' + data[i].Id + "-"+ path.getLength() + " speed "+ data[i].Speed + "km/h",
map: map
});
*/
}
map.fitBounds(bounds);
}
})
}
google.maps.event.addDomListener(window, 'load', init);
</script>
<!-- /.container -->
{{ end}}
{{ define "content" }}
<div class="container">
<div id="loginbox" style="margin-top:50px;" class="mainbox col-md-6 col-md-offset-3 col-sm-8 col-sm-offset-2">
<div class="panel panel-info" >
<div class="panel-heading">
<div class="panel-title">Sign In</div>
<div style="float:right; font-size: 80%; position: relative; top:-10px"><a href="#">Forgot password?</a></div>
</div>
<div style="padding-top:30px" class="panel-body" >
<div style="display:none" id="login-alert" class="alert alert-danger col-sm-12"></div>
<form id="loginform" class="form-horizontal" role="form" method="POST" action="/login">
<div style="margin-bottom: 25px" class="input-group">
<span class="input-group-addon"><i class="glyphicon glyphicon-user"></i></span>
<input id="login-username" type="text" class="form-control" name="username" value="" placeholder="username or email">
</div>
<div style="margin-bottom: 25px" class="input-group">
<span class="input-group-addon"><i class="glyphicon glyphicon-lock"></i></span>
<input id="login-password" type="password" class="form-control" name="password" placeholder="password">
</div>
<div class="input-group">
<div class="checkbox">
<label>
<input id="login-remember" type="checkbox" name="remember" value="1"> Remember me
</label>
</div>
</div>
<div style="margin-top:10px" class="form-group">
<div class="col-sm-12 controls">
<button class="btn btn-success" type="submit">Login</button>
</div>
</div>
</form>
</div>
</div>
</div>
</div>
{{ end }}
\ No newline at end of file
package main
import (
_ "database/driver/mysql"
"database/orm/gorm"
"flag"
"net"
"net/httpd"
"os"
"os/daemon"
log "os/logger"
"sync"
"time"
)
var (
server = &GPSServer{
c: &GPSServerConfig{ // Default config
cfgfile: "/etc/gpsd.conf",
db_host: "localhost",
db_port: 3306,
db_db: "gpsd",
db_user: "root",
db_pass: "",
http_root: "public_html/",
http_port: 8080,
},
clients: make(map[string]*TDeviceInfo),
}
html = NewHTML()
sidname = "sid"
sidexpire = 1 * time.Hour
logfd *os.File
)
type (
// GPS Server - handle client Connections
GPSServer struct {
clients map[string]*TDeviceInfo
c *GPSServerConfig // server_config.go
parselog string
logfile string
db gorm.DB
mutex sync.Mutex
wait sync.WaitGroup
shutdown bool
}
// GPS Logger - handle data logging
EventData struct {
Active bool
AlarmCode int
AlarmMsg string
Time time.Time
Latitude float64
Longitude float64
Speed float64
Miles int64
Angle float64
Power bool // GPS Power
Acc bool // Motor Ignition
}
// Some Handlers
LoginHandler func(string) bool
EventHandler func(string, EventData)
// Tracking device - GPS
TDeviceInfo struct {
Id string
IpAddr string
LastSeen *time.Time // Last seen online
adapter DeviceAdapter
}
// SQL database Users
User struct {
Id int32 `sql:"type:int AUTO_INCREMENT",gorm:"primary_key"`
Name string `sql:"type:varchar(30)"`
Password string `sql:"type:varchar(60)"`
}
// SQL database Devices
Device struct {
Id int64 `sql:"type:int AUTO_INCREMENT;",gorm:"primary_key"`
Uid int64 `sql:"type:int"`
Imei string `sql:"type:varchar(30)"`
}
Presence struct {
id int64
time int64
}
// SQL database Logs
Log struct {
Id int64 `sql:"type:int AUTO_INCREMENT;",gorm:"primary_key"`
Time time.Time `sql:"DEFAULT:NULL"`
Did int64 `sql:"type:int;not null;"`
Active bool `sql:"type:boolean"`
Latitude float64 `sql:"type:float(9,6);not null;"`
Longitude float64 `sql:"type:float(9,6);not null;"`
Speed int `sql:"type:int;not null"`
Angle float64 `sql:"type:float(9,6);not null;"`
}
Session struct {
user string
}
)
func init() {
log.SetLevel(log.DebugLevel)
daemon.Setup(&daemon.Config{
//Uid: 1000,
//Gid: 1000,
//Chroot: "/home/gspd/",
Pidfile: "/var/run/gpsd.pid",
IntHandler: statsHandler,
//TermHandler: signals,
//QuitHandler: signals,
//Usr1Handler: signals,
//Usr2Handler: signals,
HupHandler: hupHandler,
})
flag.StringVar(&server.parselog, "p", "", "Description")
flag.Parse()
if server.parselog == "" {
err := daemon.Background()
if err != nil {
log.Error(err)
}
logfd, err := log.OpenLog("/var/log/gpsd.log")
if err != nil {
log.Errorf("Open log file error: %s", err)
return
}
log.SetOutput(logfd)
}
}
// Process SIGHUP Handler
func hupHandler(sig int) bool {
// Cancel standard behavior
return true
}
// Process SIGINT Handler
func statsHandler(sig int) bool {
log.Info("SIGINT received")
server.Shutdown()
// Cancel standard behavior
server.wait.Wait()
return false
}
func main() {
err := server.c.LoadConfig()
if err != nil {
log.Fatalf("LoadConfig error : %s", err)
}
err = server.InitDatabase()
if err != nil {
log.Fatalf("Database config error : %s", err)
}
if server.parselog != "" {
err := parse_logfile(server.parselog)
if err != nil {
log.Errorf("parse_logfile error: %s", err)
}
return
}
// Start Listening on Ports
err = server.ListenTCP(":9090")
if err != nil {
log.Fatalf("ListenTCP error : %s", err)
}
router, err := gpsd_router()
if err != nil {
log.Error(err)
return
}
err = httpd.ListenAndServe(server.c.httpAddr(), router)
if err != nil {
log.Error(err)
return
}
}
// Server Listen TCP
func (s *GPSServer) ListenTCP(addr string) error {
l, err := net.Listen("tcp", addr)
if err != nil {
log.Error("Error listening: %s", err)
return err
}
log.Info("GPS service listening on " + addr)
go func(l net.Listener, addr string) {
s.wait.Add(1)
for {
// Listen for an incoming connection.
l.(*net.TCPListener).SetDeadline(time.Now().Add(100 * time.Millisecond))
conn, err := l.Accept()
// Shutdown request ?
if server.shutdown == true {
l.Close()
log.Infof("Shutdown listener socket %s", addr)
s.wait.Done()
return
}
if err != nil {
netErr, ok := err.(net.Error)
if ok && netErr.Timeout() && netErr.Temporary() {
continue
}
return
}
go server.handleTCP(conn)
}
}(l, addr)
return nil
}
// Handle New TCP connection
func (s *GPSServer) handleTCP(c net.Conn) {
d := &TDeviceInfo{}
for {
buf := make([]byte, 1024)
mlen, err := c.Read(buf)
if err != nil {
log.Errorf("Error reading: %s", err)
return
}
rcvd := string(buf)[:mlen]
if d.Id == "" {
log.Infof("New Connection from %s", c.RemoteAddr().String())
d, err = GetDeviceAdapter(rcvd, c)
if err != nil {
c.Close()
log.Errorf("No adapter found for device !")
return
}
// Give ServerHandlers to Device
d.onLogin(s.onLogin)
d.onAlarm(s.onAlarm)
d.onPing(s.onPing)
// Save to clients list
s.mutex.Lock()
s.clients[d.Id] = d
s.mutex.Unlock()
}
// Make sure addapter can Handle message
err = d.Handle(rcvd)
if err != nil {
log.Errorf("Unable to parse message %s", rcvd)
c.Close()
return
}
//
}
}
func (s *GPSServer) sqlDevice(device string) *Device {
devsql := &Device{}
server.db.Where("imei = ? ", device).First(&devsql)
if devsql.Id == 0 {
devsql.Imei = device
devsql.Uid = 1 // Default
server.db.Create(&devsql)
server.db.Where("imei = ? ", device).First(&devsql)
}
return devsql
}
// On device Login Request
func (s *GPSServer) onLogin(device string) bool {
return true
}
// On device Allarm Request
func (s *GPSServer) onAlarm(device string, event EventData) {
}
// Normal GPS Data
func (s *GPSServer) onPing(device string, event EventData) {
devsql := s.sqlDevice(device)
if event.Active {
// Send log to Database
logdata := &Log{
Time: event.Time,
Did: devsql.Id,
Active: true,
Latitude: event.Latitude,
Longitude: event.Longitude,
Speed: int(event.Speed),
Angle: event.Angle,
}
server.db.Create(&logdata)
}
}
// Shutdown Server
func (s *GPSServer) Shutdown() {
log.Info("Shutdown request via SIGTERM")
s.shutdown = true
s.wait.Wait()
}
// Initialize Database
func (s *GPSServer) InitDatabase() error {
db, err := gorm.Open("mysql", s.c.DbString())
if err != nil {
return err
}
s.db = db
db.AutoMigrate(&User{}, &Device{}, &Log{})
// Multiple column index
//db.Model(&User{}).AddIndex("idx_user_name_age", "name", "age")
return nil
}
package main
import (
"net"
log "os/logger"
"strings"
)
type (
//
// Interface for GPS adapters
//
DeviceAdapter interface {
Handle(string) error
Send(string, string)
OnLogin(LoginHandler)
OnAlarm(EventHandler)
OnPing(EventHandler)
}
// Implement DeviceAdapter interface
TK103 struct {
DeviceId string
c net.Conn
onLogin LoginHandler
onAlarm EventHandler
onPing EventHandler
}
)
// Handles incoming requests.
func (a *TK103) Handle(rcvd string) error {
cmd_start := strings.Index(rcvd, "B")
if cmd_start > 13 {
log.Error("DeviceId too long")
return nil
}
a.DeviceId = rcvd[1:cmd_start]
cmd := rcvd[cmd_start : cmd_start+4]
switch cmd {
case "BP05": // Login Request
if a.onLogin != nil {
if a.onLogin(a.DeviceId) == true {
a.Send("AP05", "")
}
return nil
}
// Auto accept login
a.Send("AP05", "")
case "BP00": // Handshake Request
a.Send("AP01", "HSO")
case "BR00": // Ping Request
res := a._parse(rcvd[cmd_start+4 : len(rcvd)-1])
if a.onPing != nil {
a.onPing(a.DeviceId, *res)
}
case "BO01": // Alarm Request
alrm := rcvd[cmd_start+4 : cmd_start+5]
res := a._parse(rcvd[cmd_start+5 : len(rcvd)-1])
code := -1
message := ""
switch alrm {
case "0":
code = 0
message = "Vehicle Power Off"
case "1":
code = 1
message = "The vehicle suffers an acciden"
case "2":
code = 2
message = "Driver sends a S.O.S."
case "3":
code = 3
message = "The alarm of the vehicle is activated"
case "4":
code = 4
message = "Vehicle is below the min speed setted"
case "5":
code = 5
message = "Vehicle is over the max speed setted"
case "6":
code = 6
message = "Out of geo fence"
}
res.AlarmCode = code
res.AlarmMsg = message
if a.onAlarm != nil {
a.onAlarm(a.DeviceId, *res)
}
a.Send("AS01", alrm)
default:
log.Infof("Unknown message from %s", a.DeviceId)
log.Printf("Message :%s", rcvd)
}
return nil
}
// Setup Login Handler
func (a *TK103) OnLogin(handler LoginHandler) {
a.onLogin = handler
}
// Setup Alarm Hanlder
func (a *TK103) OnAlarm(handler EventHandler) {
a.onAlarm = handler
}
// Setup Ping Handler
func (a *TK103) OnPing(handler EventHandler) {
a.onPing = handler
}
func (tk *TK103) _parse(data string) *EventData {
yy := data[0:2] // YYMMDD - Year
mm := data[2:4] // YYMMDD - Month
dd := data[4:6] // YYMMDD - Month
active := data[6:7] // The availability of GPS data
latstr := data[7:9] // Latitude - Minute
latmstr := data[9:16] // latitude - Seconds
latind := data[16:17] // Latitude indicator "N" or "S"
longstr := data[17:20] // Longitue - Minute
longmstr := data[20:27] // Longitude - Seconds
longind := data[27:28] // Longture indicator "E" or "V"
speedstr := data[28:33] // The unit is km/h
timeh := data[33:35] //HHMMSS - hours
timem := data[35:37] //HHMMSS - minutes
times := data[37:39] //HHMMSS - seconds
anglestr := data[39:44] // Orientation
//io := data[44:52] // IO State
//mil := data[52:53] // Milepost
//milstr := data[53:61] // Km
// Process Latitude
latitude, err := ToDecimalDegrees(latstr, latmstr)
if err != nil {
log.Errorf("Val: %s'%s,Err: %s", latstr, latmstr, err)
}
if latind == "S" {
latitude = -latitude
}
// Process Longitude
longitude, err := ToDecimalDegrees(longstr, longmstr)
if err != nil {
log.Errorf("Val: %s'%s,Err: %s", longstr, longmstr, err)
}
if longind == "W" {
longitude = -longitude
}
// Process Speed
speed, err := ToDecimal(speedstr)
if err != nil {
log.Errorf("Val: %s,Err: %s", speedstr, err)
}
// Process Time
time, err := FormatTimeShort(yy, mm, dd, timeh, timem, times)
if err != nil {
log.Errorf("Val: 20%s/%s/%s %s:%s:%s, Err: %s",
yy, mm, dd, timeh, timem, times, err)
}
// Process Orientation
angle, err := ToDecimal(anglestr)
if err != nil {
log.Errorf("Val: %s,Err: %s", anglestr, err)
}
return &EventData{
Active: (active == "A"),
Time: time,
Latitude: latitude,
Longitude: longitude,
Speed: speed,
Angle: angle,
}
}
func (tk *TK103) Send(cmd, data string) {
message := "(" + strings.Join([]string{tk.DeviceId, cmd, data}, "") + ")"
_, err := tk.c.Write([]byte(message))
if err != nil {
log.Error("Failed to send cmd " + cmd + " to " + tk.DeviceId)
}
}
package main
import (
"fmt"
"util/cfg"
)
//
// Server Configuration
//
type GPSServerConfig struct { // GPS Server Config
cfgfile string
db_user string
db_pass string
db_host string
db_port int
db_db string
http_addr string
http_port int
http_root string
}
// Load and parse config File
func (c *GPSServerConfig) LoadConfig() error {
conf, err := cfg.ReadFile(c.cfgfile)
if err != nil {
return err
}
if conf.HasOption("database", "hostname") {
c.db_host, err = conf.GetString("database", "hostname")
}
if conf.HasOption("database", "database") {
c.db_db, err = conf.GetString("database", "database")
}
if conf.HasOption("database", "username") {
c.db_user, err = conf.GetString("database", "username")
}
if conf.HasOption("database", "password") {
c.db_pass, err = conf.GetString("database", "password")
}
if conf.HasOption("server", "address") {
c.http_addr, err = conf.GetString("server", "address")
}
if conf.HasOption("server", "httpdocs") {
c.http_root, err = conf.GetString("server", "httpdocs")
}
return nil
}
func (c *GPSServerConfig) httpAddr() string {
return fmt.Sprintf("%s:%d", c.http_addr, c.http_port)
}
func (c *GPSServerConfig) DbString() string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8&parseTime=true",
c.db_user,
c.db_pass,
c.db_host,
c.db_port,
c.db_db,
)
}
package main
import (
"errors"
"net"
)
// Device Helpers
func (d *TDeviceInfo) Handle(msg string) error {
return d.adapter.Handle(msg)
}
func (d *TDeviceInfo) onLogin(handler LoginHandler) {
d.adapter.OnLogin(handler)
}
func (d *TDeviceInfo) onAlarm(handler EventHandler) {
d.adapter.OnAlarm(handler)
}
func (d *TDeviceInfo) onPing(handler EventHandler) {
d.adapter.OnPing(handler)
}
// Define all Known Parsers
func GetDeviceAdapter(msg string, conn net.Conn) (*TDeviceInfo, error) {
// TK102
if msg[:1] == "(" && msg[len(msg)-1:] == ")" {
if len(msg) > 13 && msg[13:14] == "B" {
device := &TDeviceInfo{
Id: msg[1:13],
IpAddr: conn.RemoteAddr().String(),
adapter: &TK103{c: conn},
}
return device, nil
}
}
return nil, errors.New("GetDeviceAdapter: No Adapter found")
}
package main
import (
"net/httpd"
"net/httpd/session"
log "os/logger"
"time"
)
func doLogin(username, password string) (*Session, bool) {
user := &User{}
server.db.Where("name = ? and password = ?", username, password).First(&user)
log.Printf("db : %x ", user)
if user.Id == 0 {
return nil, false
} else {
return &Session{
user: user.Name,
}, true
}
}
func AuthHandler(next httpd.Handler) httpd.Handler {
return httpd.HandlerFunc(func(c *httpd.Context) {
log.Printf("USER auth handler requested !")
s, err := c.Request.Cookie(sidname)
if err != nil {
//Setup new Session
ses := session.New()
session.Set(ses)
c.SetCookie(sidname, ses.Id(), sidexpire)
c.Redirect("/login", 302)
return
}
ses := session.Get(s.Value)
if ses == nil {
ses := session.New()
session.Set(ses)
c.SetCookie(sidname, ses.Id(), sidexpire)
c.Redirect("/login", 302)
return
}
// Handle login
if c.RequestURI == "/login" && c.RequestMethod == "POST" {
usr := c.Request.FormValue("username")
pwd := c.Request.FormValue("password")
sdata, login := doLogin(usr, pwd)
if login == false {
c.Redirect("/login", 302)
return
}
// Store the session
ses.Data = sdata
session.Set(ses)
c.Redirect("/", 302)
return
}
if ses.Data == nil && c.RequestURI != "/login" {
c.Redirect("/login", 302)
return
}
c.Session = ses
next.Handle(c)
})
}
func indexPage(c *httpd.Context) {
data := struct {
Test string
LogEvents []string
ProxyTotal int64
}{}
//SELECT DATE_FORMAT(`time`, '%Y-%m-%d') FROM `logs` WHERE 1 GROUP BY DATE_FORMAT(`time`,'%Y%m%d')
rows, err := server.db.Raw("SELECT DATE_FORMAT(`time`, '%Y-%m-%d') FROM `logs` WHERE 1 GROUP BY DATE_FORMAT( `time` , '%Y%m%d' ) ").Rows()
if err != nil {
log.Errorf("%s", err)
return
}
defer rows.Close()
for rows.Next() {
tm := ""
rows.Scan(&tm)
data.LogEvents = append(data.LogEvents, tm)
}
// Setup new Page
page := NewPage("Real Time GPS Tracker", data)
page.AddJS("https://maps.googleapis.com/maps/api/js?key=AIzaSyBn7yp_UcuCe5U1IZHKcfyJ5wJTMS8YIIM")
page.AddJS("https://ajax.googleapis.com/ajax/libs/jquery/2.1.4/jquery.min.js")
html.Execute("index",
c.Response,
page,
)
}
func loginPage(c *httpd.Context) {
err := html.Execute(
"login",
c.Response,
NewPage("Login", nil),
)
if err != nil {
log.Error(err)
}
}
// Track device
func trackHandler(c *httpd.Context) {
var logs []Log
device := c.Param("device")
date, err := ParseDate(c.Param("date"))
if err != nil {
log.Printf("trackHandler error parsing time %s", err)
c.Error("Invalid date", 400)
return
}
log := &Log{Time: date}
server.db.Order("time asc").Where("did = ? AND time > ? AND time < ? ", device, log.Time, log.Time.Add(24*time.Hour)).Find(&logs)
c.JSON(200, logs)
}
// Logout
func logoutPage(c *httpd.Context) {
sess := c.Session.(*session.Session)
sess.Data = nil
session.Set(sess)
log.Printf("logging out from sid %s", sess.Id())
c.Redirect("/login", 302)
}
// GPSD ROUTER
func gpsd_router() (*httpd.Router, error) {
http := httpd.NewRouter()
http.BeforeHandle(AuthHandler)
http.Get("/", indexPage)
http.Get("/track/:device/:date", trackHandler)
http.Get("/login", loginPage)
http.Get("/logout", logoutPage)
// Assets Handlers
assets := http.Subrouter("/assets")
/*
pwd, err := os.Getwd()
if err != nil {
//fmt.Println(err)
os.Exit(1)
}
*/
assets.ServeFiles("/*filepath", server.c.http_root+"assets/")
return http, nil
}
package main
import (
// log "os/logger"
"strconv"
"strings"
"time"
)
// Speed to Deciamal
// speed = float(speed)*1.852
//strconv.FormatFloat(input_num, 'f', 6, 64)
// Degree to Decimal
func ToDecimalDegrees(minute, seconds string) (float64, error) {
min, err := strconv.ParseFloat(minute, 64)
if err != nil {
return float64(0), err
}
sec, err := strconv.ParseFloat(seconds, 64)
if err != nil {
return min, err
}
// Format to 6 char precision
format := strconv.FormatFloat(min+sec/60, 'f', 6, 64)
decimal, err := strconv.ParseFloat(format, 64)
if err != nil {
return min, err
}
return decimal, nil
}
// Convert To Decimal
func ToDecimal(speed string) (float64, error) {
s, err := strconv.ParseFloat(speed, 64)
if err != nil {
return float64(0), err
}
return s, nil
}
func ParseDate(date string) (time.Time, error) {
dstr := strings.Split(date, "-")
loc, err := time.LoadLocation("Europe/Bucharest")
if err != nil {
return time.Now(), err
}
year, err := strconv.ParseInt(dstr[0], 10, 64)
if err != nil {
//return , err
}
month, err := strconv.ParseInt(dstr[1], 10, 64)
if err != nil {
//return , err
}
day, err := strconv.ParseInt(dstr[2], 10, 64)
if err != nil {
//return , err
}
return time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, loc), nil
}
// Format Time
func FormatTimeShort(yy, mm, dd, h, m, s string) (time.Time, error) {
year, err := strconv.ParseInt("20"+yy, 10, 64)
if err != nil {
//return , err
}
month, err := strconv.ParseInt(mm, 10, 64)
if err != nil {
//return , err
}
day, err := strconv.ParseInt(dd, 10, 64)
if err != nil {
//return , err
}
hour, err := strconv.ParseInt(h, 10, 64)
if err != nil {
//return , err
}
minute, err := strconv.ParseInt(m, 10, 64)
if err != nil {
//return , err
}
second, err := strconv.ParseInt(s, 10, 64)
if err != nil {
//return , err
}
t := time.Date(
int(year),
time.Month(int(month)),
int(day),
int(hour),
int(minute),
int(second),
0,
time.UTC).Add(time.Hour)
return t, nil
}
## Version 1.1 (2013-11-02)
Changes:
- Go-MySQL-Driver now requires Go 1.1
- Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore
- Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors
- `byte(nil)` is now treated as a NULL value. Before, it was treated like an empty string / `[]byte("")`
- DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'.
- Use the IO buffer also for writing. This results in zero allocations (by the driver) for most queries
- Optimized the buffer for reading
- stmt.Query now caches column metadata
- New Logo
- Changed the copyright header to include all contributors
- Improved the LOAD INFILE documentation
- The driver struct is now exported to make the driver directly accessible
- Refactored the driver tests
- Added more benchmarks and moved all to a separate file
- Other small refactoring
New Features:
- Added *old_passwords* support: Required in some cases, but must be enabled by adding `allowOldPasswords=true` to the DSN since it is insecure
- Added a `clientFoundRows` parameter: Return the number of matching rows instead of the number of rows changed on UPDATEs
- Added TLS/SSL support: Use a TLS/SSL encrypted connection to the server. Custom TLS configs can be registered and used
Bugfixes:
- Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification
- Convert to DB timezone when inserting `time.Time`
- Splitted packets (more than 16MB) are now merged correctly
- Fixed false positive `io.EOF` errors when the data was fully read
- Avoid panics on reuse of closed connections
- Fixed empty string producing false nil values
- Fixed sign byte for positive TIME fields
## Version 1.0 (2013-05-14)
Initial Release
# Contributing Guidelines
## Reporting Issues
Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/go-sql-driver/mysql/issues?state=open) or was [recently closed](https://github.com/go-sql-driver/mysql/issues?direction=desc&page=1&sort=updated&state=closed).
Please provide the following minimum information:
* Your Go-MySQL-Driver version (or git SHA)
* Your Go version (run `go version` in your console)
* A detailed issue description
* Error Log if present
* If possible, a short example
## Contributing Code
By contributing to this project, you share your code under the Mozilla Public License 2, as specified in the LICENSE file.
Don't forget to add yourself to the AUTHORS file.
### Pull Requests Checklist
Please check the following points before submitting your pull request:
- [x] Code compiles correctly
- [x] Created tests, if possible
- [x] All tests pass
- [x] Extended the README / documentation, if necessary
- [x] Added yourself to the AUTHORS file
### Code Review
Everyone is invited to review and comment on pull requests.
If it looks fine to you, comment with "LGTM" (Looks good to me).
If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes.
Before merging the Pull Request, at least one [team member](https://github.com/go-sql-driver?tab=members) must have commented with "LGTM".
## Development Ideas
If you are looking for ideas for code contributions, please check our [Development Ideas](https://github.com/go-sql-driver/mysql/wiki/Development-Ideas) Wiki page.
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.
# Go-MySQL-Driver
A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) package
![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin")
**Version 1.1** (November 02, 2013)
---------------------------------------
* [Features](#features)
* [Requirements](#requirements)
* [Installation](#installation)
* [Usage](#usage)
* [DSN (Data Source Name)](#dsn-data-source-name)
* [Password](#password)
* [Protocol](#protocol)
* [Address](#address)
* [Parameters](#parameters)
* [Examples](#examples)
* [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support)
* [time.Time support](#timetime-support)
* [Unicode support](#unicode-support)
* [Testing / Development](#testing--development)
* [License](#license)
---------------------------------------
## Features
* Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance")
* Native Go implementation. No C-bindings, just pure Go
* Connections over TCP/IPv4, TCP/IPv6 or Unix domain sockets
* Automatic handling of broken connections
* Automatic Connection Pooling *(by database/sql package)*
* Supports queries larger than 16MB
* Full [`sql.RawBytes`](http://golang.org/pkg/database/sql/#RawBytes) support.
* Intelligent `LONG DATA` handling in prepared statements
* Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support
* Optional `time.Time` parsing
## Requirements
* Go 1.1 or higher (use [v1.0](https://github.com/go-sql-driver/mysql/tags) for Go 1.0.x)
* MySQL (Version 4.1 or higher), MariaDB or Percona Server
---------------------------------------
## Installation
Simple install the package to your [$GOPATH](http://code.google.com/p/go-wiki/wiki/GOPATH "GOPATH") with the [go tool](http://golang.org/cmd/go/ "go command") from shell:
```bash
$ go get github.com/go-sql-driver/mysql
```
Make sure [Git is installed](http://git-scm.com/downloads) on your machine and in your system's `PATH`.
## Usage
_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](http://golang.org/pkg/database/sql) API then.
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
```go
import "database/sql"
import _ "github.com/go-sql-driver/mysql"
db, err := sql.Open("mysql", "user:password@/dbname")
```
[Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples").
### DSN (Data Source Name)
The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets):
```
[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...&paramN=valueN]
```
A DSN in its fullest form:
```
username:password@protocol(address)/dbname?param=value
```
Except for the databasename, all values are optional. So the minimal DSN is:
```
/dbname
```
If you do not want to preselect a database, leave `dbname` empty:
```
/
```
This has the same effect as an empty DSN string:
```
```
#### Password
Passwords can consist of any character. Escaping is **not** necessary.
#### Protocol
See [net.Dial](http://golang.org/pkg/net/#Dial) for more information which networks are available.
In general you should use an Unix domain socket if available and TCP otherwise for best performance.
#### Address
For TCP and UDP networks, addresses have the form `host:port`.
If `host` is a literal IPv6 address, it must be enclosed in square brackets.
The functions [net.JoinHostPort](http://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](http://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form.
For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`.
#### Parameters
*Parameters are case-sensitive!*
##### `allowAllFiles`
```
Type: bool
Valid Values: true, false
Default: false
```
`allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files.
[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
##### `allowOldPasswords`
```
Type: bool
Valid Values: true, false
Default: false
```
`allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords).
##### `charset`
```
Type: string
Valid Values: <name>
Default: none
```
Sets the charset used for client-server interaction (`"SET NAMES <value>"`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
##### `clientFoundRows`
```
Type: bool
Valid Values: true, false
Default: false
```
`clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
##### `loc`
```
Type: string
Valid Values: <escaped name>
Default: UTC
```
Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
##### `parseTime`
```
Type: bool
Valid Values: true, false
Default: false
```
`parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
##### `strict`
```
Type: bool
Valid Values: true, false
Default: false
```
`strict=true` enables strict mode. MySQL warnings are treated as errors.
##### `timeout`
```
Type: decimal number
Default: OS default
```
*Driver* side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
##### `tls`
```
Type: bool / string
Valid Values: true, false, skip-verify, <name>
Default: false
```
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](http://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
##### System Variables
All other parameters are interpreted as system variables:
* `autocommit`: `"SET autocommit=<value>"`
* `time_zone`: `"SET time_zone=<value>"`
* [`tx_isolation`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `"SET tx_isolation=<value>"`
* `param`: `"SET <param>=<value>"`
*The values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!*
#### Examples
```
user@unix(/path/to/socket)/dbname
```
```
root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local
```
```
user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true
```
TCP via IPv6:
```
user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s
```
TCP on a remote host, e.g. Amazon RDS:
```
id:password@tcp(your-amazonaws-uri.com:3306)/dbname
```
TCP using default port (3306) on localhost:
```
user:password@tcp/dbname&charset=utf8mb4,utf8&sys_var=esc%40ped
```
Use the default protocol (tcp) and host (localhost:3306):
```
user:password@/dbname
```
No Database preselected:
```
user:password@/
```
### `LOAD DATA LOCAL INFILE` support
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
```go
import "github.com/go-sql-driver/mysql"
```
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then.
See the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details.
### `time.Time` support
The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
Alternatively you can use the [`NullTime`](http://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`.
### Unicode support
Since version 1.1 Go-MySQL-Driver automatically uses the collation `utf8_general_ci` by default. Adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN is not necessary anymore in most cases.
See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support.
## Testing / Development
To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details.
Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated.
If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls).
See the [Contributing Guidelines](https://github.com/go-sql-driver/mysql/blob/master/CHANGELOG.md) for details.
---------------------------------------
## License
Go-MySQL-Driver is licensed under the [Mozilla Public License Version 2.0](https://raw.github.com/go-sql-driver/mysql/master/LICENSE)
Mozilla summarizes the license scope as follows:
> MPL: The copyleft applies to any files containing MPLed code.
That means:
* You can **use** the **unchanged** source code both in private as also commercial
* You **needn't publish** the source code of your library as long the files licensed under the MPL 2.0 are **unchanged**
* You **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0)
Please read the [MPL 2.0 FAQ](http://www.mozilla.org/MPL/2.0/FAQ.html) if you have further questions regarding the license.
You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE)
![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow")
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"database/sql"
"strings"
"sync"
"sync/atomic"
"testing"
)
type TB testing.B
func (tb *TB) check(err error) {
if err != nil {
tb.Fatal(err)
}
}
func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB {
tb.check(err)
return db
}
func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows {
tb.check(err)
return rows
}
func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
tb.check(err)
return stmt
}
func initDB(b *testing.B, queries ...string) *sql.DB {
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
b.Fatalf("Error on %q: %v", query, err)
}
}
return db
}
const concurrencyLevel = 10
func BenchmarkQuery(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
var got string
tb.check(stmt.QueryRow(1).Scan(&got))
if got != "one" {
b.Errorf("query = %q; want one", got)
wg.Done()
return
}
}
}()
}
}
func BenchmarkExec(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := tb.checkDB(sql.Open("mysql", dsn))
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("DO 1"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
if _, err := stmt.Exec(); err != nil {
b.Fatal(err.Error())
}
}
}()
}
}
// data, but no db writes
var roundtripSample []byte
func initRoundtripBenchmarks() ([]byte, int, int) {
if roundtripSample == nil {
roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024))
}
return roundtripSample, 16, len(roundtripSample)
}
func BenchmarkRoundtripTxt(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
sampleString := string(sample)
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
b.StartTimer()
var result string
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sampleString[0:length]
rows := tb.checkRows(db.Query(`SELECT "` + test + `"`))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if result != test {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
func BenchmarkRoundtripBin(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT ?"))
defer stmt.Close()
b.StartTimer()
var result sql.RawBytes
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sample[0:length]
rows := tb.checkRows(stmt.Query(test))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if !bytes.Equal(result, test) {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import "io"
const defaultBufSize = 4096
// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
rd io.Reader
idx int
length int
}
func newBuffer(rd io.Reader) *buffer {
var b [defaultBufSize]byte
return &buffer{
buf: b[:],
rd: rd,
}
}
// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
// move existing data to the beginning
if b.length > 0 && b.idx > 0 {
copy(b.buf[0:b.length], b.buf[b.idx:])
}
// grow buffer if necessary
// TODO: let the buffer shrink again at some point
// Maybe keep the org buf slice and swap back?
if need > len(b.buf) {
// Round up to the next multiple of the default size
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
copy(newBuf, b.buf)
b.buf = newBuf
}
b.idx = 0
for {
n, err := b.rd.Read(b.buf[b.length:])
b.length += n
if err == nil {
if b.length < need {
continue
}
return nil
}
if b.length >= need && err == io.EOF {
return nil
}
return err
}
}
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) {
if b.length < need {
// refill
if err := b.fill(need); err != nil {
return nil, err
}
}
offset := b.idx
b.idx += need
b.length -= need
return b.buf[offset:b.idx], nil
}
// returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
if b.length > 0 {
return nil
}
// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
}
return make([]byte, length)
}
// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
if b.length == 0 {
return b.buf[:length]
}
return nil
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
if b.length == 0 {
return b.buf
}
return nil
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"database/sql/driver"
"errors"
"net"
"strings"
"time"
)
type mysqlConn struct {
buf *buffer
netConn net.Conn
affectedRows uint64
insertId uint64
cfg *config
maxPacketAllowed int
maxWriteSize int
flags clientFlag
sequence uint8
parseTime bool
strict bool
}
type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
timeout time.Duration
tls *tls.Config
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
}
// Handles parameters set in DSN
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.params {
switch param {
// Charset
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// time.Time parsing
case "parseTime":
var isBool bool
mc.parseTime, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Strict mode
case "strict":
var isBool bool
mc.strict, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Compression
case "compress":
err = errors.New("Compression not implemented yet")
return
// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
if err != nil {
return
}
}
}
return
}
func (mc *mysqlConn) Begin() (driver.Tx, error) {
if mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
err := mc.exec("START TRANSACTION")
if err == nil {
return &mysqlTx{mc}, err
}
return nil, err
}
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if mc.netConn != nil {
mc.writeCommandPacket(comQuit)
mc.netConn.Close()
mc.netConn = nil
}
mc.cfg = nil
mc.buf = nil
return
}
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
return nil, err
}
stmt := &mysqlStmt{
mc: mc,
}
// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
}
}
return stmt, err
}
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
mc.affectedRows = 0
mc.insertId = 0
err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, err
}
// with args, must use prepared stmt
return nil, driver.ErrSkip
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err != nil {
return err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil && resLen > 0 {
if err = mc.readUntilEOF(); err != nil {
return err
}
err = mc.readUntilEOF()
}
return err
}
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
if mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen > 0 {
// Columns
rows.columns, err = mc.readColumns(resLen)
}
return rows, err
}
}
return nil, err
}
// with args, must use prepared stmt
return nil, driver.ErrSkip
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
}
}
return nil, err
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
const (
minProtocolVersion byte = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05"
)
// MySQL constants documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
const (
iOK byte = 0x00
iLocalInFile byte = 0xfb
iEOF byte = 0xfe
iERR byte = 0xff
)
type clientFlag uint32
const (
clientLongPassword clientFlag = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
clientNoSchema
clientCompress
clientODBC
clientLocalFiles
clientIgnoreSpace
clientProtocol41
clientInteractive
clientSSL
clientIgnoreSIGPIPE
clientTransactions
clientReserved
clientSecureConn
clientMultiStatements
clientMultiResults
)
const (
comQuit byte = iota + 1
comInitDB
comQuery
comFieldList
comCreateDB
comDropDB
comRefresh
comShutdown
comStatistics
comProcessInfo
comConnect
comProcessKill
comDebug
comPing
comTime
comDelayedInsert
comChangeUser
comBinlogDump
comTableDump
comConnectOut
comRegiserSlave
comStmtPrepare
comStmtExecute
comStmtSendLongData
comStmtClose
comStmtReset
comSetOption
comStmtFetch
)
const (
fieldTypeDecimal byte = iota
fieldTypeTiny
fieldTypeShort
fieldTypeLong
fieldTypeFloat
fieldTypeDouble
fieldTypeNULL
fieldTypeTimestamp
fieldTypeLongLong
fieldTypeInt24
fieldTypeDate
fieldTypeTime
fieldTypeDateTime
fieldTypeYear
fieldTypeNewDate
fieldTypeVarChar
fieldTypeBit
)
const (
fieldTypeNewDecimal byte = iota + 0xf6
fieldTypeEnum
fieldTypeSet
fieldTypeTinyBLOB
fieldTypeMediumBLOB
fieldTypeLongBLOB
fieldTypeBLOB
fieldTypeVarString
fieldTypeString
fieldTypeGeometry
)
type fieldFlag uint16
const (
flagNotNULL fieldFlag = 1 << iota
flagPriKey
flagUniqueKey
flagMultipleKey
flagBLOB
flagUnsigned
flagZeroFill
flagBinary
flagEnum
flagAutoIncrement
flagTimestamp
flagSet
flagUnknown1
flagUnknown2
flagUnknown3
flagUnknown4
)
const (
collation_ascii_general_ci byte = 11
collation_utf8_general_ci byte = 33
collation_utf8mb4_general_ci byte = 45
collation_utf8mb4_bin byte = 46
collation_latin1_general_ci byte = 48
collation_binary byte = 63
collation_utf8mb4_unicode_ci byte = 224
)
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "github.com/go-sql-driver/mysql"
//
// db, err := sql.Open("mysql", "user:password@/dbname")
//
// See https://github.com/go-sql-driver/mysql#usage for details
package mysql
import (
"database/sql"
"database/sql/driver"
"net"
)
// This struct is exported to make the driver directly accessible.
// In general the driver is used via the database/sql package.
type MySQLDriver struct{}
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error
// New mysqlConn
mc := &mysqlConn{
maxPacketAllowed: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
}
mc.cfg, err = parseDSN(dsn)
if err != nil {
return nil, err
}
// Connect to Server
nd := net.Dialer{Timeout: mc.cfg.timeout}
mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)
if err != nil {
return nil, err
}
mc.buf = newBuffer(mc.netConn)
// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
if err != nil {
mc.Close()
return nil, err
}
// Send Client Authentication Packet
if err = mc.writeAuthPacket(cipher); err != nil {
mc.Close()
return nil, err
}
// Read Result Packet
err = mc.readResultOK()
if err != nil {
// Retry with old authentication method, if allowed
if mc.cfg.allowOldPasswords && err == errOldPassword {
if err = mc.writeOldAuthPacket(cipher); err != nil {
mc.Close()
return nil, err
}
if err = mc.readResultOK(); err != nil {
mc.Close()
return nil, err
}
} else {
mc.Close()
return nil, err
}
}
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxPacketAllowed = stringToInt(maxap) - 1
if mc.maxPacketAllowed < maxPacketSize {
mc.maxWriteSize = mc.maxPacketAllowed
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
func init() {
sql.Register("mysql", &MySQLDriver{})
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"io/ioutil"
"net"
"net/url"
"os"
"strings"
"testing"
"time"
)
var (
dsn string
netAddr string
available bool
)
var (
tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC)
sDate = "2012-06-14"
tDateTime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)
sDateTime = "2011-11-20 21:27:37"
tDate0 = time.Time{}
sDate0 = "0000-00-00"
sDateTime0 = "0000-00-00 00:00:00"
)
// See https://github.com/go-sql-driver/mysql/wiki/Testing
func init() {
env := func(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
user := env("MYSQL_TEST_USER", "root")
pass := env("MYSQL_TEST_PASS", "")
prot := env("MYSQL_TEST_PROT", "tcp")
addr := env("MYSQL_TEST_ADDR", "localhost:3306")
dbname := env("MYSQL_TEST_DBNAME", "gotest")
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
c, err := net.Dial(prot, addr)
if err == nil {
available = true
c.Close()
}
}
type DBTest struct {
*testing.T
db *sql.DB
}
func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
if !available {
t.Skipf("MySQL-Server not running on %s", netAddr)
}
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("Error connecting: %s", err.Error())
}
defer db.Close()
db.Exec("DROP TABLE IF EXISTS test")
dbt := &DBTest{t, db}
for _, test := range tests {
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
}
}
func (dbt *DBTest) fail(method, query string, err error) {
if len(query) > 300 {
query = "[query too large to print]"
}
dbt.Fatalf("Error on %s %s: %s", method, query, err.Error())
}
func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
res, err := dbt.db.Exec(query, args...)
if err != nil {
dbt.fail("Exec", query, err)
}
return res
}
func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) {
rows, err := dbt.db.Query(query, args...)
if err != nil {
dbt.fail("Query", query, err)
}
return rows
}
func TestCRUD(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
// Create Table
dbt.mustExec("CREATE TABLE test (value BOOL)")
// Test for unexpected data
var out bool
rows := dbt.mustQuery("SELECT * FROM test")
if rows.Next() {
dbt.Error("unexpected data in empty table")
}
// Create Data
res := dbt.mustExec("INSERT INTO test VALUES (1)")
count, err := res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("Expected 1 affected row, got %d", count)
}
id, err := res.LastInsertId()
if err != nil {
dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
}
if id != 0 {
dbt.Fatalf("Expected InsertID 0, got %d", id)
}
// Read
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if true != out {
dbt.Errorf("true != %t", out)
}
if rows.Next() {
dbt.Error("unexpected data")
}
} else {
dbt.Error("no data")
}
// Update
res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true)
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("Expected 1 affected row, got %d", count)
}
// Check Update
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if false != out {
dbt.Errorf("false != %t", out)
}
if rows.Next() {
dbt.Error("unexpected data")
}
} else {
dbt.Error("no data")
}
// Delete
res = dbt.mustExec("DELETE FROM test WHERE value = ?", false)
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("Expected 1 affected row, got %d", count)
}
// Check for unexpected rows
res = dbt.mustExec("DELETE FROM test")
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 0 {
dbt.Fatalf("Expected 0 affected row, got %d", count)
}
})
}
func TestInt(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
in := int64(42)
var out int64
var rows *sql.Rows
// SIGNED
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ")")
dbt.mustExec("INSERT INTO test VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Errorf("%s: %d != %d", v, in, out)
}
} else {
dbt.Errorf("%s: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
// UNSIGNED ZEROFILL
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)")
dbt.mustExec("INSERT INTO test VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out)
}
} else {
dbt.Errorf("%s ZEROFILL: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
})
}
func TestFloat(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [2]string{"FLOAT", "DOUBLE"}
in := float32(42.23)
var out float32
var rows *sql.Rows
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ")")
dbt.mustExec("INSERT INTO test VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Errorf("%s: %g != %g", v, in, out)
}
} else {
dbt.Errorf("%s: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
})
}
func TestString(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"}
in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย"
var out string
var rows *sql.Rows
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8")
dbt.mustExec("INSERT INTO test VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Errorf("%s: %s != %s", v, in, out)
}
} else {
dbt.Errorf("%s: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
// BLOB
dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8")
id := 2
in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " +
"sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " +
"sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " +
"Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " +
"Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " +
"sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " +
"sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " +
"Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet."
dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in)
err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out)
if err != nil {
dbt.Fatalf("Error on BLOB-Query: %s", err.Error())
} else if out != in {
dbt.Errorf("BLOB: %s != %s", in, out)
}
})
}
func TestDateTime(t *testing.T) {
type testmode struct {
selectSuffix string
args []interface{}
}
type timetest struct {
in interface{}
sOut string
tOut time.Time
tIsZero bool
}
type tester func(dbt *DBTest, rows *sql.Rows,
test *timetest, sqltype, resulttype, mode string)
type setup struct {
vartype string
dsnSuffix string
test tester
}
var (
modes = map[string]*testmode{
"text": &testmode{},
"binary": &testmode{" WHERE 1 = ?", []interface{}{1}},
}
timetests = map[string][]*timetest{
"DATE": {
{sDate, sDate, tDate, false},
{sDate0, sDate0, tDate0, true},
{tDate, sDate, tDate, false},
{tDate0, sDate0, tDate0, true},
},
"DATETIME": {
{sDateTime, sDateTime, tDateTime, false},
{sDateTime0, sDateTime0, tDate0, true},
{tDateTime, sDateTime, tDateTime, false},
{tDate0, sDateTime0, tDate0, true},
},
}
setups = []*setup{
{"string", "&parseTime=false", func(
dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
var sOut string
if err := rows.Scan(&sOut); err != nil {
dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
} else if test.sOut != sOut {
dbt.Errorf("%s (%s %s): %s != %s", sqltype, resulttype, mode, test.sOut, sOut)
}
}},
{"time.Time", "&parseTime=true", func(
dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
var tOut time.Time
if err := rows.Scan(&tOut); err != nil {
dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
} else if test.tOut != tOut || test.tIsZero != tOut.IsZero() {
dbt.Errorf("%s (%s %s): %s [%t] != %s [%t]", sqltype, resulttype, mode, test.tOut, test.tIsZero, tOut, tOut.IsZero())
}
}},
}
)
var s *setup
testTime := func(dbt *DBTest) {
var rows *sql.Rows
for sqltype, tests := range timetests {
dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
for _, test := range tests {
for mode, q := range modes {
dbt.mustExec("TRUNCATE test")
dbt.mustExec("INSERT INTO test VALUES (?)", test.in)
rows = dbt.mustQuery("SELECT value FROM test"+q.selectSuffix, q.args...)
if rows.Next() {
s.test(dbt, rows, test, sqltype, s.vartype, mode)
} else {
if err := rows.Err(); err != nil {
dbt.Errorf("%s (%s %s): %s",
sqltype, s.vartype, mode, err.Error())
} else {
dbt.Errorf("%s (%s %s): no data",
sqltype, s.vartype, mode)
}
}
}
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
}
timeDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
for _, v := range setups {
s = v
runTests(t, timeDsn+s.dsnSuffix, testTime)
}
}
func TestNULL(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
nullStmt, err := dbt.db.Prepare("SELECT NULL")
if err != nil {
dbt.Fatal(err)
}
defer nullStmt.Close()
nonNullStmt, err := dbt.db.Prepare("SELECT 1")
if err != nil {
dbt.Fatal(err)
}
defer nonNullStmt.Close()
// NullBool
var nb sql.NullBool
// Invalid
if err = nullStmt.QueryRow().Scan(&nb); err != nil {
dbt.Fatal(err)
}
if nb.Valid {
dbt.Error("Valid NullBool which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
dbt.Fatal(err)
}
if !nb.Valid {
dbt.Error("Invalid NullBool which should be valid")
} else if nb.Bool != true {
dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool)
}
// NullFloat64
var nf sql.NullFloat64
// Invalid
if err = nullStmt.QueryRow().Scan(&nf); err != nil {
dbt.Fatal(err)
}
if nf.Valid {
dbt.Error("Valid NullFloat64 which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
dbt.Fatal(err)
}
if !nf.Valid {
dbt.Error("Invalid NullFloat64 which should be valid")
} else if nf.Float64 != float64(1) {
dbt.Errorf("Unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
}
// NullInt64
var ni sql.NullInt64
// Invalid
if err = nullStmt.QueryRow().Scan(&ni); err != nil {
dbt.Fatal(err)
}
if ni.Valid {
dbt.Error("Valid NullInt64 which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
dbt.Fatal(err)
}
if !ni.Valid {
dbt.Error("Invalid NullInt64 which should be valid")
} else if ni.Int64 != int64(1) {
dbt.Errorf("Unexpected NullInt64 value: %d (should be 1)", ni.Int64)
}
// NullString
var ns sql.NullString
// Invalid
if err = nullStmt.QueryRow().Scan(&ns); err != nil {
dbt.Fatal(err)
}
if ns.Valid {
dbt.Error("Valid NullString which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
dbt.Fatal(err)
}
if !ns.Valid {
dbt.Error("Invalid NullString which should be valid")
} else if ns.String != `1` {
dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)")
}
// nil-bytes
var b []byte
// Read nil
if err = nullStmt.QueryRow().Scan(&b); err != nil {
dbt.Fatal(err)
}
if b != nil {
dbt.Error("Non-nil []byte wich should be nil")
}
// Read non-nil
if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
dbt.Fatal(err)
}
if b == nil {
dbt.Error("Nil []byte wich should be non-nil")
}
// Insert nil
b = nil
success := false
if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil {
dbt.Fatal(err)
}
if !success {
dbt.Error("Inserting []byte(nil) as NULL failed")
}
// Check input==output with input==nil
b = nil
if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
dbt.Fatal(err)
}
if b != nil {
dbt.Error("Non-nil echo from nil input")
}
// Check input==output with input!=nil
b = []byte("")
if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
dbt.Fatal(err)
}
if b == nil {
dbt.Error("nil echo from non-nil input")
}
// Insert NULL
dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)")
dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2)
var out interface{}
rows := dbt.mustQuery("SELECT * FROM test")
if rows.Next() {
rows.Scan(&out)
if out != nil {
dbt.Errorf("%v != nil", out)
}
} else {
dbt.Error("no data")
}
})
}
func TestLongData(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
var maxAllowedPacketSize int
err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
if err != nil {
dbt.Fatal(err)
}
maxAllowedPacketSize--
// don't get too ambitious
if maxAllowedPacketSize > 1<<25 {
maxAllowedPacketSize = 1 << 25
}
dbt.mustExec("CREATE TABLE test (value LONGBLOB)")
in := strings.Repeat(`a`, maxAllowedPacketSize+1)
var out string
var rows *sql.Rows
// Long text data
const nonDataQueryLen = 28 // length query w/o value
inS := in[:maxAllowedPacketSize-nonDataQueryLen]
dbt.mustExec("INSERT INTO test VALUES('" + inS + "')")
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if inS != out {
dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out))
}
if rows.Next() {
dbt.Error("LONGBLOB: unexpexted row")
}
} else {
dbt.Fatalf("LONGBLOB: no data")
}
// Empty table
dbt.mustExec("TRUNCATE TABLE test")
// Long binary data
dbt.mustExec("INSERT INTO test VALUES(?)", in)
rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1)
if rows.Next() {
rows.Scan(&out)
if in != out {
dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out))
}
if rows.Next() {
dbt.Error("LONGBLOB: unexpexted row")
}
} else {
if err = rows.Err(); err != nil {
dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error())
} else {
dbt.Fatal("LONGBLOB: no data (err: <nil>)")
}
}
})
}
func TestLoadData(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
verifyLoadDataResult := func() {
rows, err := dbt.db.Query("SELECT * FROM test")
if err != nil {
dbt.Fatal(err.Error())
}
i := 0
values := [4]string{
"a string",
"a string containing a \t",
"a string containing a \n",
"a string containing both \t\n",
}
var id int
var value string
for rows.Next() {
i++
err = rows.Scan(&id, &value)
if err != nil {
dbt.Fatal(err.Error())
}
if i != id {
dbt.Fatalf("%d != %d", i, id)
}
if values[i-1] != value {
dbt.Fatalf("%s != %s", values[i-1], value)
}
}
err = rows.Err()
if err != nil {
dbt.Fatal(err.Error())
}
if i != 4 {
dbt.Fatalf("Rows count mismatch. Got %d, want 4", i)
}
}
file, err := ioutil.TempFile("", "gotest")
defer os.Remove(file.Name())
if err != nil {
dbt.Fatal(err)
}
file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
file.Close()
dbt.db.Exec("DROP TABLE IF EXISTS test")
dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
// Local File
RegisterLocalFile(file.Name())
dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name()))
verifyLoadDataResult()
// negative test
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")
if err == nil {
dbt.Fatal("Load non-existent file didn't fail")
} else if err.Error() != "Local File 'doesnotexist' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" {
dbt.Fatal(err.Error())
}
// Empty table
dbt.mustExec("TRUNCATE TABLE test")
// Reader
RegisterReaderHandler("test", func() io.Reader {
file, err = os.Open(file.Name())
if err != nil {
dbt.Fatal(err)
}
return file
})
dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test")
verifyLoadDataResult()
// negative test
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test")
if err == nil {
dbt.Fatal("Load non-existent Reader didn't fail")
} else if err.Error() != "Reader 'doesnotexist' is not registered" {
dbt.Fatal(err.Error())
}
})
}
func TestFoundRows(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
count, err := res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 2 {
dbt.Fatalf("Expected 2 affected rows, got %d", count)
}
res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 2 {
dbt.Fatalf("Expected 2 affected rows, got %d", count)
}
})
runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
count, err := res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 2 {
dbt.Fatalf("Expected 2 matched rows, got %d", count)
}
res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 3 {
dbt.Fatalf("Expected 3 matched rows, got %d", count)
}
})
}
func TestStrict(t *testing.T) {
// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
runTests(t, relaxedDsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
var queries = [...]struct {
in string
codes []string
}{
{"DROP TABLE IF EXISTS no_such_table", []string{"1051"}},
{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}},
}
var err error
var checkWarnings = func(err error, mode string, idx int) {
if err == nil {
dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in)
}
if warnings, ok := err.(MySQLWarnings); ok {
var codes = make([]string, len(warnings))
for i := range warnings {
codes[i] = warnings[i].Code
}
if len(codes) != len(queries[idx].codes) {
dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
}
for i := range warnings {
if codes[i] != queries[idx].codes[i] {
dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
return
}
}
} else {
dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
}
}
// text protocol
for i := range queries {
_, err = dbt.db.Exec(queries[i].in)
checkWarnings(err, "text", i)
}
var stmt *sql.Stmt
// binary protocol
for i := range queries {
stmt, err = dbt.db.Prepare(queries[i].in)
if err != nil {
dbt.Errorf("Error on preparing query %s: %s", queries[i].in, err.Error())
}
_, err = stmt.Exec()
checkWarnings(err, "binary", i)
err = stmt.Close()
if err != nil {
dbt.Errorf("Error on closing stmt for query %s: %s", queries[i].in, err.Error())
}
}
})
}
func TestTLS(t *testing.T) {
tlsTest := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
if err == errNoTLS {
dbt.Skip("Server does not support TLS")
} else {
dbt.Fatalf("Error on Ping: %s", err.Error())
}
}
rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'")
var variable, value *sql.RawBytes
for rows.Next() {
if err := rows.Scan(&variable, &value); err != nil {
dbt.Fatal(err.Error())
}
if value == nil {
dbt.Fatal("No Cipher")
}
}
}
runTests(t, dsn+"&tls=skip-verify", tlsTest)
// Verify that registering / using a custom cfg works
RegisterTLSConfig("custom-skip-verify", &tls.Config{
InsecureSkipVerify: true,
})
runTests(t, dsn+"&tls=custom-skip-verify", tlsTest)
}
func TestReuseClosedConnection(t *testing.T) {
// this test does not use sql.database, it uses the driver directly
if !available {
t.Skipf("MySQL-Server not running on %s", netAddr)
}
md := &MySQLDriver{}
conn, err := md.Open(dsn)
if err != nil {
t.Fatalf("Error connecting: %s", err.Error())
}
stmt, err := conn.Prepare("DO 1")
if err != nil {
t.Fatalf("Error preparing statement: %s", err.Error())
}
_, err = stmt.Exec(nil)
if err != nil {
t.Fatalf("Error executing statement: %s", err.Error())
}
err = conn.Close()
if err != nil {
t.Fatalf("Error closing connection: %s", err.Error())
}
defer func() {
if err := recover(); err != nil {
t.Errorf("Panic after reusing a closed connection: %v", err)
}
}()
_, err = stmt.Exec(nil)
if err != nil && err != driver.ErrBadConn {
t.Errorf("Unexpected error '%s', expected '%s'",
err.Error(), driver.ErrBadConn.Error())
}
}
func TestCharset(t *testing.T) {
if !available {
t.Skipf("MySQL-Server not running on %s", netAddr)
}
mustSetCharset := func(charsetParam, expected string) {
runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
rows := dbt.mustQuery("SELECT @@character_set_connection")
defer rows.Close()
if !rows.Next() {
dbt.Fatalf("Error getting connection charset: %s", rows.Err())
}
var got string
rows.Scan(&got)
if got != expected {
dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
}
})
}
// non utf8 test
mustSetCharset("charset=ascii", "ascii")
// when the first charset is invalid, use the second
mustSetCharset("charset=none,utf8", "utf8")
// when the first charset is valid, use it
mustSetCharset("charset=ascii,utf8", "ascii")
mustSetCharset("charset=utf8,ascii", "utf8")
}
func TestFailingCharset(t *testing.T) {
runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
// run query to really establish connection...
_, err := dbt.db.Exec("SELECT 1")
if err == nil {
dbt.db.Close()
t.Fatalf("Connection must not succeed without a valid charset")
}
})
}
func TestRawBytesResultExceedsBuffer(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
// defaultBufSize from buffer.go
expected := strings.Repeat("abc", defaultBufSize)
rows := dbt.mustQuery("SELECT '" + expected + "'")
defer rows.Close()
if !rows.Next() {
dbt.Error("expected result, got none")
}
var result sql.RawBytes
rows.Scan(&result)
if expected != string(result) {
dbt.Error("result did not match expected value")
}
})
}
func TestTimezoneConversion(t *testing.T) {
zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
// Regression test for timezone handling
tzTest := func(dbt *DBTest) {
// Create table
dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
// Insert local time into database (should be converted)
usCentral, _ := time.LoadLocation("US/Central")
now := time.Now().In(usCentral)
dbt.mustExec("INSERT INTO test VALUE (?)", now)
// Retrieve time from DB
rows := dbt.mustQuery("SELECT ts FROM test")
if !rows.Next() {
dbt.Fatal("Didn't get any rows out")
}
var nowDB time.Time
err := rows.Scan(&nowDB)
if err != nil {
dbt.Fatal("Err", err)
}
// Check that dates match
if now.Unix() != nowDB.Unix() {
dbt.Errorf("Times don't match.\n")
dbt.Errorf(" Now(%v)=%v\n", usCentral, now)
dbt.Errorf(" Now(UTC)=%v\n", nowDB)
}
}
for _, tz := range zones {
runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
}
}
// This tests for https://github.com/go-sql-driver/mysql/pull/139
//
// An extra (invisible) nil byte was being added to the beginning of positive
// time strings.
func TestTimeSign(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
var sTimes = []struct {
value string
fieldType string
}{
{"12:34:56", "TIME"},
{"-12:34:56", "TIME"},
// As described in http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html
// they *should* work, but only in 5.6+.
// { "12:34:56.789", "TIME(3)" },
// { "-12:34:56.789", "TIME(3)" },
}
for _, sTime := range sTimes {
dbt.db.Exec("DROP TABLE IF EXISTS test")
dbt.mustExec("CREATE TABLE test (id INT, time_field " + sTime.fieldType + ")")
dbt.mustExec("INSERT INTO test (id, time_field) VALUES(1, '" + sTime.value + "')")
rows := dbt.mustQuery("SELECT time_field FROM test WHERE id = ?", 1)
if rows.Next() {
var oTime string
rows.Scan(&oTime)
if oTime != sTime.value {
dbt.Errorf(`time values differ: got %q, expected %q.`, oTime, sTime.value)
}
} else {
dbt.Error("expecting at least one row.")
}
}
})
}
// Special cases
func TestRowsClose(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
rows, err := dbt.db.Query("SELECT 1")
if err != nil {
dbt.Fatal(err)
}
err = rows.Close()
if err != nil {
dbt.Fatal(err)
}
if rows.Next() {
dbt.Fatal("Unexpected row after rows.Close()")
}
err = rows.Err()
if err != nil {
dbt.Fatal(err)
}
})
}
// dangling statements
// http://code.google.com/p/go/issues/detail?id=3865
func TestCloseStmtBeforeRows(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
stmt, err := dbt.db.Prepare("SELECT 1")
if err != nil {
dbt.Fatal(err)
}
rows, err := stmt.Query()
if err != nil {
stmt.Close()
dbt.Fatal(err)
}
defer rows.Close()
err = stmt.Close()
if err != nil {
dbt.Fatal(err)
}
if !rows.Next() {
dbt.Fatal("Getting row failed")
} else {
err = rows.Err()
if err != nil {
dbt.Fatal(err)
}
var out bool
err = rows.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != true {
dbt.Errorf("true != %t", out)
}
}
})
}
// It is valid to have multiple Rows for the same Stmt
// http://code.google.com/p/go/issues/detail?id=3734
func TestStmtMultiRows(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0")
if err != nil {
dbt.Fatal(err)
}
rows1, err := stmt.Query()
if err != nil {
stmt.Close()
dbt.Fatal(err)
}
defer rows1.Close()
rows2, err := stmt.Query()
if err != nil {
stmt.Close()
dbt.Fatal(err)
}
defer rows2.Close()
var out bool
// 1
if !rows1.Next() {
dbt.Fatal("1st rows1.Next failed")
} else {
err = rows1.Err()
if err != nil {
dbt.Fatal(err)
}
err = rows1.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != true {
dbt.Errorf("true != %t", out)
}
}
if !rows2.Next() {
dbt.Fatal("1st rows2.Next failed")
} else {
err = rows2.Err()
if err != nil {
dbt.Fatal(err)
}
err = rows2.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != true {
dbt.Errorf("true != %t", out)
}
}
// 2
if !rows1.Next() {
dbt.Fatal("2nd rows1.Next failed")
} else {
err = rows1.Err()
if err != nil {
dbt.Fatal(err)
}
err = rows1.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != false {
dbt.Errorf("false != %t", out)
}
if rows1.Next() {
dbt.Fatal("Unexpected row on rows1")
}
err = rows1.Close()
if err != nil {
dbt.Fatal(err)
}
}
if !rows2.Next() {
dbt.Fatal("2nd rows2.Next failed")
} else {
err = rows2.Err()
if err != nil {
dbt.Fatal(err)
}
err = rows2.Scan(&out)
if err != nil {
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != false {
dbt.Errorf("false != %t", out)
}
if rows2.Next() {
dbt.Fatal("Unexpected row on rows2")
}
err = rows2.Close()
if err != nil {
dbt.Fatal(err)
}
}
})
}
func TestConcurrent(t *testing.T) {
if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
t.Skip("MYSQL_TEST_CONCURRENT env var not set")
}
runTests(t, dsn, func(dbt *DBTest) {
var max int
err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)
if err != nil {
dbt.Fatalf("%s", err.Error())
}
dbt.Logf("Testing up to %d concurrent connections \r\n", max)
canStop := false
c := make(chan struct{}, max)
for i := 0; i < max; i++ {
go func(id int) {
tx, err := dbt.db.Begin()
if err != nil {
canStop = true
if err.Error() == "Error 1040: Too many connections" {
max--
return
} else {
dbt.Fatalf("Error on Con %d: %s", id, err.Error())
}
}
c <- struct{}{}
for !canStop {
_, err = tx.Exec("SELECT 1")
if err != nil {
canStop = true
dbt.Fatalf("Error on Con %d: %s", id, err.Error())
}
}
err = tx.Commit()
if err != nil {
canStop = true
dbt.Fatalf("Error on Con %d: %s", id, err.Error())
}
}(i)
}
for i := 0; i < max; i++ {
<-c
}
canStop = true
dbt.Logf("Reached %d concurrent connections \r\n", max)
})
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"errors"
"fmt"
"io"
)
var (
errInvalidConn = errors.New("Invalid Connection")
errMalformPkt = errors.New("Malformed Packet")
errNoTLS = errors.New("TLS encryption requested but server does not support TLS")
errOldPassword = errors.New("This server only supports the insecure old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
errPktSync = errors.New("Commands out of sync. You can't run this command now")
errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
)
// error type which represents a single MySQL error
type MySQLError struct {
Number uint16
Message string
}
func (me *MySQLError) Error() string {
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}
// error type which represents a group of one or more MySQL warnings
type MySQLWarnings []mysqlWarning
func (mws MySQLWarnings) Error() string {
var msg string
for i, warning := range mws {
if i > 0 {
msg += "\r\n"
}
msg += fmt.Sprintf("%s %s: %s", warning.Level, warning.Code, warning.Message)
}
return msg
}
// error type which represents a single MySQL warning
type mysqlWarning struct {
Level string
Code string
Message string
}
func (mc *mysqlConn) getWarnings() (err error) {
rows, err := mc.Query("SHOW WARNINGS", []driver.Value{})
if err != nil {
return
}
var warnings = MySQLWarnings{}
var values = make([]driver.Value, 3)
var warning mysqlWarning
var raw []byte
var ok bool
for {
err = rows.Next(values)
switch err {
case nil:
warning = mysqlWarning{}
if raw, ok = values[0].([]byte); ok {
warning.Level = string(raw)
} else {
warning.Level = fmt.Sprintf("%s", values[0])
}
if raw, ok = values[1].([]byte); ok {
warning.Code = string(raw)
} else {
warning.Code = fmt.Sprintf("%s", values[1])
}
if raw, ok = values[2].([]byte); ok {
warning.Message = string(raw)
} else {
warning.Message = fmt.Sprintf("%s", values[0])
}
warnings = append(warnings, warning)
case io.EOF:
return warnings
default:
rows.Close()
return
}
}
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"fmt"
"io"
"os"
"strings"
)
var (
fileRegister map[string]bool
readerRegister map[string]func() io.Reader
)
func init() {
fileRegister = make(map[string]bool)
readerRegister = make(map[string]func() io.Reader)
}
// RegisterLocalFile adds the given file to the file whitelist,
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
// Alternatively you can allow the use of all local files with
// the DSN parameter 'allowAllFiles=true'
//
// filePath := "/home/gopher/data.csv"
// mysql.RegisterLocalFile(filePath)
// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterLocalFile(filePath string) {
fileRegister[strings.Trim(filePath, `"`)] = true
}
// DeregisterLocalFile removes the given filepath from the whitelist.
func DeregisterLocalFile(filePath string) {
delete(fileRegister, strings.Trim(filePath, `"`))
}
// RegisterReaderHandler registers a handler function which is used
// to receive a io.Reader.
// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
// If the handler returns a io.ReadCloser Close() is called when the
// request is finished.
//
// mysql.RegisterReaderHandler("data", func() io.Reader {
// var csvReader io.Reader // Some Reader that returns CSV data
// ... // Open Reader here
// return csvReader
// })
// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterReaderHandler(name string, handler func() io.Reader) {
readerRegister[name] = handler
}
// DeregisterReaderHandler removes the ReaderHandler function with
// the given name from the registry.
func DeregisterReaderHandler(name string) {
delete(readerRegister, name)
}
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
var rdr io.Reader
data := make([]byte, 4+mc.maxWriteSize)
if strings.HasPrefix(name, "Reader::") { // io.Reader
name = name[8:]
handler, inMap := readerRegister[name]
if handler != nil {
rdr = handler()
}
if rdr == nil {
if !inMap {
err = fmt.Errorf("Reader '%s' is not registered", name)
} else {
err = fmt.Errorf("Reader '%s' is <nil>", name)
}
}
} else { // File
name = strings.Trim(name, `"`)
if mc.cfg.allowAllFiles || fileRegister[name] {
rdr, err = os.Open(name)
} else {
err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
}
}
if rdc, ok := rdr.(io.ReadCloser); ok {
defer func() {
if err == nil {
err = rdc.Close()
} else {
rdc.Close()
}
}()
}
// send content packets
var ioErr error
if err == nil {
var n int
for err == nil && ioErr == nil {
n, err = rdr.Read(data[4:])
if n > 0 {
data[0] = byte(n)
data[1] = byte(n >> 8)
data[2] = byte(n >> 16)
data[3] = mc.sequence
ioErr = mc.writePacket(data[:4+n])
}
}
if err == io.EOF {
err = nil
}
if ioErr != nil {
errLog.Print(ioErr.Error())
return driver.ErrBadConn
}
}
// send empty packet (termination)
ioErr = mc.writePacket([]byte{
0x00,
0x00,
0x00,
mc.sequence,
})
if ioErr != nil {
errLog.Print(ioErr.Error())
return driver.ErrBadConn
}
// read OK packet
if err == nil {
return mc.readResultOK()
} else {
mc.readPacket()
}
return err
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"math"
"time"
)
// Packets documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
// Read packet header
data, err := mc.buf.readNext(4)
if err != nil {
errLog.Print(err.Error())
mc.Close()
return nil, driver.ErrBadConn
}
// Packet Length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
if pktLen < 1 {
errLog.Print(errMalformPkt.Error())
mc.Close()
return nil, driver.ErrBadConn
}
// Check Packet Sync [8 bit]
if data[3] != mc.sequence {
if data[3] > mc.sequence {
return nil, errPktSyncMul
} else {
return nil, errPktSync
}
}
mc.sequence++
// Read packet body [pktLen bytes]
if data, err = mc.buf.readNext(pktLen); err == nil {
if pktLen < maxPacketSize {
return data, nil
}
// Make a copy since data becomes invalid with the next read
buf := make([]byte, len(data))
copy(buf, data)
// More data
data, err = mc.readPacket()
if err == nil {
return append(buf, data...), nil
}
}
// err case
mc.Close()
errLog.Print(err.Error())
return nil, driver.ErrBadConn
}
// Write packet buffer 'data'
// The packet header must be already included
func (mc *mysqlConn) writePacket(data []byte) error {
if len(data)-4 <= mc.maxWriteSize { // Can send data at once
// Write packet
n, err := mc.netConn.Write(data)
if err == nil && n == len(data) {
mc.sequence++
return nil
}
// Handle error
if err == nil { // n != len(data)
errLog.Print(errMalformPkt.Error())
} else {
errLog.Print(err.Error())
}
return driver.ErrBadConn
}
// Must split packet
return mc.splitPacket(data)
}
func (mc *mysqlConn) splitPacket(data []byte) error {
pktLen := len(data) - 4
if pktLen > mc.maxPacketAllowed {
return errPktTooLarge
}
for pktLen >= maxPacketSize {
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
data[3] = mc.sequence
// Write packet
n, err := mc.netConn.Write(data[:4+maxPacketSize])
if err == nil && n == 4+maxPacketSize {
mc.sequence++
data = data[maxPacketSize:]
pktLen -= maxPacketSize
continue
}
// Handle error
if err == nil { // n != len(data)
errLog.Print(errMalformPkt.Error())
} else {
errLog.Print(err.Error())
}
return driver.ErrBadConn
}
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
data[3] = mc.sequence
return mc.writePacket(data)
}
/******************************************************************************
* Initialisation Process *
******************************************************************************/
// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
data, err := mc.readPacket()
if err != nil {
return nil, err
}
if data[0] == iERR {
return nil, mc.handleErrorPacket(data)
}
// protocol version [1 byte]
if data[0] < minProtocolVersion {
return nil, fmt.Errorf(
"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
data[0],
minProtocolVersion,
)
}
// server version [null terminated string]
// connection id [4 bytes]
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
// first part of the password cipher [8 bytes]
cipher := data[pos : pos+8]
// (filler) always 0x00 [1 byte]
pos += 8 + 1
// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if mc.flags&clientProtocol41 == 0 {
return nil, errOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
return nil, errNoTLS
}
pos += 2
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
// capability flags (upper 2 bytes) [2 bytes]
// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
// second part of the password cipher [12? bytes]
// The documentation is ambiguous about the length.
// The official Python library uses the fixed length 12
// which is not documented but seems to work.
cipher = append(cipher, data[pos:pos+12]...)
// TODO: Verify string termination
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
// \NUL otherwise
//
//if data[len(data)-1] == 0 {
// return
//}
//return errMalformPkt
}
return cipher, nil
}
// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
clientLongPassword |
clientTransactions |
clientLocalFiles |
mc.flags&clientLongFlag
if mc.cfg.clientFoundRows {
clientFlags |= clientFoundRows
}
// To enable TLS / SSL
if mc.cfg.tls != nil {
clientFlags |= clientSSL
}
// User Password
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd))
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
// To specify a db name
if n := len(mc.cfg.dbname); n > 0 {
clientFlags |= clientConnectWithDB
pktLen += n + 1
}
// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print("Busy buffer")
return driver.ErrBadConn
}
// ClientFlags [32 bit]
data[4] = byte(clientFlags)
data[5] = byte(clientFlags >> 8)
data[6] = byte(clientFlags >> 16)
data[7] = byte(clientFlags >> 24)
// MaxPacketSize [32 bit] (none)
data[8] = 0x00
data[9] = 0x00
data[10] = 0x00
data[11] = 0x00
// Charset [1 byte]
data[12] = collation_utf8_general_ci
// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if mc.cfg.tls != nil {
// Packet header [24bit length + 1 byte sequence]
data[0] = byte((4 + 4 + 1 + 23))
data[1] = byte((4 + 4 + 1 + 23) >> 8)
data[2] = byte((4 + 4 + 1 + 23) >> 16)
data[3] = mc.sequence
// Send TLS / SSL request packet
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
return err
}
// Switch to TLS
tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
if err := tlsConn.Handshake(); err != nil {
return err
}
mc.netConn = tlsConn
mc.buf.rd = tlsConn
}
// Add the packet header [24bit length + 1 byte sequence]
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
data[3] = mc.sequence
// Filler [23 bytes] (all 0x00)
pos := 13 + 23
// User [null terminated string]
if len(mc.cfg.user) > 0 {
pos += copy(data[pos:], mc.cfg.user)
}
data[pos] = 0x00
pos++
// ScrambleBuffer [length encoded integer]
data[pos] = byte(len(scrambleBuff))
pos += 1 + copy(data[pos+1:], scrambleBuff)
// Databasename [null terminated string]
if len(mc.cfg.dbname) > 0 {
pos += copy(data[pos:], mc.cfg.dbname)
data[pos] = 0x00
}
// Send Auth packet
return mc.writePacket(data)
}
// Client old authentication packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
// User password
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
// Calculate the packet lenght and add a tailing 0
pktLen := len(scrambleBuff) + 1
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print("Busy buffer")
return driver.ErrBadConn
}
// Add the packet header [24bit length + 1 byte sequence]
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
data[3] = mc.sequence
// Add the scrambled password [null terminated string]
copy(data[4:], scrambleBuff)
return mc.writePacket(data)
}
/******************************************************************************
* Command Packets *
******************************************************************************/
func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0
data := mc.buf.takeSmallBuffer(4 + 1)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print("Busy buffer")
return driver.ErrBadConn
}
// Add the packet header [24bit length + 1 byte sequence]
data[0] = 0x01 // 1 byte long
data[1] = 0x00
data[2] = 0x00
data[3] = 0x00 // new command, sequence id is always 0
// Add command byte
data[4] = command
// Send CMD packet
return mc.writePacket(data)
}
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
// Reset Packet Sequence
mc.sequence = 0
pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print("Busy buffer")
return driver.ErrBadConn
}
// Add the packet header [24bit length + 1 byte sequence]
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
data[3] = 0x00 // new command, sequence id is always 0
// Add command byte
data[4] = command
// Add arg
copy(data[5:], arg)
// Send CMD packet
return mc.writePacket(data)
}
func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print("Busy buffer")
return driver.ErrBadConn
}
// Add the packet header [24bit length + 1 byte sequence]
data[0] = 0x05 // 5 bytes long
data[1] = 0x00
data[2] = 0x00
data[3] = 0x00 // new command, sequence id is always 0
// Add command byte
data[4] = command
// Add arg [32 bit]
data[5] = byte(arg)
data[6] = byte(arg >> 8)
data[7] = byte(arg >> 16)
data[8] = byte(arg >> 24)
// Send CMD packet
return mc.writePacket(data)
}
/******************************************************************************
* Result Packets *
******************************************************************************/
// Returns error if Packet is not an 'Result OK'-Packet
func (mc *mysqlConn) readResultOK() error {
data, err := mc.readPacket()
if err == nil {
// packet indicator
switch data[0] {
case iOK:
return mc.handleOkPacket(data)
case iEOF:
// someone is using old_passwords
return errOldPassword
default: // Error otherwise
return mc.handleErrorPacket(data)
}
}
return err
}
// Result Set Header Packet
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
data, err := mc.readPacket()
if err == nil {
switch data[0] {
case iOK:
return 0, mc.handleOkPacket(data)
case iERR:
return 0, mc.handleErrorPacket(data)
case iLocalInFile:
return 0, mc.handleInFileRequest(string(data[1:]))
}
// column count
num, _, n := readLengthEncodedInteger(data)
if n-len(data) == 0 {
return int(num), nil
}
return 0, errMalformPkt
}
return 0, err
}
// Error Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
if data[0] != iERR {
return errMalformPkt
}
// 0xff [1 byte]
// Error Number [16 bit uint]
errno := binary.LittleEndian.Uint16(data[1:3])
pos := 3
// SQL State [optional: # + 5bytes string]
if data[3] == 0x23 {
//sqlstate := string(data[4 : 4+5])
pos = 9
}
// Error Message [string]
return &MySQLError{
Number: errno,
Message: string(data[pos:]),
}
}
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error {
var n, m int
// 0x00 [1 byte]
// Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
// Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
// server_status [2 bytes]
// warning count [2 bytes]
if !mc.strict {
return nil
} else {
pos := 1 + n + m + 2
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
return mc.getWarnings()
}
return nil
}
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
columns := make([]mysqlField, count)
for i := 0; ; i++ {
data, err := mc.readPacket()
if err != nil {
return nil, err
}
// EOF Packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
if i == count {
return columns, nil
}
return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
}
// Catalog
pos, err := skipLengthEnodedString(data)
if err != nil {
return nil, err
}
// Database [len coded string]
n, err := skipLengthEnodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
// Table [len coded string]
n, err = skipLengthEnodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
// Original table [len coded string]
n, err = skipLengthEnodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
// Name [len coded string]
name, _, n, err := readLengthEnodedString(data[pos:])
if err != nil {
return nil, err
}
columns[i].name = string(name)
pos += n
// Original name [len coded string]
n, err = skipLengthEnodedString(data[pos:])
if err != nil {
return nil, err
}
// Filler [1 byte]
// Charset [16 bit uint]
// Length [32 bit uint]
pos += n + 1 + 2 + 4
// Field type [byte]
columns[i].fieldType = data[pos]
pos++
// Flags [16 bit uint]
columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
//pos += 2
// Decimals [8 bit uint]
//pos++
// Default value [len coded binary]
//if pos < len(data) {
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
//}
}
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc
data, err := mc.readPacket()
if err != nil {
return err
}
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
return io.EOF
}
// RowSet Packet
var n int
var isNull bool
pos := 0
for i := range dest {
// Read bytes and convert to string
dest[i], isNull, n, err = readLengthEnodedString(data[pos:])
pos += n
if err == nil {
if !isNull {
if !mc.parseTime {
continue
} else {
switch rows.columns[i].fieldType {
case fieldTypeTimestamp, fieldTypeDateTime,
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
string(dest[i].([]byte)),
mc.cfg.loc,
)
if err == nil {
continue
}
default:
continue
}
}
} else {
dest[i] = nil
continue
}
}
return err // err != nil
}
return nil
}
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
func (mc *mysqlConn) readUntilEOF() error {
for {
data, err := mc.readPacket()
// No Err and no EOF Packet
if err == nil && data[0] != iEOF {
continue
}
return err // Err or EOF
}
}
/******************************************************************************
* Prepared Statements *
******************************************************************************/
// Prepare Result Packets
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
data, err := stmt.mc.readPacket()
if err == nil {
// packet indicator [1 byte]
if data[0] != iOK {
return 0, stmt.mc.handleErrorPacket(data)
}
// statement id [4 bytes]
stmt.id = binary.LittleEndian.Uint32(data[1:5])
// Column count [16 bit uint]
columnCount := binary.LittleEndian.Uint16(data[5:7])
// Param count [16 bit uint]
stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
// Reserved [8 bit]
// Warning count [16 bit uint]
if !stmt.mc.strict {
return columnCount, nil
} else {
// Check for warnings count > 0, only available in MySQL > 4.1
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
return columnCount, stmt.mc.getWarnings()
}
return columnCount, nil
}
}
return 0, err
}
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
maxLen := stmt.mc.maxPacketAllowed - 1
pktLen := maxLen
// After the header (bytes 0-3) follows before the data:
// 1 byte command
// 4 bytes stmtID
// 2 bytes paramID
const dataOffset = 1 + 4 + 2
// Can not use the write buffer since
// a) the buffer is too small
// b) it is in use
data := make([]byte, 4+1+4+2+len(arg))
copy(data[4+dataOffset:], arg)
for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
if dataOffset+argLen < maxLen {
pktLen = dataOffset + argLen
}
// Add the packet header [24bit length + 1 byte sequence]
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
data[3] = 0x00 // mc.sequence
// Add command byte [1 byte]
data[4] = comStmtSendLongData
// Add stmtID [32 bit]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// Add paramID [16 bit]
data[9] = byte(paramID)
data[10] = byte(paramID >> 8)
// Send CMD packet
err := stmt.mc.writePacket(data[:4+pktLen])
if err == nil {
data = data[pktLen-dataOffset:]
continue
}
return err
}
// Reset Packet Sequence
stmt.mc.sequence = 0
return nil
}
// Execute Prepared Statement
// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if len(args) != stmt.paramCount {
return fmt.Errorf(
"Arguments count mismatch (Got: %d Has: %d)",
len(args),
stmt.paramCount,
)
}
mc := stmt.mc
// Reset packet-sequence
mc.sequence = 0
var data []byte
if len(args) == 0 {
const pktLen = 1 + 4 + 1 + 4
data = mc.buf.takeBuffer(4 + pktLen)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print("Busy buffer")
return driver.ErrBadConn
}
// packet header [4 bytes]
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
data[3] = 0x00 // new command, sequence id is always 0
} else {
data = mc.buf.takeCompleteBuffer()
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print("Busy buffer")
return driver.ErrBadConn
}
// header (bytes 0-3) is added after we know the packet size
}
// command [1 byte]
data[4] = comStmtExecute
// statement_id [4 bytes]
data[5] = byte(stmt.id)
data[6] = byte(stmt.id >> 8)
data[7] = byte(stmt.id >> 16)
data[8] = byte(stmt.id >> 24)
// flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte]
data[9] = 0x00
// iteration_count (uint32(1)) [4 bytes]
data[10] = 0x01
data[11] = 0x00
data[12] = 0x00
data[13] = 0x00
if len(args) > 0 {
// NULL-bitmap [(len(args)+7)/8 bytes]
nullMask := uint64(0)
pos := 4 + 1 + 4 + 1 + 4 + ((len(args) + 7) >> 3)
// newParameterBoundFlag 1 [1 byte]
data[pos] = 0x01
pos++
// type of each parameter [len(args)*2 bytes]
paramTypes := data[pos:]
pos += (len(args) << 1)
// value of each parameter [n bytes]
paramValues := data[pos:pos]
valuesCap := cap(paramValues)
for i := range args {
// build NULL-bitmap
if args[i] == nil {
nullMask |= 1 << uint(i)
paramTypes[i+i] = fieldTypeNULL
paramTypes[i+i+1] = 0x00
continue
}
// cache types and values
switch v := args[i].(type) {
case int64:
paramTypes[i+i] = fieldTypeLongLong
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case float64:
paramTypes[i+i] = fieldTypeDouble
paramTypes[i+i+1] = 0x00
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
math.Float64bits(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(math.Float64bits(v))...,
)
}
case bool:
paramTypes[i+i] = fieldTypeTiny
paramTypes[i+i+1] = 0x00
if v {
paramValues = append(paramValues, 0x01)
} else {
paramValues = append(paramValues, 0x00)
}
case []byte:
// Common case (non-nil value) first
if v != nil {
paramTypes[i+i] = fieldTypeString
paramTypes[i+i+1] = 0x00
if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
paramValues = append(paramValues, v...)
} else {
if err := stmt.writeCommandLongData(i, v); err != nil {
return err
}
}
continue
}
// Handle []byte(nil) as a NULL value
nullMask |= 1 << uint(i)
paramTypes[i+i] = fieldTypeNULL
paramTypes[i+i+1] = 0x00
case string:
paramTypes[i+i] = fieldTypeString
paramTypes[i+i+1] = 0x00
if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
paramValues = append(paramValues, v...)
} else {
if err := stmt.writeCommandLongData(i, []byte(v)); err != nil {
return err
}
}
case time.Time:
paramTypes[i+i] = fieldTypeString
paramTypes[i+i+1] = 0x00
var val []byte
if v.IsZero() {
val = []byte("0000-00-00")
} else {
val = []byte(v.In(mc.cfg.loc).Format(timeFormat))
}
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(val)),
)
paramValues = append(paramValues, val...)
default:
return fmt.Errorf("Can't convert type: %T", args[i])
}
}
// Check if param values exceeded the available buffer
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
mc.buf.buf = data
}
pos += len(paramValues)
data = data[:pos]
pktLen := pos - 4
// packet header [4 bytes]
data[0] = byte(pktLen)
data[1] = byte(pktLen >> 8)
data[2] = byte(pktLen >> 16)
data[3] = mc.sequence
// Convert nullMask to bytes
for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ {
data[i+14] = byte(nullMask >> uint(i<<3))
}
}
return mc.writePacket(data)
}
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
func (rows *binaryRows) readRow(dest []driver.Value) error {
data, err := rows.mc.readPacket()
if err != nil {
return err
}
// packet indicator [1 byte]
if data[0] != iOK {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
return io.EOF
}
// Error otherwise
return rows.mc.handleErrorPacket(data)
}
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
pos := 1 + (len(dest)+7+2)>>3
nullMask := data[1:pos]
for i := range dest {
// Field is NULL
// (byte >> bit-pos) % 2 == 1
if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
dest[i] = nil
continue
}
// Convert to byte-coded string
switch rows.columns[i].fieldType {
case fieldTypeNULL:
dest[i] = nil
continue
// Numeric Types
case fieldTypeTiny:
if rows.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(data[pos])
} else {
dest[i] = int64(int8(data[pos]))
}
pos++
continue
case fieldTypeShort, fieldTypeYear:
if rows.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
} else {
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
}
pos += 2
continue
case fieldTypeInt24, fieldTypeLong:
if rows.columns[i].flags&flagUnsigned != 0 {
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
} else {
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
}
pos += 4
continue
case fieldTypeLongLong:
if rows.columns[i].flags&flagUnsigned != 0 {
val := binary.LittleEndian.Uint64(data[pos : pos+8])
if val > math.MaxInt64 {
dest[i] = uint64ToString(val)
} else {
dest[i] = int64(val)
}
} else {
dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
}
pos += 8
continue
case fieldTypeFloat:
dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
pos += 4
continue
case fieldTypeDouble:
dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
pos += 8
continue
// Length coded Binary Strings
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
var isNull bool
var n int
dest[i], isNull, n, err = readLengthEnodedString(data[pos:])
pos += n
if err == nil {
if !isNull {
continue
} else {
dest[i] = nil
continue
}
}
return err
// Date YYYY-MM-DD
case fieldTypeDate, fieldTypeNewDate:
num, isNull, n := readLengthEncodedInteger(data[pos:])
pos += n
if isNull {
dest[i] = nil
continue
}
if rows.mc.parseTime {
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc)
} else {
dest[i], err = formatBinaryDate(num, data[pos:])
}
if err == nil {
pos += int(num)
continue
} else {
return err
}
// Time [-][H]HH:MM:SS[.fractal]
case fieldTypeTime:
num, isNull, n := readLengthEncodedInteger(data[pos:])
pos += n
if num == 0 {
if isNull {
dest[i] = nil
continue
} else {
dest[i] = []byte("00:00:00")
continue
}
}
var sign string
if data[pos] == 1 {
sign = "-"
}
switch num {
case 8:
dest[i] = []byte(fmt.Sprintf(
sign+"%02d:%02d:%02d",
uint16(data[pos+1])*24+uint16(data[pos+5]),
data[pos+6],
data[pos+7],
))
pos += 8
continue
case 12:
dest[i] = []byte(fmt.Sprintf(
sign+"%02d:%02d:%02d.%06d",
uint16(data[pos+1])*24+uint16(data[pos+5]),
data[pos+6],
data[pos+7],
binary.LittleEndian.Uint32(data[pos+8:pos+12]),
))
pos += 12
continue
default:
return fmt.Errorf("Invalid TIME-packet length %d", num)
}
// Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
case fieldTypeTimestamp, fieldTypeDateTime:
num, isNull, n := readLengthEncodedInteger(data[pos:])
pos += n
if isNull {
dest[i] = nil
continue
}
if rows.mc.parseTime {
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc)
} else {
dest[i], err = formatBinaryDateTime(num, data[pos:])
}
if err == nil {
pos += int(num)
continue
} else {
return err
}
// Please report if this happens!
default:
return fmt.Errorf("Unknown FieldType %d", rows.columns[i].fieldType)
}
}
return nil
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlResult struct {
affectedRows int64
insertId int64
}
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertId, nil
}
func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows, nil
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"io"
)
type mysqlField struct {
fieldType byte
flags fieldFlag
name string
}
type mysqlRows struct {
mc *mysqlConn
columns []mysqlField
}
type binaryRows struct {
mysqlRows
}
type textRows struct {
mysqlRows
}
func (rows *mysqlRows) Columns() []string {
columns := make([]string, len(rows.columns))
for i := range columns {
columns[i] = rows.columns[i].name
}
return columns
}
func (rows *mysqlRows) Close() error {
mc := rows.mc
if mc == nil {
return nil
}
if mc.netConn == nil {
return errInvalidConn
}
// Remove unread packets from stream
err := mc.readUntilEOF()
rows.mc = nil
return err
}
func (rows *binaryRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return errInvalidConn
}
// Fetch next row from stream
if err := rows.readRow(dest); err != io.EOF {
return err
}
rows.mc = nil
}
return io.EOF
}
func (rows *textRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return errInvalidConn
}
// Fetch next row from stream
if err := rows.readRow(dest); err != io.EOF {
return err
}
rows.mc = nil
}
return io.EOF
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
)
type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
columns []mysqlField // cached from the first query
}
func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.netConn == nil {
errLog.Print(errInvalidConn)
return driver.ErrBadConn
}
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
stmt.mc = nil
return err
}
func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
}
mc := stmt.mc
mc.affectedRows = 0
mc.insertId = 0
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
if resLen > 0 {
// Columns
err = mc.readUntilEOF()
if err != nil {
return nil, err
}
// Rows
err = mc.readUntilEOF()
}
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, nil
}
}
return nil, err
}
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
if stmt.mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
}
mc := stmt.mc
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
rows := new(binaryRows)
rows.mc = mc
if resLen > 0 {
// Columns
// If not cached, read them and cache them
if stmt.columns == nil {
rows.columns, err = mc.readColumns(resLen)
stmt.columns = rows.columns
} else {
rows.columns = stmt.columns
err = mc.readUntilEOF()
}
}
return rows, err
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlTx struct {
mc *mysqlConn
}
func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return errInvalidConn
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}
func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return errInvalidConn
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/sha1"
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net/url"
"os"
"strings"
"time"
)
var (
errLog *log.Logger // Error Logger
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
)
func init() {
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
tlsConfigRegister = make(map[string]*tls.Config)
}
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
// Use the key as a value in the DSN where tls=value.
//
// rootCertPool := x509.NewCertPool()
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
// if err != nil {
// log.Fatal(err)
// }
// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
// log.Fatal("Failed to append PEM.")
// }
// clientCert := make([]tls.Certificate, 0, 1)
// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem")
// if err != nil {
// log.Fatal(err)
// }
// clientCert = append(clientCert, certs)
// mysql.RegisterTLSConfig("custom", &tls.Config{
// RootCAs: rootCertPool,
// Certificates: clientCert,
// })
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
//
func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
return fmt.Errorf("Key '%s' is reserved", key)
}
tlsConfigRegister[key] = config
return nil
}
// DeregisterTLSConfig removes the tls.Config associated with key.
func DeregisterTLSConfig(key string) {
delete(tlsConfigRegister, key)
}
// parseDSN parses the DSN string to a config
func parseDSN(dsn string) (cfg *config, err error) {
cfg = new(config)
// TODO: use strings.IndexByte when we can depend on Go 1.2
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.passwd = dsn[k+1 : j]
break
}
}
cfg.user = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an adress is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.addr = dsn[k+1 : i-1]
break
}
}
cfg.net = dsn[j+1 : k]
}
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
cfg.dbname = dsn[i+1 : j]
break
}
}
// Set default network if empty
if cfg.net == "" {
cfg.net = "tcp"
}
// Set default adress if empty
if cfg.addr == "" {
switch cfg.net {
case "tcp":
cfg.addr = "127.0.0.1:3306"
case "unix":
cfg.addr = "/tmp/mysql.sock"
default:
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
}
}
// Set default location if empty
if cfg.loc == nil {
cfg.loc = time.UTC
}
return
}
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}
// cfg params
switch value := param[1]; param[0] {
// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.allowAllFiles, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.clientFoundRows, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.allowOldPasswords, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Time Location
case "loc":
if value, err = url.QueryUnescape(value); err != nil {
return
}
cfg.loc, err = time.LoadLocation(value)
if err != nil {
return
}
// Dial Timeout
case "timeout":
cfg.timeout, err = time.ParseDuration(value)
if err != nil {
return
}
// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.tls = &tls.Config{}
}
} else {
if strings.ToLower(value) == "skip-verify" {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
cfg.tls = tlsConfig
} else {
return fmt.Errorf("Invalid value / unknown config name: %s", value)
}
}
default:
// lazy init
if cfg.params == nil {
cfg.params = make(map[string]string)
}
if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
}
return
}
// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}
// Not a valid bool value
return
}
/******************************************************************************
* Authentication *
******************************************************************************/
// Encrypt password using 4.1+ method
func scramblePassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}
// stage1Hash = SHA1(password)
crypt := sha1.New()
crypt.Write(password)
stage1 := crypt.Sum(nil)
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
// inner Hash
crypt.Reset()
crypt.Write(stage1)
hash := crypt.Sum(nil)
// outer Hash
crypt.Reset()
crypt.Write(scramble)
crypt.Write(hash)
scramble = crypt.Sum(nil)
// token = scrambleHash XOR stage1Hash
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
// Encrypt password using pre 4.1 (old password) method
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
type myRnd struct {
seed1, seed2 uint32
}
const myRndMaxVal = 0x3FFFFFFF
// Pseudo random number generator
func newMyRnd(seed1, seed2 uint32) *myRnd {
return &myRnd{
seed1: seed1 % myRndMaxVal,
seed2: seed2 % myRndMaxVal,
}
}
// Tested to be equivalent to MariaDB's floating point variant
// http://play.golang.org/p/QHvhd4qved
// http://play.golang.org/p/RG0q4ElWDx
func (r *myRnd) NextByte() byte {
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
}
// Generate binary hash from byte string using insecure pre 4.1 method
func pwHash(password []byte) (result [2]uint32) {
var add uint32 = 7
var tmp uint32
result[0] = 1345345333
result[1] = 0x12345671
for _, c := range password {
// skip spaces and tabs in password
if c == ' ' || c == '\t' {
continue
}
tmp = uint32(c)
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
result[1] += (result[1] << 8) ^ result[0]
add += tmp
}
// Remove sign bit (1<<31)-1)
result[0] &= 0x7FFFFFFF
result[1] &= 0x7FFFFFFF
return
}
// Encrypt password using insecure pre 4.1 method
func scrambleOldPassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}
scramble = scramble[:8]
hashPw := pwHash(password)
hashSc := pwHash(scramble)
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
var out [8]byte
for i := range out {
out[i] = r.NextByte() + 64
}
mask := r.NextByte()
for i := range out {
out[i] ^= mask
}
return out[:]
}
/******************************************************************************
* Time related utils *
******************************************************************************/
// NullTime represents a time.Time that may be NULL.
// NullTime implements the Scanner interface so
// it can be used as a scan destination:
//
// var nt NullTime
// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
// ...
// if nt.Valid {
// // use nt.Time
// } else {
// // NULL value
// }
//
// This NullTime implementation is not driver-specific
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails.
func (nt *NullTime) Scan(value interface{}) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
}
switch v := value.(type) {
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, err = parseDateTime(string(v), time.UTC)
nt.Valid = (err == nil)
return
case string:
nt.Time, err = parseDateTime(v, time.UTC)
nt.Valid = (err == nil)
return
}
nt.Valid = false
return fmt.Errorf("Can't convert %T to time.Time", value)
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
switch len(str) {
case 10: // YYYY-MM-DD
if str == "0000-00-00" {
return
}
t, err = time.Parse(timeFormat[:10], str)
case 19: // YYYY-MM-DD HH:MM:SS
if str == "0000-00-00 00:00:00" {
return
}
t, err = time.Parse(timeFormat, str)
default:
err = fmt.Errorf("Invalid Time-String: %s", str)
return
}
// Adjust location
if err == nil && loc != time.UTC {
y, mo, d := t.Date()
h, mi, s := t.Clock()
t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
}
return
}
func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {
switch num {
case 0:
return time.Time{}, nil
case 4:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
0, 0, 0, 0,
loc,
), nil
case 7:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
0,
loc,
), nil
case 11:
return time.Date(
int(binary.LittleEndian.Uint16(data[:2])), // year
time.Month(data[2]), // month
int(data[3]), // day
int(data[4]), // hour
int(data[5]), // minutes
int(data[6]), // seconds
int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds
loc,
), nil
}
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
}
func formatBinaryDate(num uint64, data []byte) (driver.Value, error) {
switch num {
case 0:
return []byte("0000-00-00"), nil
case 4:
return []byte(fmt.Sprintf(
"%04d-%02d-%02d",
binary.LittleEndian.Uint16(data[:2]),
data[2],
data[3],
)), nil
}
return nil, fmt.Errorf("Invalid DATE-packet length %d", num)
}
func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) {
switch num {
case 0:
return []byte("0000-00-00 00:00:00"), nil
case 4:
return []byte(fmt.Sprintf(
"%04d-%02d-%02d 00:00:00",
binary.LittleEndian.Uint16(data[:2]),
data[2],
data[3],
)), nil
case 7:
return []byte(fmt.Sprintf(
"%04d-%02d-%02d %02d:%02d:%02d",
binary.LittleEndian.Uint16(data[:2]),
data[2],
data[3],
data[4],
data[5],
data[6],
)), nil
case 11:
return []byte(fmt.Sprintf(
"%04d-%02d-%02d %02d:%02d:%02d.%06d",
binary.LittleEndian.Uint16(data[:2]),
data[2],
data[3],
data[4],
data[5],
data[6],
binary.LittleEndian.Uint32(data[7:11]),
)), nil
}
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
}
/******************************************************************************
* Convert from and to bytes *
******************************************************************************/
func uint64ToBytes(n uint64) []byte {
return []byte{
byte(n),
byte(n >> 8),
byte(n >> 16),
byte(n >> 24),
byte(n >> 32),
byte(n >> 40),
byte(n >> 48),
byte(n >> 56),
}
}
func uint64ToString(n uint64) []byte {
var a [20]byte
i := 20
// U+0030 = 0
// ...
// U+0039 = 9
var q uint64
for n >= 10 {
i--
q = n / 10
a[i] = uint8(n-q*10) + 0x30
n = q
}
i--
a[i] = uint8(n) + 0x30
return a[i:]
}
// treats string value as unsigned integer representation
func stringToInt(b []byte) int {
val := 0
for i := range b {
val *= 10
val += int(b[i] - 0x30)
}
return val
}
// returns the string read as a bytes slice, wheter the value is NULL,
// the number of bytes read and an error, in case the string is longer than
// the input slice
func readLengthEnodedString(b []byte) ([]byte, bool, int, error) {
// Get length
num, isNull, n := readLengthEncodedInteger(b)
if num < 1 {
return b[n:n], isNull, n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return b[n-int(num) : n], false, n, nil
}
return nil, false, n, io.EOF
}
// returns the number of bytes skipped and an error, in case the string is
// longer than the input slice
func skipLengthEnodedString(b []byte) (int, error) {
// Get length
num, _, n := readLengthEncodedInteger(b)
if num < 1 {
return n, nil
}
n += int(num)
// Check data length
if len(b) >= n {
return n, nil
}
return n, io.EOF
}
// returns the number read, whether the value is NULL and the number of bytes read
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
switch b[0] {
// 251: NULL
case 0xfb:
return 0, true, 1
// 252: value of following 2
case 0xfc:
return uint64(b[1]) | uint64(b[2])<<8, false, 3
// 253: value of following 3
case 0xfd:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
// 254: value of following 8
case 0xfe:
return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
uint64(b[7])<<48 | uint64(b[8])<<54,
false, 9
}
// 0-250: value of first byte
return uint64(b[0]), false, 1
}
// encodes a uint64 value and appends it to the given bytes slice
func appendLengthEncodedInteger(b []byte, n uint64) []byte {
switch {
case n <= 250:
return append(b, byte(n))
case n <= 0xffff:
return append(b, 0xfc, byte(n), byte(n>>8))
case n <= 0xffffff:
return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16))
}
return b
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"fmt"
"testing"
"time"
)
var testDSNs = []struct {
in string
out string
loc *time.Location
}{
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
}
func TestDSNParser(t *testing.T) {
var cfg *config
var err error
var res string
for i, tst := range testDSNs {
cfg, err = parseDSN(tst.in)
if err != nil {
t.Error(err.Error())
}
// pointer not static
cfg.tls = nil
res = fmt.Sprintf("%+v", cfg)
if res != fmt.Sprintf(tst.out, tst.loc) {
t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc))
}
}
}
func TestDSNParserInvalid(t *testing.T) {
var invalidDSNs = []string{
"@net(addr/", // no closing brace
"@tcp(/", // no closing brace
"tcp(/", // no closing brace
"(/", // no closing brace
"net(addr)//", // unescaped
//"/dbname?arg=/some/unescaped/path",
}
for i, tst := range invalidDSNs {
if _, err := parseDSN(tst); err == nil {
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
}
}
}
func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for _, tst := range testDSNs {
if _, err := parseDSN(tst.in); err != nil {
b.Error(err.Error())
}
}
}
}
func TestScanNullTime(t *testing.T) {
var scanTests = []struct {
in interface{}
error bool
valid bool
time time.Time
}{
{tDate, false, true, tDate},
{sDate, false, true, tDate},
{[]byte(sDate), false, true, tDate},
{tDateTime, false, true, tDateTime},
{sDateTime, false, true, tDateTime},
{[]byte(sDateTime), false, true, tDateTime},
{tDate0, false, true, tDate0},
{sDate0, false, true, tDate0},
{[]byte(sDate0), false, true, tDate0},
{sDateTime0, false, true, tDate0},
{[]byte(sDateTime0), false, true, tDate0},
{"", true, false, tDate0},
{"1234", true, false, tDate0},
{0, true, false, tDate0},
}
var nt = NullTime{}
var err error
for _, tst := range scanTests {
err = nt.Scan(tst.in)
if (err != nil) != tst.error {
t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil))
}
if nt.Valid != tst.valid {
t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid)
}
if nt.Time != tst.time {
t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
}
}
}
The MIT License (MIT)
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
# GORM
[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
The fantastic ORM library for Golang, aims to be developer friendly.
[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921)
## Overview
* Full-Featured ORM (almost)
* Chainable API
* Auto Migrations
* Relations (Has One, Has Many, Belongs To, Many To Many, [Polymorphism](#polymorphism))
* Callbacks (Before/After Create/Save/Update/Delete/Find)
* Preloading (eager loading)
* Transactions
* Embed Anonymous Struct
* Soft Deletes
* Customizable Logger
* Iteration Support via [Rows](#row--rows)
* Every feature comes with tests
* Developer Friendly
# Getting Started
## Install
```
go get -u github.com/jinzhu/gorm
```
## Define Models (Structs)
```go
type User struct {
ID int
Birthday time.Time
Age int
Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag
Num int `sql:"AUTO_INCREMENT"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
Emails []Email // One-To-Many relationship (has many)
BillingAddress Address // One-To-One relationship (has one)
BillingAddressID sql.NullInt64 // Foreign key of BillingAddress
ShippingAddress Address // One-To-One relationship (has one)
ShippingAddressID int // Foreign key of ShippingAddress
IgnoreMe int `sql:"-"` // Ignore this field
Languages []Language `gorm:"many2many:user_languages;"` // Many-To-Many relationship, 'user_languages' is join table
}
type Email struct {
ID int
UserID int `sql:"index"` // Foreign key (belongs to), tag `index` will create index for this field when using AutoMigrate
Email string `sql:"type:varchar(100);unique_index"` // Set field's sql type, tag `unique_index` will create unique index
Subscribed bool
}
type Address struct {
ID int
Address1 string `sql:"not null;unique"` // Set field as not nullable and unique
Address2 string `sql:"type:varchar(100);unique"`
Post sql.NullString `sql:"not null"`
}
type Language struct {
ID int
Name string `sql:"index:idx_name_code"` // Create index with name, and will create combined index if find other fields defined same name
Code string `sql:"index:idx_name_code"` // `unique_index` also works
}
```
## Conventions
* Table name is the plural of struct name's snake case, you can disable pluralization with `db.SingularTable(true)`, or [Specifying The Table Name For A Struct Permanently With TableName](#specifying-the-table-name-for-a-struct-permanently-with-tablename)
```go
type User struct{} // struct User's database table name is "users" by default, will be "user" if you disabled pluralisation
```
* Column name is the snake case of field's name
* Use `ID` field as primary key
* Use `CreatedAt` to store record's created time if field exists
* Use `UpdatedAt` to store record's updated time if field exists
* Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete)
* Gorm provide a default model struct, you could embed it in your struct
```go
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}
type User struct {
gorm.Model
Name string
}
```
## Initialize Database
```go
import (
"github.com/jinzhu/gorm"
_ "github.com/lib/pq"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
)
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB.
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
// db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
// You can also use an existing database connection handle
// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db, _ := gorm.Open("postgres", dbSql)
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
db.DB()
// Then you could invoke `*sql.DB`'s functions with it
db.DB().Ping()
db.DB().SetMaxIdleConns(10)
db.DB().SetMaxOpenConns(100)
// Disable table name's pluralization
db.SingularTable(true)
```
## Migration
```go
// Create table
db.CreateTable(&User{})
// Drop table
db.DropTable(&User{})
// Automating Migration
db.AutoMigrate(&User{})
db.AutoMigrate(&User{}, &Product{}, &Order{})
// Feel free to change your struct, AutoMigrate will keep your database up-to-date.
// AutoMigrate will ONLY add *new columns* and *new indexes*,
// WON'T update current column's type or delete unused columns, to protect your data.
// If the table is not existing, AutoMigrate will create the table automatically.
```
# Basic CRUD
## Create Record
```go
user := User{Name: "Jinzhu", Age: 18, Birthday: time.Now()}
db.NewRecord(user) // => returns `true` if primary key is blank
db.Create(&user)
db.NewRecord(user) // => return `false` after `user` created
// Associations will be inserted automatically when save the record
user := User{
Name: "jinzhu",
BillingAddress: Address{Address1: "Billing Address - Address 1"},
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
Languages: []Language{{Name: "ZH"}, {Name: "EN"}},
}
db.Create(&user)
//// BEGIN TRANSACTION;
//// INSERT INTO "addresses" (address1) VALUES ("Billing Address - Address 1");
//// INSERT INTO "addresses" (address1) VALUES ("Shipping Address - Address 1");
//// INSERT INTO "users" (name,billing_address_id,shipping_address_id) VALUES ("jinzhu", 1, 2);
//// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu@example.com");
//// INSERT INTO "emails" (user_id,email) VALUES (111, "jinzhu-2@example.com");
//// INSERT INTO "languages" ("name") VALUES ('ZH');
//// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 1);
//// INSERT INTO "languages" ("name") VALUES ('EN');
//// INSERT INTO user_languages ("user_id","language_id") VALUES (111, 2);
//// COMMIT;
```
Refer [Associations](#associations) for more details
## Query
```go
// Get the first record
db.First(&user)
//// SELECT * FROM users ORDER BY id LIMIT 1;
// Get the last record
db.Last(&user)
//// SELECT * FROM users ORDER BY id DESC LIMIT 1;
// Get all records
db.Find(&users)
//// SELECT * FROM users;
// Get record with primary key
db.First(&user, 10)
//// SELECT * FROM users WHERE id = 10;
```
### Query With Where (Plain SQL)
```go
// Get the first matched record
db.Where("name = ?", "jinzhu").First(&user)
//// SELECT * FROM users WHERE name = 'jinzhu' limit 1;
// Get all matched records
db.Where("name = ?", "jinzhu").Find(&users)
//// SELECT * FROM users WHERE name = 'jinzhu';
db.Where("name <> ?", "jinzhu").Find(&users)
// IN
db.Where("name in (?)", []string{"jinzhu", "jinzhu 2"}).Find(&users)
// LIKE
db.Where("name LIKE ?", "%jin%").Find(&users)
// AND
db.Where("name = ? and age >= ?", "jinzhu", "22").Find(&users)
// Time
db.Where("updated_at > ?", lastWeek).Find(&users)
db.Where("created_at BETWEEN ? AND ?", lastWeek, today).Find(&users)
```
### Query With Where (Struct & Map)
```go
// Struct
db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
//// SELECT * FROM users WHERE name = "jinzhu" AND age = 20 LIMIT 1;
// Map
db.Where(map[string]interface{}{"name": "jinzhu", "age": 20}).Find(&users)
//// SELECT * FROM users WHERE name = "jinzhu" AND age = 20;
// Slice of primary keys
db.Where([]int64{20, 21, 22}).Find(&users)
//// SELECT * FROM users WHERE id IN (20, 21, 22);
```
### Query With Not
```go
db.Not("name", "jinzhu").First(&user)
//// SELECT * FROM users WHERE name <> "jinzhu" LIMIT 1;
// Not In
db.Not("name", []string{"jinzhu", "jinzhu 2"}).Find(&users)
//// SELECT * FROM users WHERE name NOT IN ("jinzhu", "jinzhu 2");
// Not In slice of primary keys
db.Not([]int64{1,2,3}).First(&user)
//// SELECT * FROM users WHERE id NOT IN (1,2,3);
db.Not([]int64{}).First(&user)
//// SELECT * FROM users;
// Plain SQL
db.Not("name = ?", "jinzhu").First(&user)
//// SELECT * FROM users WHERE NOT(name = "jinzhu");
// Struct
db.Not(User{Name: "jinzhu"}).First(&user)
//// SELECT * FROM users WHERE name <> "jinzhu";
```
### Query With Inline Condition
```go
// Get by primary key
db.First(&user, 23)
//// SELECT * FROM users WHERE id = 23 LIMIT 1;
// Plain SQL
db.Find(&user, "name = ?", "jinzhu")
//// SELECT * FROM users WHERE name = "jinzhu";
db.Find(&users, "name <> ? AND age > ?", "jinzhu", 20)
//// SELECT * FROM users WHERE name <> "jinzhu" AND age > 20;
// Struct
db.Find(&users, User{Age: 20})
//// SELECT * FROM users WHERE age = 20;
// Map
db.Find(&users, map[string]interface{}{"age": 20})
//// SELECT * FROM users WHERE age = 20;
```
### Query With Or
```go
db.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&users)
//// SELECT * FROM users WHERE role = 'admin' OR role = 'super_admin';
// Struct
db.Where("name = 'jinzhu'").Or(User{Name: "jinzhu 2"}).Find(&users)
//// SELECT * FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2';
// Map
db.Where("name = 'jinzhu'").Or(map[string]interface{}{"name": "jinzhu 2"}).Find(&users)
```
### Query Chains
Gorm has a chainable API, you could use it like this
```go
db.Where("name <> ?","jinzhu").Where("age >= ? and role <> ?",20,"admin").Find(&users)
//// SELECT * FROM users WHERE name <> 'jinzhu' AND age >= 20 AND role <> 'admin';
db.Where("role = ?", "admin").Or("role = ?", "super_admin").Not("name = ?", "jinzhu").Find(&users)
```
### Preloading (Eager loading)
```go
db.Preload("Orders").Find(&users)
//// SELECT * FROM users;
//// SELECT * FROM orders WHERE user_id IN (1,2,3,4);
db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
//// SELECT * FROM users;
//// SELECT * FROM orders WHERE user_id IN (1,2,3,4) AND state NOT IN ('cancelled');
db.Where("state = ?", "active").Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
//// SELECT * FROM users WHERE state = 'active';
//// SELECT * FROM orders WHERE user_id IN (1,2) AND state NOT IN ('cancelled');
db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users)
//// SELECT * FROM users;
//// SELECT * FROM orders WHERE user_id IN (1,2,3,4); // has many
//// SELECT * FROM profiles WHERE user_id IN (1,2,3,4); // has one
//// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to
```
#### Nested Preloading
```go
db.Preload("Orders.OrderItems").Find(&users)
db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users)
```
## Update
```go
// Update an existing struct
db.First(&user)
user.Name = "jinzhu 2"
user.Age = 100
db.Save(&user)
//// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111;
db.Where("active = ?", true).Save(&user)
//// UPDATE users SET name='jinzhu 2', age=100, updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true;
// Update an attribute if it is changed
db.Model(&user).Update("name", "hello")
//// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111;
db.Model(&user).Where("active = ?", true).Update("name", "hello")
//// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111 AND active = true;
db.First(&user, 111).Update("name", "hello")
//// SELECT * FROM users LIMIT 1;
//// UPDATE users SET name='hello', updated_at = '2013-11-17 21:34:10' WHERE id=111;
// Update multiple attributes if they are changed
db.Model(&user).Updates(map[string]interface{}{"name": "hello", "age": 18, "actived": false})
// Update multiple attributes if they are changed (update with struct only works with none zero values)
db.Model(&user).Updates(User{Name: "hello", Age: 18})
//// UPDATE users SET name='hello', age=18, updated_at = '2013-11-17 21:34:10' WHERE id = 111;
```
### Update Without Callbacks
By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations:
```go
db.Model(&user).UpdateColumn("name", "hello")
//// UPDATE users SET name='hello' WHERE id = 111;
// Update with struct only works with none zero values, or use map[string]interface{}
db.Model(&user).UpdateColumns(User{Name: "hello", Age: 18})
//// UPDATE users SET name='hello', age=18 WHERE id = 111;
```
### Batch Updates
```go
db.Table("users").Where("id = ?", 10).Updates(map[string]interface{}{"name": "hello", "age": 18})
//// UPDATE users SET name='hello', age=18 WHERE id = 10;
// Update with struct only works with none zero values, or use map[string]interface{}
db.Model(User{}).Updates(User{Name: "hello", Age: 18})
//// UPDATE users SET name='hello', age=18;
// Callbacks won't run when do batch updates
// Use `RowsAffected` to get the count of affected records
db.Model(User{}).Updates(User{Name: "hello", Age: 18}).RowsAffected
```
### Update with SQL Expression
```go
DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
//// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2';
DB.Model(&product).Updates(map[string]interface{}{"price": gorm.Expr("price * ? + ?", 2, 100)})
//// UPDATE "products" SET "code" = 'L1212', "price" = price * '2' + '100', "updated_at" = '2013-11-17 21:34:10' WHERE "id" = '2';
DB.Model(&product).UpdateColumn("quantity", gorm.Expr("quantity - ?", 1))
//// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2';
DB.Model(&product).Where("quantity > 1").UpdateColumn("quantity", gorm.Expr("quantity - ?", 1))
//// UPDATE "products" SET "quantity" = quantity - 1 WHERE "id" = '2' AND quantity > 1;
```
## Delete
```go
// Delete an existing record
db.Delete(&email)
//// DELETE from emails where id=10;
```
### Batch Delete
```go
db.Where("email LIKE ?", "%jinzhu%").Delete(Email{})
//// DELETE from emails where email LIKE "%jinhu%";
```
### Soft Delete
If struct has `DeletedAt` field, it will get soft delete ability automatically!
Then it won't be deleted from database permanently when call `Delete`.
```go
db.Delete(&user)
//// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE id = 111;
// Batch Delete
db.Where("age = ?", 20).Delete(&User{})
//// UPDATE users SET deleted_at="2013-10-29 10:23" WHERE age = 20;
// Soft deleted records will be ignored when query them
db.Where("age = 20").Find(&user)
//// SELECT * FROM users WHERE age = 20 AND (deleted_at IS NULL OR deleted_at <= '0001-01-02');
// Find soft deleted records with Unscoped
db.Unscoped().Where("age = 20").Find(&users)
//// SELECT * FROM users WHERE age = 20;
// Delete record permanently with Unscoped
db.Unscoped().Delete(&order)
//// DELETE FROM orders WHERE id=10;
```
## Associations
### Has One
```go
// User has one address
db.Model(&user).Related(&address)
//// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key AddressId
// Specify the foreign key
db.Model(&user).Related(&address1, "BillingAddressId")
//// SELECT * FROM addresses WHERE id = 123; // 123 is user's foreign key BillingAddressId
```
### Belongs To
```go
// Email belongs to user
db.Model(&email).Related(&user)
//// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key UserId
// Specify the foreign key
db.Model(&email).Related(&user, "ProfileId")
//// SELECT * FROM users WHERE id = 111; // 111 is email's foreign key ProfileId
```
### Has Many
```go
// User has many emails
db.Model(&user).Related(&emails)
//// SELECT * FROM emails WHERE user_id = 111;
// user_id is the foreign key, 111 is user's primary key's value
// Specify the foreign key
db.Model(&user).Related(&emails, "ProfileId")
//// SELECT * FROM emails WHERE profile_id = 111;
// profile_id is the foreign key, 111 is user's primary key's value
```
### Many To Many
```go
// User has many languages and belongs to many languages
db.Model(&user).Related(&languages, "Languages")
//// SELECT * FROM "languages" INNER JOIN "user_languages" ON "user_languages"."language_id" = "languages"."id" WHERE "user_languages"."user_id" = 111
// `Languages` is user's column name, this column's tag defined join table like this `gorm:"many2many:user_languages;"`
```
There is also a mode used to handle many to many relations easily
```go
// Query
db.Model(&user).Association("Languages").Find(&languages)
// same as `db.Model(&user).Related(&languages, "Languages")`
db.Where("name = ?", "ZH").First(&languageZH)
db.Where("name = ?", "EN").First(&languageEN)
// Append
db.Model(&user).Association("Languages").Append([]Language{languageZH, languageEN})
db.Model(&user).Association("Languages").Append([]Language{{Name: "DE"}})
db.Model(&user).Association("Languages").Append(Language{Name: "DE"})
// Delete
db.Model(&user).Association("Languages").Delete([]Language{languageZH, languageEN})
db.Model(&user).Association("Languages").Delete(languageZH, languageEN)
// Replace
db.Model(&user).Association("Languages").Replace([]Language{languageZH, languageEN})
db.Model(&user).Association("Languages").Replace(Language{Name: "DE"}, languageEN)
// Count
db.Model(&user).Association("Languages").Count()
// Return the count of languages the user has
// Clear
db.Model(&user).Association("Languages").Clear()
// Remove all relations between the user and languages
```
### Polymorphism
Supports polymorphic has-many and has-one associations.
```go
type Cat struct {
Id int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
type Dog struct {
Id int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
type Toy struct {
Id int
Name string
OwnerId int
OwnerType string
}
```
Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors.
## Advanced Usage
## FirstOrInit
Get the first matched record, or initialize a record with search conditions.
```go
// Unfound
db.FirstOrInit(&user, User{Name: "non_existing"})
//// user -> User{Name: "non_existing"}
// Found
db.Where(User{Name: "Jinzhu"}).FirstOrInit(&user)
//// user -> User{Id: 111, Name: "Jinzhu", Age: 20}
db.FirstOrInit(&user, map[string]interface{}{"name": "jinzhu"})
//// user -> User{Id: 111, Name: "Jinzhu", Age: 20}
```
### Attrs
Ignore some values when searching, but use them to initialize the struct if record is not found.
```go
// Unfound
db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrInit(&user)
//// SELECT * FROM USERS WHERE name = 'non_existing';
//// user -> User{Name: "non_existing", Age: 20}
db.Where(User{Name: "noexisting_user"}).Attrs("age", 20).FirstOrInit(&user)
//// SELECT * FROM USERS WHERE name = 'non_existing';
//// user -> User{Name: "non_existing", Age: 20}
// Found
db.Where(User{Name: "Jinzhu"}).Attrs(User{Age: 30}).FirstOrInit(&user)
//// SELECT * FROM USERS WHERE name = jinzhu';
//// user -> User{Id: 111, Name: "Jinzhu", Age: 20}
```
### Assign
Ignore some values when searching, but assign it to the result regardless it is found or not.
```go
// Unfound
db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrInit(&user)
//// user -> User{Name: "non_existing", Age: 20}
// Found
db.Where(User{Name: "Jinzhu"}).Assign(User{Age: 30}).FirstOrInit(&user)
//// SELECT * FROM USERS WHERE name = jinzhu';
//// user -> User{Id: 111, Name: "Jinzhu", Age: 30}
```
## FirstOrCreate
Get the first matched record, or create with search conditions.
```go
// Unfound
db.FirstOrCreate(&user, User{Name: "non_existing"})
//// INSERT INTO "users" (name) VALUES ("non_existing");
//// user -> User{Id: 112, Name: "non_existing"}
// Found
db.Where(User{Name: "Jinzhu"}).FirstOrCreate(&user)
//// user -> User{Id: 111, Name: "Jinzhu"}
```
### Attrs
Ignore some values when searching, but use them to create the struct if record is not found. like `FirstOrInit`
```go
// Unfound
db.Where(User{Name: "non_existing"}).Attrs(User{Age: 20}).FirstOrCreate(&user)
//// SELECT * FROM users WHERE name = 'non_existing';
//// INSERT INTO "users" (name, age) VALUES ("non_existing", 20);
//// user -> User{Id: 112, Name: "non_existing", Age: 20}
// Found
db.Where(User{Name: "jinzhu"}).Attrs(User{Age: 30}).FirstOrCreate(&user)
//// SELECT * FROM users WHERE name = 'jinzhu';
//// user -> User{Id: 111, Name: "jinzhu", Age: 20}
```
### Assign
Ignore some values when searching, but assign it to the record regardless it is found or not, then save back to database. like `FirstOrInit`
```go
// Unfound
db.Where(User{Name: "non_existing"}).Assign(User{Age: 20}).FirstOrCreate(&user)
//// SELECT * FROM users WHERE name = 'non_existing';
//// INSERT INTO "users" (name, age) VALUES ("non_existing", 20);
//// user -> User{Id: 112, Name: "non_existing", Age: 20}
// Found
db.Where(User{Name: "jinzhu"}).Assign(User{Age: 30}).FirstOrCreate(&user)
//// SELECT * FROM users WHERE name = 'jinzhu';
//// UPDATE users SET age=30 WHERE id = 111;
//// user -> User{Id: 111, Name: "jinzhu", Age: 30}
```
## Select
```go
db.Select("name, age").Find(&users)
//// SELECT name, age FROM users;
db.Select([]string{"name", "age"}).Find(&users)
//// SELECT name, age FROM users;
db.Table("users").Select("COALESCE(age,?)", 42).Rows()
//// SELECT COALESCE(age,'42') FROM users;
```
## Order
```go
db.Order("age desc, name").Find(&users)
//// SELECT * FROM users ORDER BY age desc, name;
// Multiple orders
db.Order("age desc").Order("name").Find(&users)
//// SELECT * FROM users ORDER BY age desc, name;
// ReOrder
db.Order("age desc").Find(&users1).Order("age", true).Find(&users2)
//// SELECT * FROM users ORDER BY age desc; (users1)
//// SELECT * FROM users ORDER BY age; (users2)
```
## Limit
```go
db.Limit(3).Find(&users)
//// SELECT * FROM users LIMIT 3;
// Cancel limit condition with -1
db.Limit(10).Find(&users1).Limit(-1).Find(&users2)
//// SELECT * FROM users LIMIT 10; (users1)
//// SELECT * FROM users; (users2)
```
## Offset
```go
db.Offset(3).Find(&users)
//// SELECT * FROM users OFFSET 3;
// Cancel offset condition with -1
db.Offset(10).Find(&users1).Offset(-1).Find(&users2)
//// SELECT * FROM users OFFSET 10; (users1)
//// SELECT * FROM users; (users2)
```
## Count
```go
db.Where("name = ?", "jinzhu").Or("name = ?", "jinzhu 2").Find(&users).Count(&count)
//// SELECT * from USERS WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (users)
//// SELECT count(*) FROM users WHERE name = 'jinzhu' OR name = 'jinzhu 2'; (count)
db.Model(User{}).Where("name = ?", "jinzhu").Count(&count)
//// SELECT count(*) FROM users WHERE name = 'jinzhu'; (count)
db.Table("deleted_users").Count(&count)
//// SELECT count(*) FROM deleted_users;
```
## Pluck
Get selected attributes as map
```go
var ages []int64
db.Find(&users).Pluck("age", &ages)
var names []string
db.Model(&User{}).Pluck("name", &names)
db.Table("deleted_users").Pluck("name", &names)
// Requesting more than one column? Do it like this:
db.Select("name, age").Find(&users)
```
## Raw SQL
```go
db.Exec("DROP TABLE users;")
db.Exec("UPDATE orders SET shipped_at=? WHERE id IN (?)", time.Now, []int64{11,22,33})
```
## Row & Rows
It is even possible to get query result as `*sql.Row` or `*sql.Rows`
```go
row := db.Table("users").Where("name = ?", "jinzhu").Select("name, age").Row() // (*sql.Row)
row.Scan(&name, &age)
rows, err := db.Model(User{}).Where("name = ?", "jinzhu").Select("name, age, email").Rows() // (*sql.Rows, error)
defer rows.Close()
for rows.Next() {
...
rows.Scan(&name, &age, &email)
...
}
// Raw SQL
rows, err := db.Raw("select name, age, email from users where name = ?", "jinzhu").Rows() // (*sql.Rows, error)
defer rows.Close()
for rows.Next() {
...
rows.Scan(&name, &age, &email)
...
}
```
## Scan
Scan results into another struct.
```go
type Result struct {
Name string
Age int
}
var result Result
db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&result)
// Raw SQL
db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
```
## Group & Having
```go
rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Rows()
for rows.Next() {
...
}
rows, err := db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Rows()
for rows.Next() {
...
}
type Result struct {
Date time.Time
Total int64
}
db.Table("orders").Select("date(created_at) as date, sum(amount) as total").Group("date(created_at)").Having("sum(amount) > ?", 100).Scan(&results)
```
## Joins
```go
rows, err := db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Rows()
for rows.Next() {
...
}
db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results)
// find a user by email address
db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user)
// find all email addresses for a user
db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails)
```
## Transactions
To perform a set of operations within a transaction, the general flow is as below.
The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction.
(Note that all individual save and delete operations are run in a transaction by default.)
```go
// begin
tx := db.Begin()
// do some database operations (use 'tx' from this point, not 'db')
tx.Create(...)
...
// rollback in case of error
tx.Rollback()
// Or commit if all is ok
tx.Commit()
```
### A Specific Example
```
func CreateAnimals(db *gorm.DB) err {
tx := db.Begin()
// Note the use of tx as the database handle once you are within a transaction
if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil {
tx.Rollback()
return err
}
if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
```
## Scopes
```go
func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
return db.Where("amount > ?", 1000)
}
func PaidWithCreditCard(db *gorm.DB) *gorm.DB {
return db.Where("pay_mode_sign = ?", "C")
}
func PaidWithCod(db *gorm.DB) *gorm.DB {
return db.Where("pay_mode_sign = ?", "C")
}
func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
return func (db *gorm.DB) *gorm.DB {
return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
}
}
db.Scopes(AmountGreaterThan1000, PaidWithCreditCard).Find(&orders)
// Find all credit card orders and amount greater than 1000
db.Scopes(AmountGreaterThan1000, PaidWithCod).Find(&orders)
// Find all COD orders and amount greater than 1000
db.Scopes(OrderStatus([]string{"paid", "shipped"})).Find(&orders)
// Find all paid, shipped orders
```
## Callbacks
Callbacks are methods defined on the pointer of struct.
If any callback returns an error, gorm will stop future operations and rollback all changes.
Here is the list of all available callbacks:
(listed in the same order in which they will get called during the respective operations)
### Creating An Object
```go
BeforeSave
BeforeCreate
// save before associations
// save self
// save after associations
AfterCreate
AfterSave
```
### Updating An Object
```go
BeforeSave
BeforeUpdate
// save before associations
// save self
// save after associations
AfterUpdate
AfterSave
```
### Destroying An Object
```go
BeforeDelete
// delete self
AfterDelete
```
### After Find
```go
// load data from database
AfterFind
```
### Example
```go
func (u *User) BeforeUpdate() (err error) {
if u.readonly() {
err = errors.New("read only user")
}
return
}
// Rollback the insertion if user's id greater than 1000
func (u *User) AfterCreate() (err error) {
if (u.Id > 1000) {
err = errors.New("user id is already greater than 1000")
}
return
}
```
As you know, save/delete operations in gorm are running in a transaction,
This is means if changes made in the transaction is not visiable unless it is commited,
So if you want to use those changes in your callbacks, you need to run SQL in same transaction.
Fortunately, gorm support pass transaction to callbacks as you needed, you could do it like this:
```go
func (u *User) AfterCreate(tx *gorm.DB) (err error) {
tx.Model(u).Update("role", "admin")
return
}
```
## Specifying The Table Name
```go
// Create `deleted_users` table with struct User's definition
db.Table("deleted_users").CreateTable(&User{})
var deleted_users []User
db.Table("deleted_users").Find(&deleted_users)
//// SELECT * FROM deleted_users;
db.Table("deleted_users").Where("name = ?", "jinzhu").Delete()
//// DELETE FROM deleted_users WHERE name = 'jinzhu';
```
### Specifying The Table Name For A Struct Permanently with TableName
```go
type Cart struct {
}
func (c Cart) TableName() string {
return "shopping_cart"
}
func (u User) TableName() string {
if u.Role == "admin" {
return "admin_users"
} else {
return "users"
}
}
```
## Error Handling
```go
query := db.Where("name = ?", "jinzhu").First(&user)
query := db.First(&user).Limit(10).Find(&users)
// query.Error will return the last happened error
// So you could do error handing in your application like this:
if err := db.Where("name = ?", "jinzhu").First(&user).Error; err != nil {
// error handling...
}
// RecordNotFound
// If no record found when you query data, gorm will return RecordNotFound error, you could check it like this:
db.Where("name = ?", "hello world").First(&User{}).Error == gorm.RecordNotFound
// Or use the shortcut method
db.Where("name = ?", "hello world").First(&user).RecordNotFound()
if db.Model(&user).Related(&credit_card).RecordNotFound() {
// no credit card found error handling
}
```
## Logger
Gorm has built-in logger support
```go
// Enable Logger
db.LogMode(true)
// Diable Logger
db.LogMode(false)
// Debug a single operation
db.Debug().Where("name = ?", "jinzhu").First(&User{})
```
![logger](https://raw.github.com/jinzhu/gorm/master/images/logger.png)
### Customize Logger
```go
// Refer gorm's default logger for how to: https://github.com/jinzhu/gorm/blob/master/logger.go#files
db.SetLogger(gorm.Logger{revel.TRACE})
db.SetLogger(log.New(os.Stdout, "\r\n", 0))
```
## Existing Schema
If you have an existing database schema, and the primary key field is different from `id`, you can add a tag to the field structure to specify that this field is a primary key.
```go
type Animal struct {
AnimalId int64 `gorm:"primary_key"`
Birthday time.Time `sql:"DEFAULT:current_timestamp"`
Name string `sql:"default:'galeone'"`
Age int64
}
```
If your column names differ from the struct fields, you can specify them like this:
```go
type Animal struct {
AnimalId int64 `gorm:"column:beast_id;primary_key"`
Birthday time.Time `gorm:"column:day_of_the_beast"`
Age int64 `gorm:"column:age_of_the_beast"`
}
```
## Composite Primary Key
```go
type Product struct {
ID string `gorm:"primary_key"`
LanguageCode string `gorm:"primary_key"`
}
```
## Database Indexes & Foreign Key
```go
// Add foreign key
// 1st param : foreignkey field
// 2nd param : destination table(id)
// 3rd param : ONDELETE
// 4th param : ONUPDATE
db.Model(&User{}).AddForeignKey("role_id", "roles", "CASCADE", "RESTRICT")
// Add index
db.Model(&User{}).AddIndex("idx_user_name", "name")
// Multiple column index
db.Model(&User{}).AddIndex("idx_user_name_age", "name", "age")
// Add unique index
db.Model(&User{}).AddUniqueIndex("idx_user_name", "name")
// Multiple column unique index
db.Model(&User{}).AddUniqueIndex("idx_user_name_age", "name", "age")
// Remove index
db.Model(&User{}).RemoveIndex("idx_user_name")
```
## Default values
If you have defined a default value in the `sql` tag (see the struct Animal above) the generated create/update SQl will ignore these fields if is set blank data.
Eg.
```go
db.Create(&Animal{Age: 99, Name: ""})
```
The generated query will be:
```sql
INSERT INTO animals("age") values('99');
```
The same thing occurs in update statements.
## More examples with query chain
```go
db.First(&first_article).Count(&total_count).Limit(10).Find(&first_page_articles).Offset(10).Find(&second_page_articles)
//// SELECT * FROM articles LIMIT 1; (first_article)
//// SELECT count(*) FROM articles; (total_count)
//// SELECT * FROM articles LIMIT 10; (first_page_articles)
//// SELECT * FROM articles LIMIT 10 OFFSET 10; (second_page_articles)
db.Where("created_at > ?", "2013-10-10").Find(&cancelled_orders, "state = ?", "cancelled").Find(&shipped_orders, "state = ?", "shipped")
//// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'cancelled'; (cancelled_orders)
//// SELECT * FROM orders WHERE created_at > '2013/10/10' AND state = 'shipped'; (shipped_orders)
// Use variables to keep query chain
todays_orders := db.Where("created_at > ?", "2013-10-29")
cancelled_orders := todays_orders.Where("state = ?", "cancelled")
shipped_orders := todays_orders.Where("state = ?", "shipped")
// Search with shared conditions for different tables
db.Where("product_name = ?", "fancy_product").Find(&orders).Find(&shopping_carts)
//// SELECT * FROM orders WHERE product_name = 'fancy_product'; (orders)
//// SELECT * FROM carts WHERE product_name = 'fancy_product'; (shopping_carts)
// Search with shared conditions from different tables with specified table
db.Where("mail_type = ?", "TEXT").Find(&users1).Table("deleted_users").Find(&users2)
//// SELECT * FROM users WHERE mail_type = 'TEXT'; (users1)
//// SELECT * FROM deleted_users WHERE mail_type = 'TEXT'; (users2)
// FirstOrCreate example
db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111"}).FirstOrCreate(&user)
//// SELECT * FROM users WHERE email = 'x@example.org';
//// INSERT INTO "users" (email,registered_ip) VALUES ("x@example.org", "111.111.111.111") // if record not found
```
## TODO
* db.Select("Languages", "Name").Update(&user)
db.Omit("Languages").Update(&user)
* Auto migrate indexes
* Github Pages
* AlertColumn, DropColumn
* R/W Splitting, Validation
# Author
**jinzhu**
* <http://github.com/jinzhu>
* <wosmvp@gmail.com>
* <http://twitter.com/zhangjinzhu>
## License
Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License).
[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.png)](http://godoc.org/github.com/jinzhu/gorm)
package gorm
import (
"errors"
"fmt"
"reflect"
)
type Association struct {
Scope *Scope
PrimaryKey interface{}
Column string
Error error
Field *Field
}
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
}
return association
}
func (association *Association) Find(value interface{}) *Association {
association.Scope.related(value, association.Column)
return association.setErr(association.Scope.db.Error)
}
func (association *Association) Append(values ...interface{}) *Association {
scope := association.Scope
field := association.Field
for _, value := range values {
reflectvalue := reflect.Indirect(reflect.ValueOf(value))
if reflectvalue.Kind() == reflect.Struct {
field.Set(reflect.Append(field.Field, reflectvalue))
} else if reflectvalue.Kind() == reflect.Slice {
field.Set(reflect.AppendSlice(field.Field, reflectvalue))
} else {
association.setErr(errors.New("invalid association type"))
}
}
scope.Search.Select(association.Column)
scope.callCallbacks(scope.db.parent.callback.updates)
return association.setErr(scope.db.Error)
}
func (association *Association) getPrimaryKeys(values ...interface{}) []interface{} {
primaryKeys := []interface{}{}
scope := association.Scope
for _, value := range values {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
if reflectValue.Kind() == reflect.Slice {
for i := 0; i < reflectValue.Len(); i++ {
if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank {
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
}
}
} else if reflectValue.Kind() == reflect.Struct {
if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank {
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
}
}
}
return primaryKeys
}
func (association *Association) Delete(values ...interface{}) *Association {
primaryKeys := association.getPrimaryKeys(values...)
if len(primaryKeys) == 0 {
association.setErr(errors.New("no primary key found"))
} else {
scope := association.Scope
relationship := association.Field.Relationship
// many to many
if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ {
value := association.Field.Field.Index(i)
if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil {
var included = false
for _, primaryKey := range primaryKeys {
if equalAsString(primaryKey, primaryField.Field.Interface()) {
included = true
}
}
if !included {
leftValues = reflect.Append(leftValues, value)
}
}
}
association.Field.Set(leftValues)
}
} else {
association.setErr(errors.New("delete only support many to many"))
}
}
return association
}
func (association *Association) Replace(values ...interface{}) *Association {
relationship := association.Field.Relationship
scope := association.Scope
if relationship.Kind == "many_to_many" {
field := association.Field.Field
oldPrimaryKeys := association.getPrimaryKeys(field.Interface())
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
association.Append(values...)
newPrimaryKeys := association.getPrimaryKeys(field.Interface())
var addedPrimaryKeys = []interface{}{}
for _, newKey := range newPrimaryKeys {
hasEqual := false
for _, oldKey := range oldPrimaryKeys {
if reflect.DeepEqual(newKey, oldKey) {
hasEqual = true
break
}
}
if !hasEqual {
addedPrimaryKeys = append(addedPrimaryKeys, newKey)
}
}
for _, primaryKey := range association.getPrimaryKeys(values...) {
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
}
if len(addedPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship))
}
} else {
association.setErr(errors.New("replace only support many to many"))
}
return association
}
func (association *Association) Clear() *Association {
relationship := association.Field.Relationship
scope := association.Scope
if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey)
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
} else {
association.setErr(err)
}
} else {
association.setErr(errors.New("clear only support many to many"))
}
return association
}
func (association *Association) Count() int {
count := -1
relationship := association.Field.Relationship
scope := association.Scope
newScope := scope.New(association.Field.Field.Interface())
if relationship.Kind == "many_to_many" {
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
if relationship.PolymorphicType != "" {
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
countScope.Count(&count)
} else if relationship.Kind == "belongs_to" {
if v, ok := scope.FieldByName(association.Column); ok {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
scope.DB().Table(newScope.TableName()).Where(whereSql, v).Count(&count)
}
}
return count
}
package gorm_test
import (
"fmt"
"testing"
)
func TestHasOneAndHasManyAssociation(t *testing.T) {
DB.DropTable(Category{})
DB.DropTable(Post{})
DB.DropTable(Comment{})
DB.CreateTable(Category{})
DB.CreateTable(Post{})
DB.CreateTable(Comment{})
post := Post{
Title: "post 1",
Body: "body 1",
Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
Category: Category{Name: "Category 1"},
MainCategory: Category{Name: "Main Category 1"},
}
if err := DB.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post")
}
if DB.First(&Category{}, "name = ?", "Category 1").Error != nil {
t.Errorf("Category should be saved")
}
var p Post
DB.First(&p, post.Id)
if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
t.Errorf("Category Id should exist")
}
if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
t.Errorf("Comment 1 should be saved")
}
if post.Comments[0].PostId == 0 {
t.Errorf("Comment Should have post id")
}
var comment Comment
if DB.First(&comment, "content = ?", "Comment 2").Error != nil {
t.Errorf("Comment 2 should be saved")
}
if comment.PostId == 0 {
t.Errorf("Comment 2 Should have post id")
}
comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}}
DB.Save(&comment3)
}
func TestRelated(t *testing.T) {
user := User{
Name: "jinzhu",
BillingAddress: Address{Address1: "Billing Address - Address 1"},
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
CreditCard: CreditCard{Number: "1234567890"},
Company: Company{Name: "company1"},
}
DB.Save(&user)
if user.CreditCard.ID == 0 {
t.Errorf("After user save, credit card should have id")
}
if user.BillingAddress.ID == 0 {
t.Errorf("After user save, billing address should have id")
}
if user.Emails[0].Id == 0 {
t.Errorf("After user save, billing address should have id")
}
var emails []Email
DB.Model(&user).Related(&emails)
if len(emails) != 2 {
t.Errorf("Should have two emails")
}
var emails2 []Email
DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
if len(emails2) != 1 {
t.Errorf("Should have two emails")
}
var user1 User
DB.Model(&user).Related(&user1.Emails)
if len(user1.Emails) != 2 {
t.Errorf("Should have only one email match related condition")
}
var address1 Address
DB.Model(&user).Related(&address1, "BillingAddressId")
if address1.Address1 != "Billing Address - Address 1" {
t.Errorf("Should get billing address from user correctly")
}
user1 = User{}
DB.Model(&address1).Related(&user1, "BillingAddressId")
if DB.NewRecord(user1) {
t.Errorf("Should get user from address correctly")
}
var user2 User
DB.Model(&emails[0]).Related(&user2)
if user2.Id != user.Id || user2.Name != user.Name {
t.Errorf("Should get user from email correctly")
}
var creditcard CreditCard
var user3 User
DB.First(&creditcard, "number = ?", "1234567890")
DB.Model(&creditcard).Related(&user3)
if user3.Id != user.Id || user3.Name != user.Name {
t.Errorf("Should get user from credit card correctly")
}
if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() {
t.Errorf("RecordNotFound for Related")
}
var company Company
if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" {
t.Errorf("RecordNotFound for Related")
}
}
func TestManyToMany(t *testing.T) {
DB.Raw("delete from languages")
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
user := User{Name: "Many2Many", Languages: languages}
DB.Save(&user)
// Query
var newLanguages []Language
DB.Model(&user).Related(&newLanguages, "Languages")
if len(newLanguages) != len([]string{"ZH", "EN"}) {
t.Errorf("Query many to many relations")
}
DB.Model(&user).Association("Languages").Find(&newLanguages)
if len(newLanguages) != len([]string{"ZH", "EN"}) {
t.Errorf("Should be able to find many to many relations")
}
if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) {
t.Errorf("Count should return correct result")
}
// Append
DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"})
if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() {
t.Errorf("New record should be saved when append")
}
languageA := Language{Name: "AA"}
DB.Save(&languageA)
DB.Model(&User{Id: user.Id}).Association("Languages").Append(languageA)
languageC := Language{Name: "CC"}
DB.Save(&languageC)
DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) {
t.Errorf("All appended languages should be saved")
}
// Delete
user.Languages = []Language{}
DB.Model(&user).Association("Languages").Find(&user.Languages)
var language Language
DB.Where("name = ?", "EE").First(&language)
DB.Model(&user).Association("Languages").Delete(language, &language)
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
t.Errorf("Relations should be deleted with Delete")
}
if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() {
t.Errorf("Language EE should not be deleted")
}
DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
user2 := User{Name: "Many2Many_User2", Languages: languages}
DB.Save(&user2)
DB.Model(&user).Association("Languages").Delete(languages, &languages)
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 {
t.Errorf("Relations should be deleted with Delete")
}
if DB.Model(&user2).Association("Languages").Count() == 0 {
t.Errorf("Other user's relations should not be deleted")
}
// Replace
var languageB Language
DB.Where("name = ?", "BB").First(&languageB)
DB.Model(&user).Association("Languages").Replace(languageB)
if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 {
t.Errorf("Relations should be replaced")
}
DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}})
if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) {
t.Errorf("Relations should be replaced")
}
// Clear
DB.Model(&user).Association("Languages").Clear()
if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
t.Errorf("Relations should be cleared")
}
}
func TestForeignKey(t *testing.T) {
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
for _, foreignKey := range []string{"UserId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
for _, foreignKey := range []string{"PostId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
}
package gorm
import (
"fmt"
)
type callback struct {
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
processors []*callbackProcessor
}
type callbackProcessor struct {
name string
before string
after string
replace bool
remove bool
typ string
processor *func(scope *Scope)
callback *callback
}
func (c *callback) addProcessor(typ string) *callbackProcessor {
cp := &callbackProcessor{typ: typ, callback: c}
c.processors = append(c.processors, cp)
return cp
}
func (c *callback) clone() *callback {
return &callback{
creates: c.creates,
updates: c.updates,
deletes: c.deletes,
queries: c.queries,
processors: c.processors,
}
}
func (c *callback) Create() *callbackProcessor {
return c.addProcessor("create")
}
func (c *callback) Update() *callbackProcessor {
return c.addProcessor("update")
}
func (c *callback) Delete() *callbackProcessor {
return c.addProcessor("delete")
}
func (c *callback) Query() *callbackProcessor {
return c.addProcessor("query")
}
func (c *callback) RowQuery() *callbackProcessor {
return c.addProcessor("row_query")
}
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
cp.before = name
return cp
}
func (cp *callbackProcessor) After(name string) *callbackProcessor {
cp.after = name
return cp
}
func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
cp.name = name
cp.processor = &fc
cp.callback.sort()
}
func (cp *callbackProcessor) Remove(name string) {
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name
cp.remove = true
cp.callback.sort()
}
func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name
cp.processor = &fc
cp.replace = true
cp.callback.sort()
}
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
var sortCallbackProcessor func(c *callbackProcessor)
var names, sortedNames = []string{}, []string{}
for _, cp := range cps {
if index := getRIndex(names, cp.name); index > -1 {
if !cp.replace && !cp.remove {
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
}
}
names = append(names, cp.name)
}
sortCallbackProcessor = func(c *callbackProcessor) {
if getRIndex(sortedNames, c.name) > -1 {
return
}
if len(c.before) > 0 {
if index := getRIndex(sortedNames, c.before); index > -1 {
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
} else if index := getRIndex(names, c.before); index > -1 {
sortedNames = append(sortedNames, c.name)
sortCallbackProcessor(cps[index])
} else {
sortedNames = append(sortedNames, c.name)
}
}
if len(c.after) > 0 {
if index := getRIndex(sortedNames, c.after); index > -1 {
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
} else if index := getRIndex(names, c.after); index > -1 {
cp := cps[index]
if len(cp.before) == 0 {
cp.before = c.name
}
sortCallbackProcessor(cp)
} else {
sortedNames = append(sortedNames, c.name)
}
}
if getRIndex(sortedNames, c.name) == -1 {
sortedNames = append(sortedNames, c.name)
}
}
for _, cp := range cps {
sortCallbackProcessor(cp)
}
var funcs = []*func(scope *Scope){}
var sortedFuncs = []*func(scope *Scope){}
for _, name := range sortedNames {
index := getRIndex(names, name)
if !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor)
}
}
for _, cp := range cps {
if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
if !cp.remove {
funcs = append(funcs, cp.processor)
}
}
}
return append(sortedFuncs, funcs...)
}
func (c *callback) sort() {
var creates, updates, deletes, queries, rowQueries []*callbackProcessor
for _, processor := range c.processors {
switch processor.typ {
case "create":
creates = append(creates, processor)
case "update":
updates = append(updates, processor)
case "delete":
deletes = append(deletes, processor)
case "query":
queries = append(queries, processor)
case "row_query":
rowQueries = append(rowQueries, processor)
}
}
c.creates = sortProcessors(creates)
c.updates = sortProcessors(updates)
c.deletes = sortProcessors(deletes)
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
}
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
package gorm
import (
"fmt"
"strings"
)
func BeforeCreate(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeCreate")
}
func UpdateTimeStampWhenCreate(scope *Scope) {
if !scope.HasError() {
now := NowFunc()
scope.SetColumn("CreatedAt", now)
scope.SetColumn("UpdatedAt", now)
}
}
func Create(scope *Scope) {
defer scope.Trace(NowFunc())
if !scope.HasError() {
// set create sql
var sqls, columns []string
fields := scope.Fields()
for _, field := range fields {
if scope.changeableField(field) {
if field.IsNormal {
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
if !field.IsBlank || !field.HasDefaultValue {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
}
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) {
columns = append(columns, scope.Quote(relationField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
}
}
}
}
returningKey := "*"
primaryField := scope.PrimaryField()
if primaryField != nil {
returningKey = scope.Quote(primaryField.DBName)
}
if len(columns) == 0 {
scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v",
scope.QuotedTableName(),
scope.Dialect().ReturningStr(scope.TableName(), returningKey),
))
} else {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) VALUES (%v) %v",
scope.QuotedTableName(),
strings.Join(columns, ","),
strings.Join(sqls, ","),
scope.Dialect().ReturningStr(scope.TableName(), returningKey),
))
}
// execute create sql
if scope.Dialect().SupportLastInsertId() {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err := result.LastInsertId()
if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected()
if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id))
}
}
}
} else {
if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
scope.db.RowsAffected = 1
} else {
scope.Err(err)
}
}
}
}
}
func AfterCreate(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterCreate")
scope.CallMethodWithErrorCheck("AfterSave")
}
func init() {
DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Create().Register("gorm:before_create", BeforeCreate)
DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
DefaultCallback.Create().Register("gorm:create", Create)
DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}
package gorm
import "fmt"
func BeforeDelete(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeDelete")
}
func Delete(scope *Scope) {
if !scope.HasError() {
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
scope.Raw(
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
scope.QuotedTableName(),
scope.AddToVars(NowFunc()),
scope.CombinedConditionSql(),
))
} else {
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql()))
}
scope.Exec()
}
}
func AfterDelete(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterDelete")
}
func init() {
DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
DefaultCallback.Delete().Register("gorm:delete", Delete)
DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}
package gorm
import (
"errors"
"fmt"
"reflect"
)
func Query(scope *Scope) {
defer scope.Trace(NowFunc())
var (
isSlice bool
isPtr bool
anyRecordFound bool
destType reflect.Type
)
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy))
}
}
var dest = scope.IndirectValue()
if value, ok := scope.Get("gorm:query_destination"); ok {
dest = reflect.Indirect(reflect.ValueOf(value))
}
if kind := dest.Kind(); kind == reflect.Slice {
isSlice = true
destType = dest.Type().Elem()
dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType))))
if destType.Kind() == reflect.Ptr {
isPtr = true
destType = destType.Elem()
}
} else if kind != reflect.Struct {
scope.Err(errors.New("unsupported destination, should be slice or struct"))
return
}
scope.prepareQuerySql()
if !scope.HasError() {
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
scope.db.RowsAffected = 0
if scope.Err(err) != nil {
return
}
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
scope.db.RowsAffected++
anyRecordFound = true
elem := dest
if isSlice {
elem = reflect.New(destType).Elem()
}
var values = make([]interface{}, len(columns))
fields := scope.New(elem.Addr().Interface()).Fields()
for index, column := range columns {
if field, ok := fields[column]; ok {
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface()
} else {
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
}
} else {
var value interface{}
values[index] = &value
}
}
scope.Err(rows.Scan(values...))
for index, column := range columns {
value := values[index]
if field, ok := fields[column]; ok {
if field.Field.Kind() == reflect.Ptr {
field.Field.Set(reflect.ValueOf(value).Elem())
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
field.Field.Set(v)
}
}
}
if isSlice {
if isPtr {
dest.Set(reflect.Append(dest, elem.Addr()))
} else {
dest.Set(reflect.Append(dest, elem))
}
}
}
if !anyRecordFound && !isSlice {
scope.Err(RecordNotFound)
}
}
}
func AfterQuery(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterFind")
}
func init() {
DefaultCallback.Query().Register("gorm:query", Query)
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
DefaultCallback.Query().Register("gorm:preload", Preload)
}
package gorm
import "reflect"
func BeginTransaction(scope *Scope) {
scope.Begin()
}
func CommitOrRollbackTransaction(scope *Scope) {
scope.CommitOrRollback()
}
func SaveBeforeAssociations(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
if relationship.ForeignFieldName != "" {
scope.Err(scope.SetColumn(relationship.ForeignFieldName, scope.New(value.Addr().Interface()).PrimaryKeyValue()))
}
}
}
}
}
func SaveAfterAssociations(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil &&
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
value := field.Field
switch value.Kind() {
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface()
newScope := newDB.NewScope(elem)
if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" {
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
}
scope.Err(newDB.Save(elem).Error)
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
}
}
default:
elem := value.Addr().Interface()
newScope := scope.New(elem)
if relationship.ForeignFieldName != "" {
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
}
scope.Err(scope.NewDB().Save(elem).Error)
}
}
}
}
}
package gorm
import (
"reflect"
"runtime"
"strings"
"testing"
)
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
var names []string
for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
names = append(names, fnames[len(fnames)-1])
}
return reflect.DeepEqual(names, fnames)
}
func create(s *Scope) {}
func beforeCreate1(s *Scope) {}
func beforeCreate2(s *Scope) {}
func afterCreate1(s *Scope) {}
func afterCreate2(s *Scope) {}
func TestRegisterCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}}
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("before_create2", beforeCreate2)
callback.Create().Register("create", create)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Register("after_create2", afterCreate2)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback")
}
}
func TestRegisterCallbackWithOrder(t *testing.T) {
var callback1 = &callback{processors: []*callbackProcessor{}}
callback1.Create().Register("before_create1", beforeCreate1)
callback1.Create().Register("create", create)
callback1.Create().Register("after_create1", afterCreate1)
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
var callback2 = &callback{processors: []*callbackProcessor{}}
callback2.Update().Register("create", create)
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
callback2.Update().Register("after_create2", afterCreate2)
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
}
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
var callback1 = &callback{processors: []*callbackProcessor{}}
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
callback1.Query().Register("before_create1", beforeCreate1)
callback1.Query().Register("after_create1", afterCreate1)
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
t.Errorf("register callback with order")
}
var callback2 = &callback{processors: []*callbackProcessor{}}
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
callback2.Delete().Register("after_create1", afterCreate1)
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback with order")
}
}
func replaceCreate(s *Scope) {}
func TestReplaceCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Replace("create", replaceCreate)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
t.Errorf("replace callback")
}
}
func TestRemoveCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Remove("create")
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
t.Errorf("remove callback")
}
}
package gorm
import (
"fmt"
"strings"
)
func AssignUpdateAttributes(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
protected, ok := scope.Get("gorm:ignore_protected_attrs")
_, updateColumn := scope.Get("gorm:update_column")
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
if updateColumn {
scope.InstanceSet("gorm:update_attrs", maps)
} else if len(updateAttrs) > 0 {
scope.InstanceSet("gorm:update_attrs", updateAttrs)
} else if !hasUpdate {
scope.SkipLeft()
return
}
}
}
}
func BeforeUpdate(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeUpdate")
}
}
func UpdateTimeStampWhenUpdate(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
}
}
func Update(scope *Scope) {
if !scope.HasError() {
var sqls []string
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
for key, value := range updateAttrs.(map[string]interface{}) {
if scope.changeableDBColumn(key) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
}
}
} else {
fields := scope.Fields()
for _, field := range fields {
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
if !field.IsBlank || !field.HasDefaultValue {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) {
if !relationField.IsBlank {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())))
}
}
}
}
}
if len(sqls) > 0 {
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v",
scope.QuotedTableName(),
strings.Join(sqls, ", "),
scope.CombinedConditionSql(),
))
scope.Exec()
}
}
}
func AfterUpdate(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("AfterUpdate")
scope.CallMethodWithErrorCheck("AfterSave")
}
}
func init() {
DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
DefaultCallback.Update().Register("gorm:update", Update)
DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}
package gorm_test
import (
"errors"
"github.com/jinzhu/gorm"
"reflect"
"testing"
)
func (s *Product) BeforeCreate() (err error) {
if s.Code == "Invalid" {
err = errors.New("invalid product")
}
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
return
}
func (s *Product) BeforeUpdate() (err error) {
if s.Code == "dont_update" {
err = errors.New("can't update")
}
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
return
}
func (s *Product) BeforeSave() (err error) {
if s.Code == "dont_save" {
err = errors.New("can't save")
}
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
return
}
func (s *Product) AfterFind() {
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
}
func (s *Product) AfterCreate(tx *gorm.DB) {
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
}
func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
}
func (s *Product) AfterSave() (err error) {
if s.Code == "after_save_error" {
err = errors.New("can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
}
func (s *Product) BeforeDelete() (err error) {
if s.Code == "dont_delete" {
err = errors.New("can't delete")
}
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
return
}
func (s *Product) AfterDelete() (err error) {
if s.Code == "after_delete_error" {
err = errors.New("can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
}
func (s *Product) GetCallTimes() []int64 {
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
}
func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100}
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
}
p.Price = 200
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
}
var products []Product
DB.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should work with slice")
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
}
DB.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
t.Errorf("Can't find a deleted record")
}
}
func TestCallbacksWithErrors(t *testing.T) {
p := Product{Code: "Invalid", Price: 100}
if DB.Save(&p).Error == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value")
}
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
t.Errorf("Should not save record that have errors")
}
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value")
}
p2 := Product{Code: "update_callback", Price: 100}
DB.Save(&p2)
p2.Code = "dont_update"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value")
}
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
p2.Code = "dont_save"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value")
}
p3 := Product{Code: "dont_delete", Price: 100}
DB.Save(&p3)
if DB.Delete(&p3).Error == nil {
t.Errorf("An error from before delete callbacks happened when delete")
}
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("An error from before delete callbacks happened")
}
p4 := Product{Code: "after_save_error", Price: 100}
DB.Save(&p4)
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error in after save callback")
}
p5 := Product{Code: "after_delete_error", Price: 100}
DB.Save(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found")
}
DB.Delete(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type commonDialect struct{}
func (commonDialect) BinVar(i int) string {
return "$$" // ?
}
func (commonDialect) SupportLastInsertId() bool {
return true
}
func (commonDialect) HasTop() bool {
return false
}
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "INTEGER AUTO_INCREMENT"
}
return "INTEGER"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "BIGINT AUTO_INCREMENT"
}
return "BIGINT"
case reflect.Float32, reflect.Float64:
return "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("VARCHAR(%d)", size)
}
return "VARCHAR(65532)"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "TIMESTAMP"
}
default:
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("BINARY(%d)", size)
}
return "BINARY(65532)"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
}
func (commonDialect) ReturningStr(tableName, key string) string {
return ""
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (commonDialect) databaseName(scope *Scope) string {
from := strings.Index(scope.db.parent.source, "/") + 1
to := strings.Index(scope.db.parent.source, "?")
if to == -1 {
to = len(scope.db.parent.source)
}
return scope.db.parent.source[from:to]
}
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}
package gorm_test
import (
"reflect"
"testing"
"time"
)
func TestCreate(t *testing.T) {
float := 35.03554004971999
user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
if !DB.NewRecord(user) || !DB.NewRecord(&user) {
t.Error("User should be new record before create")
}
if count := DB.Save(&user).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
if DB.NewRecord(user) || DB.NewRecord(&user) {
t.Error("User should not new record after save")
}
var newUser User
DB.First(&newUser, user.Id)
if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
t.Errorf("User's PasswordHash should be saved ([]byte)")
}
if newUser.Age != 18 {
t.Errorf("User's Age should be saved (int)")
}
if newUser.UserNum != Num(111) {
t.Errorf("User's UserNum should be saved (custom type)")
}
if newUser.Latitude != float {
t.Errorf("Float64 should not be changed after save")
}
if user.CreatedAt.IsZero() {
t.Errorf("Should have created_at after create")
}
if newUser.CreatedAt.IsZero() {
t.Errorf("Should have created_at after create")
}
DB.Model(user).Update("name", "create_user_new_name")
DB.First(&user, user.Id)
if user.CreatedAt != newUser.CreatedAt {
t.Errorf("CreatedAt should not be changed after update")
}
}
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
jt := JoinTable{From: 1, To: 2}
err := DB.Create(&jt).Error
if err != nil {
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
}
}
func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
animal := Animal{Name: "Ferdinand"}
if DB.Save(&animal).Error != nil {
t.Errorf("No error should happen when create a record without std primary key")
}
if animal.Counter == 0 {
t.Errorf("No std primary key should be filled value after create")
}
if animal.Name != "Ferdinand" {
t.Errorf("Default value should be overrided")
}
// Test create with default value not overrided
an := Animal{From: "nerdz"}
if DB.Save(&an).Error != nil {
t.Errorf("No error should happen when create an record without std primary key")
}
// We must fetch the value again, to have the default fields updated
// (We can't do this in the update statements, since sql default can be expressions
// And be different from the fields' type (eg. a time.Time fiels has a default value of "now()"
DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
if an.Name != "galeone" {
t.Errorf("Default value should fill the field. But got %v", an.Name)
}
}
func TestAnonymousScanner(t *testing.T) {
user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
DB.Save(&user)
var user2 User
DB.First(&user2, "name = ?", "anonymous_scanner")
if user2.Role.Name != "admin" {
t.Errorf("Should be able to get anonymous scanner")
}
if !user2.IsAdmin() {
t.Errorf("Should be able to get anonymous scanner")
}
}
func TestAnonymousField(t *testing.T) {
user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
DB.Save(&user)
var user2 User
DB.First(&user2, "name = ?", "anonymous_field")
DB.Model(&user2).Related(&user2.Company)
if user2.Company.Name != "company" {
t.Errorf("Should be able to get anonymous field")
}
}
func TestSelectWithCreate(t *testing.T) {
user := getPreparedUser("select_user", "select_with_create")
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
var queryuser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
if queryuser.Name != user.Name || queryuser.Age == user.Age {
t.Errorf("Should only create users with name column")
}
if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 ||
queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 {
t.Errorf("Should only create selected relationships")
}
}
func TestOmitWithCreate(t *testing.T) {
user := getPreparedUser("omit_user", "omit_with_create")
DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
var queryuser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
if queryuser.Name == user.Name || queryuser.Age != user.Age {
t.Errorf("Should only create users with age column")
}
if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 ||
queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 {
t.Errorf("Should not create omited relationships")
}
}
package gorm_test
import (
"testing"
"time"
)
type CustomizeColumn struct {
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
Name string `gorm:"column:mapped_name"`
Date time.Time `gorm:"column:mapped_time"`
}
// Make sure an ignored field does not interfere with another field's custom
// column name that matches the ignored field.
type CustomColumnAndIgnoredFieldClash struct {
Body string `sql:"-"`
RawBody string `gorm:"column:body"`
}
func TestCustomizeColumn(t *testing.T) {
col := "mapped_name"
DB.DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}
col = "mapped_id"
if scope.PrimaryKey() != col {
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
}
expected := "foo"
cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()}
if count := DB.Create(&cc).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
var cc1 CustomizeColumn
DB.First(&cc1, 666)
if cc1.Name != expected {
t.Errorf("Failed to query CustomizeColumn")
}
cc.Name = "bar"
DB.Save(&cc)
var cc2 CustomizeColumn
DB.First(&cc2, 666)
if cc2.Name != "bar" {
t.Errorf("Failed to query CustomizeColumn")
}
}
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
t.Errorf("Should not raise error: %s", err)
}
}
package gorm_test
import (
"testing"
"time"
)
func TestDelete(t *testing.T) {
user1, user2 := User{Name: "delete1"}, User{Name: "delete2"}
DB.Save(&user1)
DB.Save(&user2)
if DB.Delete(&user1).Error != nil {
t.Errorf("No error should happen when delete a record")
}
if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
t.Errorf("Other users that not deleted should be found-able")
}
}
func TestInlineDelete(t *testing.T) {
user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"}
DB.Save(&user1)
DB.Save(&user2)
if DB.Delete(&User{}, user1.Id).Error != nil {
t.Errorf("No error should happen when delete a record")
} else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
if DB.Delete(&User{}, "name = ?", user2.Name).Error != nil {
t.Errorf("No error should happen when delete a record")
} else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
}
func TestSoftDelete(t *testing.T) {
type User struct {
Id int64
Name string
DeletedAt time.Time
}
DB.AutoMigrate(&User{})
user := User{Name: "soft_delete"}
DB.Save(&user)
DB.Delete(&user)
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
t.Errorf("Can't find a soft deleted record")
}
if DB.Unscoped().First(&User{}, "name = ?", user.Name).Error != nil {
t.Errorf("Should be able to find soft deleted record with Unscoped")
}
DB.Unscoped().Delete(&user)
if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() {
t.Errorf("Can't find permanently deleted record")
}
}
package gorm
import (
"fmt"
"reflect"
)
type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
HasTop() bool
SqlTag(value reflect.Value, size int, autoIncrease bool) string
ReturningStr(tableName, key string) string
SelectFromDummyTable() string
Quote(key string) string
HasTable(scope *Scope, tableName string) bool
HasColumn(scope *Scope, tableName string, columnName string) bool
HasIndex(scope *Scope, tableName string, indexName string) bool
RemoveIndex(scope *Scope, indexName string)
}
func NewDialect(driver string) Dialect {
var d Dialect
switch driver {
case "foundation":
d = &foundation{}
case "mysql":
d = &mysql{}
case "sqlite3":
d = &sqlite3{}
case "mssql":
d = &mssql{}
default:
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
d = &commonDialect{}
}
return d
}
package gorm_test
import "testing"
type BasePost struct {
Id int64
Title string
URL string
}
type HNPost struct {
BasePost
Upvotes int32
}
type EngadgetPost struct {
BasePost BasePost `gorm:"embedded"`
ImageUrl string
}
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if news.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
t.Errorf("primary key with embedded struct should works")
}
for _, field := range DB.NewScope(&HNPost{}).Fields() {
if field.Name == "BasePost" {
t.Errorf("scope Fields should not contain embedded struct")
}
}
}
package gorm
import "errors"
var (
RecordNotFound = errors.New("record not found")
InvalidSql = errors.New("invalid sql")
NoNewAttrs = errors.New("no new attributes")
NoValidTransaction = errors.New("no valid transaction")
CantStartTransaction = errors.New("can't start transaction")
)
package gorm
import (
"database/sql"
"errors"
"reflect"
)
type Field struct {
*StructField
IsBlank bool
Field reflect.Value
}
func (field *Field) Set(value interface{}) error {
if !field.Field.IsValid() {
return errors.New("field value not valid")
}
if !field.Field.CanAddr() {
return errors.New("unaddressable value")
}
if rvalue, ok := value.(reflect.Value); ok {
value = rvalue.Interface()
}
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
if v, ok := value.(reflect.Value); ok {
if err := scanner.Scan(v.Interface()); err != nil {
return err
}
} else {
if err := scanner.Scan(value); err != nil {
return err
}
}
} else {
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflectValue.Convert(field.Field.Type()))
} else {
return errors.New("could not convert argument")
}
}
field.IsBlank = isBlank(field.Field)
return nil
}
// Fields get value's fields
func (scope *Scope) Fields() map[string]*Field {
if scope.fields == nil {
fields := map[string]*Field{}
structFields := scope.GetStructFields()
indirectValue := scope.IndirectValue()
isStruct := indirectValue.Kind() == reflect.Struct
for _, structField := range structFields {
if isStruct {
fields[structField.DBName] = getField(indirectValue, structField)
} else {
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
}
}
scope.fields = fields
}
return scope.fields
}
func getField(indirectValue reflect.Value, structField *StructField) *Field {
field := &Field{StructField: structField}
for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
}
field.Field = indirectValue
field.IsBlank = isBlank(indirectValue)
return field
}
package gorm
import (
"fmt"
"reflect"
"time"
)
type foundation struct {
commonDialect
}
func (foundation) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (foundation) SupportLastInsertId() bool {
return false
}
func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "serial"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigserial"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "double"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "clob"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
}
func (f foundation) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", f.Quote(tableName), key)
}
func (foundation) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName).Row().Scan(&count)
return count > 0
}
func (foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
return count > 0
}
func (f foundation) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", f.Quote(indexName)))
}
func (foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
package gorm
import "database/sql"
type sqlCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type sqlDb interface {
Begin() (*sql.Tx, error)
}
type sqlTx interface {
Commit() error
Rollback() error
}
package gorm
import (
"errors"
"fmt"
"reflect"
"strings"
)
type JoinTableHandlerInterface interface {
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
Table(db *DB) string
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
SourceForeignKeys() []JoinTableForeignKey
DestinationForeignKeys() []JoinTableForeignKey
}
type JoinTableForeignKey struct {
DBName string
AssociationDBName string
}
type JoinTableSource struct {
ModelType reflect.Type
ForeignKeys []JoinTableForeignKey
}
type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
}
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
return s.Source.ForeignKeys
}
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
return s.Destination.ForeignKeys
}
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName
s.Source = JoinTableSource{ModelType: source}
sourceScope := &Scope{Value: reflect.New(source).Interface()}
sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields
for _, primaryField := range sourcePrimaryFields {
if relationship.ForeignDBName == "" {
relationship.ForeignFieldName = source.Name() + primaryField.Name
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
}
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.ForeignDBName
} else {
dbName = ToDBName(source.Name() + primaryField.Name)
}
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: dbName,
AssociationDBName: primaryField.DBName,
})
}
s.Destination = JoinTableSource{ModelType: destination}
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields
for _, primaryField := range destinationPrimaryFields {
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.AssociationForeignDBName
} else {
dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name)
}
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: dbName,
AssociationDBName: primaryField.DBName,
})
}
}
func (s JoinTableHandler) Table(db *DB) string {
return s.TableName
}
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{}
for _, source := range sources {
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
}
} else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
}
}
}
return values
}
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
scope := db.NewScope("")
searchMap := s.GetSearchMap(db, source1, source2)
var assignColumns, binVars, conditions []string
var values []interface{}
for key, value := range searchMap {
assignColumns = append(assignColumns, key)
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}
for _, value := range values {
values = append(values, value)
}
quotedTable := handler.Table(db)
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
quotedTable,
strings.Join(assignColumns, ","),
strings.Join(binVars, ","),
scope.Dialect().SelectFromDummyTable(),
quotedTable,
strings.Join(conditions, " AND "),
)
return db.Exec(sql, values...).Error
}
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var conditions []string
var values []interface{}
for key, value := range s.GetSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", key))
values = append(values, value)
}
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
quotedTable := handler.Table(db)
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
var joinConditions []string
var queryConditions []string
var values []interface{}
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}
for _, foreignKey := range s.Source.ForeignKeys {
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
Where(strings.Join(queryConditions, " AND "), values...)
} else {
db.Error = errors.New("wrong source type for join table handler")
return db
}
}
package gorm_test
import (
"fmt"
"testing"
"time"
"github.com/jinzhu/gorm"
)
type Person struct {
Id int
Name string
Addresses []*Address `gorm:"many2many:person_addresses;"`
}
type PersonAddress struct {
gorm.JoinTableHandler
PersonID int
AddressID int
DeletedAt time.Time
CreatedAt time.Time
}
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
return db.Where(map[string]interface{}{
"person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
"address_id": db.NewScope(associationValue).PrimaryKeyValue(),
}).Assign(map[string]interface{}{
"person_id": foreignValue,
"address_id": associationValue,
"deleted_at": gorm.Expr("NULL"),
}).FirstOrCreate(&PersonAddress{}).Error
}
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
return db.Delete(&PersonAddress{}).Error
}
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
table := pa.Table(db)
return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
}
func TestJoinTable(t *testing.T) {
DB.Exec("drop table person_addresses;")
DB.AutoMigrate(&Person{})
DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{})
address1 := &Address{Address1: "address 1"}
address2 := &Address{Address1: "address 2"}
person := &Person{Name: "person", Addresses: []*Address{address1, address2}}
DB.Save(person)
DB.Model(person).Association("Addresses").Delete(address1)
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 {
t.Errorf("Should found one address")
}
if DB.Model(person).Association("Addresses").Count() != 1 {
t.Errorf("Should found one address")
}
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 {
t.Errorf("Found two addresses with Unscoped")
}
if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 {
t.Errorf("Should deleted all addresses")
}
}
package gorm
import (
"database/sql/driver"
"fmt"
"log"
"os"
"reflect"
"regexp"
"time"
)
type logger interface {
Print(v ...interface{})
}
type Logger struct {
*log.Logger
}
var defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
// Format log
var sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
func (logger Logger) Print(values ...interface{}) {
if len(values) > 1 {
level := values[0]
currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
messages := []interface{}{source, currentTime}
if level == "sql" {
// duration
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
// sql
var formatedValues []interface{}
for _, value := range values[4].([]interface{}) {
indirectValue := reflect.Indirect(reflect.ValueOf(value))
if indirectValue.IsValid() {
value = indirectValue.Interface()
if t, ok := value.(time.Time); ok {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
} else if b, ok := value.([]byte); ok {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", string(b)))
} else if r, ok := value.(driver.Valuer); ok {
if value, err := r.Value(); err == nil && value != nil {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
} else {
formatedValues = append(formatedValues, "NULL")
}
} else {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
}
} else {
formatedValues = append(formatedValues, fmt.Sprintf("'%v'", value))
}
}
messages = append(messages, fmt.Sprintf(sqlRegexp.ReplaceAllString(values[3].(string), "%v"), formatedValues...))
} else {
messages = append(messages, "\033[31;1m")
messages = append(messages, values[2:]...)
messages = append(messages, "\033[0m")
}
logger.Println(messages...)
}
}
package gorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"time"
)
// NowFunc returns current time, this function is exported in order to be able
// to give the flexibility to the developer to customize it according to their
// needs
//
// e.g: return time.Now().UTC()
//
var NowFunc = func() time.Time {
return time.Now()
}
type DB struct {
Value interface{}
Error error
RowsAffected int64
callback *callback
db sqlCommon
parent *DB
search *search
logMode int
logger logger
dialect Dialect
singularTable bool
source string
values map[string]interface{}
joinTableHandlers map[string]JoinTableHandler
}
func Open(dialect string, args ...interface{}) (DB, error) {
var db DB
var err error
if len(args) == 0 {
err = errors.New("invalid database source")
} else {
var source string
var dbSql sqlCommon
switch value := args[0].(type) {
case string:
var driver = dialect
if len(args) == 1 {
source = value
} else if len(args) >= 2 {
driver = value
source = args[1].(string)
}
if driver == "foundation" {
driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
}
dbSql, err = sql.Open(driver, source)
case sqlCommon:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
dbSql = value
}
db = DB{
dialect: NewDialect(dialect),
logger: defaultLogger,
callback: DefaultCallback,
source: source,
values: map[string]interface{}{},
db: dbSql,
}
db.parent = &db
}
return db, err
}
func (s *DB) Close() error {
return s.parent.db.(*sql.DB).Close()
}
func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}
func (s *DB) New() *DB {
clone := s.clone()
clone.search = nil
clone.Value = nil
return clone
}
// NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope {
dbClone := db.clone()
dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
}
// CommonDB Return the underlying sql.DB or sql.Tx instance.
// Use of this method is discouraged. It's mainly intended to allow
// coexistence with legacy non-GORM code.
func (s *DB) CommonDB() sqlCommon {
return s.db
}
func (s *DB) Callback() *callback {
s.parent.callback = s.parent.callback.clone()
return s.parent.callback
}
func (s *DB) SetLogger(l logger) {
s.parent.logger = l
}
func (s *DB) LogMode(enable bool) *DB {
if enable {
s.logMode = 2
} else {
s.logMode = 1
}
return s
}
func (s *DB) SingularTable(enable bool) {
modelStructs = map[reflect.Type]*ModelStruct{}
s.parent.singularTable = enable
}
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
return s.clone().search.Where(query, args...).db
}
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
return s.clone().search.Or(query, args...).db
}
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
return s.clone().search.Not(query, args...).db
}
func (s *DB) Limit(value interface{}) *DB {
return s.clone().search.Limit(value).db
}
func (s *DB) Offset(value interface{}) *DB {
return s.clone().search.Offset(value).db
}
func (s *DB) Order(value string, reorder ...bool) *DB {
return s.clone().search.Order(value, reorder...).db
}
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
return s.clone().search.Select(query, args...).db
}
func (s *DB) Omit(columns ...string) *DB {
return s.clone().search.Omit(columns...).db
}
func (s *DB) Group(query string) *DB {
return s.clone().search.Group(query).db
}
func (s *DB) Having(query string, values ...interface{}) *DB {
return s.clone().search.Having(query, values...).db
}
func (s *DB) Joins(query string) *DB {
return s.clone().search.Joins(query).db
}
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
for _, f := range funcs {
s = f(s)
}
return s
}
func (s *DB) Unscoped() *DB {
return s.clone().search.unscoped().db
}
func (s *DB) Attrs(attrs ...interface{}) *DB {
return s.clone().search.Attrs(attrs...).db
}
func (s *DB) Assign(attrs ...interface{}) *DB {
return s.clone().search.Assign(attrs...).db
}
func (s *DB) First(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search.Limit(1)
return newScope.Set("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Scan(dest interface{}) *DB {
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Row() *sql.Row {
return s.NewScope(s.Value).row()
}
func (s *DB) Rows() (*sql.Rows, error) {
return s.NewScope(s.Value).rows()
}
func (s *DB) Pluck(column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(column, value).db
}
func (s *DB) Count(value interface{}) *DB {
return s.NewScope(s.Value).count(value).db
}
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
}
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
c := s.clone()
if result := c.First(out, where...); result.Error != nil {
if !result.RecordNotFound() {
return result
}
c.NewScope(out).inlineCondition(where...).initialize()
} else {
c.NewScope(out).updatedAttrsWithValues(convertInterfaceToMap(s.search.assignAttrs), false)
}
return c
}
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
c := s.clone()
if result := c.First(out, where...); result.Error != nil {
if !result.RecordNotFound() {
return result
}
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
} else if len(c.search.assignAttrs) > 0 {
c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates)
}
return c
}
func (s *DB) Update(attrs ...interface{}) *DB {
return s.Updates(toSearchableMap(attrs...), true)
}
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
return s.clone().NewScope(s.Value).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callback.updates).db
}
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
return s.UpdateColumns(toSearchableMap(attrs...))
}
func (s *DB) UpdateColumns(values interface{}) *DB {
return s.clone().NewScope(s.Value).
Set("gorm:update_column", true).
Set("gorm:save_associations", false).
InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callback.updates).db
}
func (s *DB) Save(value interface{}) *DB {
scope := s.clone().NewScope(value)
if scope.PrimaryKeyZero() {
return scope.callCallbacks(s.parent.callback.creates).db
}
return scope.callCallbacks(s.parent.callback.updates).db
}
func (s *DB) Create(value interface{}) *DB {
scope := s.clone().NewScope(value)
return scope.callCallbacks(s.parent.callback.creates).db
}
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callback.deletes).db
}
func (s *DB) Raw(sql string, values ...interface{}) *DB {
return s.clone().search.Raw(true).Where(sql, values...).db
}
func (s *DB) Exec(sql string, values ...interface{}) *DB {
scope := s.clone().NewScope(nil)
generatedSql := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
generatedSql = strings.TrimSuffix(strings.TrimPrefix(generatedSql, "("), ")")
scope.Raw(generatedSql)
return scope.Exec().db
}
func (s *DB) Model(value interface{}) *DB {
c := s.clone()
c.Value = value
return c
}
func (s *DB) Table(name string) *DB {
clone := s.clone()
clone.search.Table(name)
clone.Value = nil
return clone
}
func (s *DB) Debug() *DB {
return s.clone().LogMode(true)
}
func (s *DB) Begin() *DB {
c := s.clone()
if db, ok := c.db.(sqlDb); ok {
tx, err := db.Begin()
c.db = interface{}(tx).(sqlCommon)
c.err(err)
} else {
c.err(CantStartTransaction)
}
return c
}
func (s *DB) Commit() *DB {
if db, ok := s.db.(sqlTx); ok {
s.err(db.Commit())
} else {
s.err(NoValidTransaction)
}
return s
}
func (s *DB) Rollback() *DB {
if db, ok := s.db.(sqlTx); ok {
s.err(db.Rollback())
} else {
s.err(NoValidTransaction)
}
return s
}
func (s *DB) NewRecord(value interface{}) bool {
return s.clone().NewScope(value).PrimaryKeyZero()
}
func (s *DB) RecordNotFound() bool {
return s.Error == RecordNotFound
}
// Migrations
func (s *DB) CreateTable(value interface{}) *DB {
return s.clone().NewScope(value).createTable().db
}
func (s *DB) DropTable(value interface{}) *DB {
return s.clone().NewScope(value).dropTable().db
}
func (s *DB) DropTableIfExists(value interface{}) *DB {
return s.clone().NewScope(value).dropTableIfExists().db
}
func (s *DB) HasTable(value interface{}) bool {
scope := s.clone().NewScope(value)
tableName := scope.TableName()
return scope.Dialect().HasTable(scope, tableName)
}
func (s *DB) AutoMigrate(values ...interface{}) *DB {
db := s.clone()
for _, value := range values {
db = db.NewScope(value).NeedPtr().autoMigrate().db
}
return db
}
func (s *DB) ModifyColumn(column string, typ string) *DB {
scope := s.clone().NewScope(s.Value)
scope.modifyColumn(column, typ)
return scope.db
}
func (s *DB) DropColumn(column string) *DB {
scope := s.clone().NewScope(s.Value)
scope.dropColumn(column)
return scope.db
}
func (s *DB) AddIndex(indexName string, column ...string) *DB {
scope := s.clone().NewScope(s.Value)
scope.addIndex(false, indexName, column...)
return scope.db
}
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
scope := s.clone().NewScope(s.Value)
scope.addIndex(true, indexName, column...)
return scope.db
}
func (s *DB) RemoveIndex(indexName string) *DB {
scope := s.clone().NewScope(s.Value)
scope.removeIndex(indexName)
return scope.db
}
/*
Add foreign key to the given scope
Example:
db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
*/
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
scope := s.clone().NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate)
return scope.db
}
func (s *DB) Association(column string) *Association {
var err error
scope := s.clone().NewScope(s.Value)
if primaryField := scope.PrimaryField(); primaryField.IsBlank {
err = errors.New("primary key can't be nil")
} else {
if field, ok := scope.FieldByName(column); ok {
if field.Relationship == nil || field.Relationship.ForeignFieldName == "" {
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
} else {
return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field}
}
} else {
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
}
}
return &Association{Error: err}
}
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
return s.clone().search.Preload(column, conditions...).db
}
// Set set value by name
func (s *DB) Set(name string, value interface{}) *DB {
return s.clone().InstantSet(name, value)
}
func (s *DB) InstantSet(name string, value interface{}) *DB {
s.values[name] = value
return s
}
// Get get value by name
func (s *DB) Get(name string) (value interface{}, ok bool) {
value, ok = s.values[name]
return
}
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
scope := s.NewScope(source)
for _, field := range scope.GetModelStruct().StructFields {
if field.Name == column || field.DBName == column {
if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" {
source := (&Scope{Value: source}).GetModelStruct().ModelType
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler
if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
s.Table(table).AutoMigrate(handler)
}
}
}
}
}
package gorm
import "time"
func (s *DB) clone() *DB {
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
for key, value := range s.values {
db.values[key] = value
}
if s.search == nil {
db.search = &search{}
} else {
db.search = s.search.clone()
}
db.search.db = &db
return &db
}
func (s *DB) err(err error) error {
if err != nil {
if err != RecordNotFound {
if s.logMode == 0 {
go s.print(fileWithLineNum(), err)
} else {
s.log(err)
}
}
s.Error = err
}
return err
}
func (s *DB) print(v ...interface{}) {
s.parent.logger.(logger).Print(v...)
}
func (s *DB) log(v ...interface{}) {
if s != nil && s.logMode == 2 {
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
}
}
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
if s.logMode == 2 {
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
}
}
package gorm_test
import (
"database/sql"
"database/sql/driver"
"fmt"
"strconv"
_ "github.com/denisenkom/go-mssqldb"
testdb "github.com/erikstmartin/go-testdb"
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
"github.com/jinzhu/now"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"os"
"testing"
"time"
)
var (
DB gorm.DB
t1, t2, t3, t4, t5 time.Time
)
func init() {
var err error
switch os.Getenv("GORM_DIALECT") {
case "mysql":
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
// CREATE DATABASE gorm;
// GRANT ALL ON gorm.* TO 'gorm'@'localhost';
fmt.Println("testing mysql...")
DB, err = gorm.Open("mysql", "gorm:gorm@/gorm?charset=utf8&parseTime=True")
case "postgres":
fmt.Println("testing postgres...")
DB, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
case "foundation":
fmt.Println("testing foundation...")
DB, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
case "mssql":
fmt.Println("testing mssql...")
DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433")
default:
fmt.Println("testing sqlite3...")
DB, err = gorm.Open("sqlite3", "/tmp/gorm.db")
}
// DB.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
// DB.SetLogger(log.New(os.Stdout, "\r\n", 0))
DB.LogMode(true)
DB.LogMode(false)
if err != nil {
panic(fmt.Sprintf("No error should happen when connect database, but got %+v", err))
}
DB.DB().SetMaxIdleConns(10)
runMigration()
}
func TestStringPrimaryKey(t *testing.T) {
type UUIDStruct struct {
ID string `gorm:"primary_key"`
Name string
}
DB.AutoMigrate(&UUIDStruct{})
data := UUIDStruct{ID: "uuid", Name: "hello"}
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
t.Errorf("string primary key should not be populated")
}
}
func TestExceptionsWithInvalidSql(t *testing.T) {
var columns []string
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
var count1, count2 int64
DB.Model(&User{}).Count(&count1)
if count1 <= 0 {
t.Errorf("Should find some users")
}
if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
DB.Model(&User{}).Count(&count2)
if count1 != count2 {
t.Errorf("No user should not be deleted by invalid SQL")
}
}
func TestSetTable(t *testing.T) {
DB.Create(getPreparedUser("pluck_user1", "pluck_user"))
DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
t.Errorf("No errors should happen if set table for pluck", err.Error())
}
var users []User
if DB.Table("users").Find(&[]User{}).Error != nil {
t.Errorf("No errors should happen if set table for find")
}
if DB.Table("invalid_table").Find(&users).Error == nil {
t.Errorf("Should got error when table is set to an invalid table")
}
DB.Exec("drop table deleted_users;")
if DB.Table("deleted_users").CreateTable(&User{}).Error != nil {
t.Errorf("Create table with specified table")
}
DB.Table("deleted_users").Save(&User{Name: "DeletedUser"})
var deletedUsers []User
DB.Table("deleted_users").Find(&deletedUsers)
if len(deletedUsers) != 1 {
t.Errorf("Query from specified table")
}
DB.Save(getPreparedUser("normal_user", "reset_table"))
DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
var user1, user2, user3 User
DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") {
t.Errorf("unset specified table with blank string")
}
}
type Order struct {
}
type Cart struct {
}
func (c Cart) TableName() string {
return "shopping_cart"
}
func TestHasTable(t *testing.T) {
type Foo struct {
Id int
Stuff string
}
DB.DropTable(&Foo{})
if ok := DB.HasTable(&Foo{}); ok {
t.Errorf("Table should not exist, but does")
}
if err := DB.CreateTable(&Foo{}).Error; err != nil {
t.Errorf("Table should be created")
}
if ok := DB.HasTable(&Foo{}); !ok {
t.Errorf("Table should exist, but HasTable informs it does not")
}
}
func TestTableName(t *testing.T) {
DB := DB.Model("")
if DB.NewScope(Order{}).TableName() != "orders" {
t.Errorf("Order's table name should be orders")
}
if DB.NewScope(&Order{}).TableName() != "orders" {
t.Errorf("&Order's table name should be orders")
}
if DB.NewScope([]Order{}).TableName() != "orders" {
t.Errorf("[]Order's table name should be orders")
}
if DB.NewScope(&[]Order{}).TableName() != "orders" {
t.Errorf("&[]Order's table name should be orders")
}
DB.SingularTable(true)
if DB.NewScope(Order{}).TableName() != "order" {
t.Errorf("Order's singular table name should be order")
}
if DB.NewScope(&Order{}).TableName() != "order" {
t.Errorf("&Order's singular table name should be order")
}
if DB.NewScope([]Order{}).TableName() != "order" {
t.Errorf("[]Order's singular table name should be order")
}
if DB.NewScope(&[]Order{}).TableName() != "order" {
t.Errorf("&[]Order's singular table name should be order")
}
if DB.NewScope(&Cart{}).TableName() != "shopping_cart" {
t.Errorf("&Cart's singular table name should be shopping_cart")
}
if DB.NewScope(Cart{}).TableName() != "shopping_cart" {
t.Errorf("Cart's singular table name should be shopping_cart")
}
if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" {
t.Errorf("&[]Cart's singular table name should be shopping_cart")
}
if DB.NewScope([]Cart{}).TableName() != "shopping_cart" {
t.Errorf("[]Cart's singular table name should be shopping_cart")
}
DB.SingularTable(false)
}
func TestSqlNullValue(t *testing.T) {
DB.DropTable(&NullValue{})
DB.AutoMigrate(&NullValue{})
if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello", Valid: true},
Age: sql.NullInt64{Int64: 18, Valid: true},
Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true},
AddedAt: NullTime{Time: time.Now(), Valid: true},
}).Error; err != nil {
t.Errorf("Not error should raise when test null value")
}
var nv NullValue
DB.First(&nv, "name = ?", "hello")
if nv.Name.String != "hello" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
t.Errorf("Should be able to fetch null value")
}
if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-2", Valid: true},
Age: sql.NullInt64{Int64: 18, Valid: false},
Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true},
AddedAt: NullTime{Time: time.Now(), Valid: false},
}).Error; err != nil {
t.Errorf("Not error should raise when test null value")
}
var nv2 NullValue
DB.First(&nv2, "name = ?", "hello-2")
if nv2.Name.String != "hello-2" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
t.Errorf("Should be able to fetch null value")
}
if err := DB.Save(&NullValue{Name: sql.NullString{String: "hello-3", Valid: false},
Age: sql.NullInt64{Int64: 18, Valid: false},
Male: sql.NullBool{Bool: true, Valid: true},
Height: sql.NullFloat64{Float64: 100.11, Valid: true},
AddedAt: NullTime{Time: time.Now(), Valid: false},
}).Error; err == nil {
t.Errorf("Can't save because of name can't be null")
}
}
func TestTransaction(t *testing.T) {
tx := DB.Begin()
u := User{Name: "transcation"}
if err := tx.Save(&u).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record")
}
if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
t.Errorf("Should return the underlying sql.Tx")
}
tx.Rollback()
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after rollback")
}
tx2 := DB.Begin()
u2 := User{Name: "transcation-2"}
if err := tx2.Save(&u2).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should find saved record")
}
tx2.Commit()
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should be able to find committed record")
}
}
func TestRow(t *testing.T) {
user1 := User{Name: "RowUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "RowUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "RowUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row()
var age int64
row.Scan(&age)
if age != 10 {
t.Errorf("Scan with Row")
}
}
func TestRows(t *testing.T) {
user1 := User{Name: "RowsUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "RowsUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "RowsUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
if err != nil {
t.Errorf("Not error should happen, but got")
}
count := 0
for rows.Next() {
var name string
var age int64
rows.Scan(&name, &age)
count++
}
if count != 2 {
t.Errorf("Should found two records with name 3")
}
}
func TestScan(t *testing.T) {
user1 := User{Name: "ScanUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "ScanUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "ScanUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
type result struct {
Name string
Age int
}
var res result
DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res)
if res.Name != user3.Name {
t.Errorf("Scan into struct should work")
}
var doubleAgeRes result
DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes)
if doubleAgeRes.Age != res.Age*2 {
t.Errorf("Scan double age as age")
}
var ress []result
DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress)
if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
t.Errorf("Scan into struct map")
}
}
func TestRaw(t *testing.T) {
user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
type result struct {
Name string
Email string
}
var ress []result
DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress)
if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
t.Errorf("Raw with scan")
}
rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows()
count := 0
for rows.Next() {
count++
}
if count != 1 {
t.Errorf("Raw with Rows should find one record with name 3")
}
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.RecordNotFound {
t.Error("Raw sql to update records")
}
}
func TestGroup(t *testing.T) {
rows, err := DB.Select("name").Table("users").Group("name").Rows()
if err == nil {
defer rows.Close()
for rows.Next() {
var name string
rows.Scan(&name)
}
} else {
t.Errorf("Should not raise any error")
}
}
func TestJoins(t *testing.T) {
type result struct {
Name string
Email string
}
user := User{
Name: "joins",
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
}
DB.Save(&user)
var results []result
DB.Table("users").Select("name, email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Scan(&results)
if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" {
t.Errorf("Should find all two emails with Join")
}
}
func TestHaving(t *testing.T) {
rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows()
if err == nil {
defer rows.Close()
for rows.Next() {
var name string
var total int64
rows.Scan(&name, &total)
if name == "2" && total != 1 {
t.Errorf("Should have one user having name 2")
}
if name == "3" && total != 2 {
t.Errorf("Should have two users having name 3")
}
}
} else {
t.Errorf("Should not raise any error")
}
}
func DialectHasTzSupport() bool {
// NB: mssql and FoundationDB do not support time zones.
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" {
return false
}
return true
}
func TestTimeWithZone(t *testing.T) {
var format = "2006-01-02 15:04:05 -0700"
var times []time.Time
GMT8, _ := time.LoadLocation("Asia/Shanghai")
times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8))
times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC))
for index, vtime := range times {
name := "time_with_zone_" + strconv.Itoa(index)
user := User{Name: name, Birthday: vtime}
if !DialectHasTzSupport() {
// If our driver dialect doesn't support TZ's, just use UTC for everything here.
user.Birthday = vtime.UTC()
}
DB.Save(&user)
expectedBirthday := "2013-02-18 17:51:49 +0000"
foundBirthday := user.Birthday.UTC().Format(format)
if foundBirthday != expectedBirthday {
t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
}
var findUser, findUser2, findUser3 User
DB.First(&findUser, "name = ?", name)
foundBirthday = findUser.Birthday.UTC().Format(format)
if foundBirthday != expectedBirthday {
t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday)
}
if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
t.Errorf("User should be found")
}
if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() {
t.Errorf("User should not be found")
}
}
}
func TestHstore(t *testing.T) {
type Details struct {
Id int64
Bulk gorm.Hstore
}
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
t.Skip()
}
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil {
fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m")
panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err))
}
DB.Exec("drop table details")
if err := DB.CreateTable(&Details{}).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait"
bulk := map[string]*string{
"bankAccountId": &bankAccountId,
"phoneNumber": &phoneNumber,
"opinion": &opinion,
}
d := Details{Bulk: bulk}
DB.Save(&d)
var d2 Details
if err := DB.First(&d2).Error; err != nil {
t.Errorf("Got error when tried to fetch details: %+v", err)
}
for k := range bulk {
if r, ok := d2.Bulk[k]; ok {
if res, _ := bulk[k]; *res != *r {
t.Errorf("Details should be equal")
}
} else {
t.Errorf("Details should be existed")
}
}
}
func TestSetAndGet(t *testing.T) {
if value, ok := DB.Set("hello", "world").Get("hello"); !ok {
t.Errorf("Should be able to get setting after set")
} else {
if value.(string) != "world" {
t.Errorf("Setted value should not be changed")
}
}
if _, ok := DB.Get("non_existing"); ok {
t.Errorf("Get non existing key should return error")
}
}
func TestCompatibilityMode(t *testing.T) {
DB, _ := gorm.Open("testdb", "")
testdb.SetQueryFunc(func(query string) (driver.Rows, error) {
columns := []string{"id", "name", "age"}
result := `
1,Tim,20
2,Joe,25
3,Bob,30
`
return testdb.RowsFromCSVString(columns, result), nil
})
var users []User
DB.Find(&users)
if (users[0].Name != "Tim") || len(users) != 3 {
t.Errorf("Unexcepted result returned")
}
}
func TestOpenExistingDB(t *testing.T) {
DB.Save(&User{Name: "jnfeinstein"})
dialect := os.Getenv("GORM_DIALECT")
db, err := gorm.Open(dialect, DB.DB())
if err != nil {
t.Errorf("Should have wrapped the existing DB connection")
}
var user User
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.RecordNotFound {
t.Errorf("Should have found existing record")
}
}
func BenchmarkGorm(b *testing.B) {
b.N = 2000
for x := 0; x < b.N; x++ {
e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// Insert
DB.Save(&email)
// Query
DB.First(&BigEmail{}, "email = ?", e)
// Update
DB.Model(&email).UpdateColumn("email", "new-"+e)
// Delete
DB.Delete(&email)
}
}
func BenchmarkRawSql(b *testing.B) {
DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable")
DB.SetMaxIdleConns(10)
insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
deleteSql := "DELETE FROM orders WHERE id = $1"
b.N = 2000
for x := 0; x < b.N; x++ {
var id int64
e := strconv.Itoa(x) + "benchmark@example.org"
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: time.Now()}
// Insert
DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
// Query
rows, _ := DB.Query(querySql, email.Email)
rows.Close()
// Update
DB.Exec(updateSql, "new-"+e, time.Now(), id)
// Delete
DB.Exec(deleteSql, id)
}
}
package gorm_test
import (
"fmt"
"testing"
"time"
)
func runMigration() {
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
fmt.Printf("Got error when try to delete table users, %+v\n", err)
}
for _, table := range []string{"animals", "user_languages"} {
DB.Exec(fmt.Sprintf("drop table %v;", table))
}
values := []interface{}{&Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}}
for _, value := range values {
DB.DropTable(value)
}
if err := DB.AutoMigrate(values...).Error; err != nil {
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
}
}
func TestIndexes(t *testing.T) {
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
scope := DB.NewScope(&Email{})
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
t.Errorf("Email should have index idx_email_email")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email") {
t.Errorf("Email's index idx_email_email should be deleted")
}
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
t.Errorf("Got error when tried to create index: %+v", err)
}
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email should have index idx_email_email_and_user_id")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
t.Errorf("Should get to create duplicate record when having unique index")
}
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
t.Errorf("Got error when tried to remove index: %+v", err)
}
if scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_email_and_user_id") {
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
}
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
t.Errorf("Should be able to create duplicated emails after remove unique index")
}
}
type BigEmail struct {
Id int64
UserId int64
Email string `sql:"index:idx_email_agent"`
UserAgent string `sql:"index:idx_email_agent"`
RegisteredAt time.Time `sql:"unique_index"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (b BigEmail) TableName() string {
return "emails"
}
func TestAutoMigration(t *testing.T) {
DB.AutoMigrate(&Address{})
if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil {
t.Errorf("Auto Migrate should not raise any error")
}
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: time.Now()})
scope := DB.NewScope(&BigEmail{})
if !scope.Dialect().HasIndex(scope, scope.TableName(), "idx_email_agent") {
t.Errorf("Failed to create index")
}
if !scope.Dialect().HasIndex(scope, scope.TableName(), "uix_emails_registered_at") {
t.Errorf("Failed to create index")
}
var bigemail BigEmail
DB.First(&bigemail, "user_agent = ?", "pc")
if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
t.Error("Big Emails should be saved and fetched correctly")
}
}
package gorm
import "time"
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}
package gorm
import (
"database/sql"
"fmt"
"go/ast"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
var modelStructs = map[reflect.Type]*ModelStruct{}
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}
type ModelStruct struct {
PrimaryFields []*StructField
StructFields []*StructField
ModelType reflect.Type
defaultTableName string
}
func (s ModelStruct) TableName(db *DB) string {
return DefaultTableNameHandler(db, s.defaultTableName)
}
type StructField struct {
DBName string
Name string
Names []string
IsPrimaryKey bool
IsNormal bool
IsIgnored bool
IsScanner bool
HasDefaultValue bool
Tag reflect.StructTag
Struct reflect.StructField
IsForeignKey bool
Relationship *Relationship
}
func (structField *StructField) clone() *StructField {
return &StructField{
DBName: structField.DBName,
Name: structField.Name,
Names: structField.Names,
IsPrimaryKey: structField.IsPrimaryKey,
IsNormal: structField.IsNormal,
IsIgnored: structField.IsIgnored,
IsScanner: structField.IsScanner,
HasDefaultValue: structField.HasDefaultValue,
Tag: structField.Tag,
Struct: structField.Struct,
IsForeignKey: structField.IsForeignKey,
Relationship: structField.Relationship,
}
}
type Relationship struct {
Kind string
PolymorphicType string
PolymorphicDBName string
ForeignFieldName string
ForeignDBName string
AssociationForeignFieldName string
AssociationForeignDBName string
JoinTableHandler JoinTableHandlerInterface
}
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct
reflectValue := reflect.Indirect(reflect.ValueOf(scope.Value))
if !reflectValue.IsValid() {
return &modelStruct
}
if reflectValue.Kind() == reflect.Slice {
reflectValue = reflect.Indirect(reflect.New(reflectValue.Type().Elem()))
}
scopeType := reflectValue.Type()
if scopeType.Kind() == reflect.Ptr {
scopeType = scopeType.Elem()
}
if value, ok := modelStructs[scopeType]; ok {
return value
}
modelStruct.ModelType = scopeType
if scopeType.Kind() != reflect.Struct {
return &modelStruct
}
// Set tablename
type tabler interface {
TableName() string
}
if tabler, ok := reflect.New(scopeType).Interface().(interface {
TableName() string
}); ok {
modelStruct.defaultTableName = tabler.TableName()
} else {
name := ToDBName(scopeType.Name())
if scope.db == nil || !scope.db.parent.singularTable {
for index, reg := range pluralMapKeys {
if reg.MatchString(name) {
name = reg.ReplaceAllString(name, pluralMapValues[index])
}
}
}
modelStruct.defaultTableName = name
}
// Get all fields
fields := []*StructField{}
for i := 0; i < scopeType.NumField(); i++ {
if fieldStruct := scopeType.Field(i); ast.IsExported(fieldStruct.Name) {
field := &StructField{
Struct: fieldStruct,
Name: fieldStruct.Name,
Names: []string{fieldStruct.Name},
Tag: fieldStruct.Tag,
}
if fieldStruct.Tag.Get("sql") == "-" {
field.IsIgnored = true
} else {
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
if _, ok := gormSettings["PRIMARY_KEY"]; ok {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
if _, ok := sqlSettings["DEFAULT"]; ok {
field.HasDefaultValue = true
}
if value, ok := gormSettings["COLUMN"]; ok {
field.DBName = value
} else {
field.DBName = ToDBName(fieldStruct.Name)
}
}
fields = append(fields, field)
}
}
defer func() {
for _, field := range fields {
if !field.IsIgnored {
fieldStruct := field.Struct
fieldType, indirectType := fieldStruct.Type, fieldStruct.Type
if indirectType.Kind() == reflect.Ptr {
indirectType = indirectType.Elem()
}
if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner {
field.IsScanner, field.IsNormal = true, true
}
if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
field.IsNormal = true
}
if !field.IsNormal {
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
getForeignField := func(column string, fields []*StructField) *StructField {
for _, field := range fields {
if field.Name == column || field.DBName == ToDBName(column) {
return field
}
}
return nil
}
var relationship = &Relationship{}
foreignKey := gormSettings["FOREIGNKEY"]
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
relationship.ForeignFieldName = polymorphicField.Name
relationship.ForeignDBName = polymorphicField.DBName
relationship.PolymorphicType = polymorphicType.Name
relationship.PolymorphicDBName = polymorphicType.DBName
polymorphicType.IsForeignKey = true
polymorphicField.IsForeignKey = true
}
}
}
switch indirectType.Kind() {
case reflect.Slice:
elemType := indirectType.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
if foreignKey == "" {
foreignKey = scopeType.Name() + "Id"
}
if many2many := gormSettings["MANY2MANY"]; many2many != "" {
relationship.Kind = "many_to_many"
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
if associationForeignKey == "" {
associationForeignKey = elemType.Name() + "Id"
}
relationship.ForeignFieldName = foreignKey
relationship.ForeignDBName = ToDBName(foreignKey)
relationship.AssociationForeignFieldName = associationForeignKey
relationship.AssociationForeignDBName = ToDBName(associationForeignKey)
joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
relationship.JoinTableHandler = &joinTableHandler
field.Relationship = relationship
} else {
relationship.Kind = "has_many"
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else if relationship.ForeignFieldName != "" {
field.Relationship = relationship
}
}
} else {
field.IsNormal = true
}
case reflect.Struct:
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
for _, toField := range toScope.GetStructFields() {
toField = toField.clone()
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
modelStruct.StructFields = append(modelStruct.StructFields, toField)
if toField.IsPrimaryKey {
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
}
}
continue
} else {
belongsToForeignKey := foreignKey
if belongsToForeignKey == "" {
belongsToForeignKey = field.Name + "Id"
}
if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil {
relationship.Kind = "belongs_to"
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else {
if foreignKey == "" {
foreignKey = modelStruct.ModelType.Name() + "Id"
}
relationship.Kind = "has_one"
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else if relationship.ForeignFieldName != "" {
field.Relationship = relationship
}
}
}
default:
field.IsNormal = true
}
}
if field.IsNormal {
if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
}
}
modelStruct.StructFields = append(modelStruct.StructFields, field)
}
}()
modelStructs[scopeType] = &modelStruct
return &modelStruct
}
func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields
}
func (scope *Scope) generateSqlTag(field *StructField) string {
var sqlType string
structType := field.Struct.Type
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
}
reflectValue := reflect.Indirect(reflect.New(structType))
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
if value, ok := sqlSettings["TYPE"]; ok {
sqlType = value
}
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
if value, ok := sqlSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}
if field.IsScanner {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
reflectValue = value
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
getScannerValue(reflectValue.Field(0))
}
}
getScannerValue(reflectValue)
}
if sqlType == "" {
var size = 255
if value, ok := sqlSettings["SIZE"]; ok {
size, _ = strconv.Atoi(value)
}
_, autoIncrease := sqlSettings["AUTO_INCREMENT"]
if field.IsPrimaryKey {
autoIncrease = true
}
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
} else {
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
}
func parseTagSetting(str string) map[string]string {
tags := strings.Split(str, ";")
setting := map[string]string{}
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if len(v) == 2 {
setting[k] = v[1]
} else {
setting[k] = k
}
}
return setting
}
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type mssql struct {
commonDialect
}
func (mssql) HasTop() bool {
return true
}
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int IDENTITY(1,1)"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigint IDENTITY(1,1)"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "float"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("nvarchar(%d)", size)
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "datetime2"
}
default:
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
}
func (mssql) databaseName(scope *Scope) string {
dbStr := strings.Split(scope.db.parent.source, ";")
for _, value := range dbStr {
s := strings.Split(value, "=")
if s[0] == "database" {
return s[1]
}
}
return ""
}
func (s mssql) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
return count > 0
}
package gorm_test
import (
"fmt"
"os"
"testing"
)
type Blog struct {
ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"`
Subject string
Body string
Tags []Tag `gorm:"many2many:blog_tags;"`
}
type Tag struct {
ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"`
Value string
}
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" {
DB.Exec(fmt.Sprintf("drop table blog_tags;"))
DB.AutoMigrate(&Blog{}, &Tag{})
blog := Blog{
Locale: "ZH",
Subject: "subject",
Body: "body",
Tags: []Tag{
{Locale: "ZH", Value: "tag1"},
{Locale: "ZH", Value: "tag2"},
},
}
DB.Save(&blog)
DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}})
var tags []Tag
DB.Model(&blog).Related(&tags, "Tags")
if len(tags) != 3 {
t.Errorf("should found 3 tags with blog")
}
}
}
package gorm
import (
"fmt"
"reflect"
"time"
)
type mysql struct {
commonDialect
}
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if autoIncrease {
return "int AUTO_INCREMENT"
}
return "int"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int unsigned AUTO_INCREMENT"
}
return "int unsigned"
case reflect.Int64:
if autoIncrease {
return "bigint AUTO_INCREMENT"
}
return "bigint"
case reflect.Uint64:
if autoIncrease {
return "bigint unsigned AUTO_INCREMENT"
}
return "bigint unsigned"
case reflect.Float32, reflect.Float64:
return "double"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "longtext"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "timestamp NULL"
}
default:
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
}
return "longblob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}
func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
package gorm_test
import "testing"
type PointerStruct struct {
ID int64
Name *string
Num *int
}
type NormalStruct struct {
ID int64
Name string
Num int
}
func TestPointerFields(t *testing.T) {
DB.DropTable(&PointerStruct{})
DB.AutoMigrate(&PointerStruct{})
var name = "pointer struct 1"
var num = 100
pointerStruct := PointerStruct{Name: &name, Num: &num}
if DB.Create(&pointerStruct).Error != nil {
t.Errorf("Failed to save pointer struct")
}
var pointerStructResult PointerStruct
if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num {
t.Errorf("Failed to query saved pointer struct")
}
var tableName = DB.NewScope(&PointerStruct{}).TableName()
var normalStruct NormalStruct
DB.Table(tableName).First(&normalStruct)
if normalStruct.Name != name || normalStruct.Num != num {
t.Errorf("Failed to query saved Normal struct")
}
var nilPointerStruct = PointerStruct{}
if err := DB.Create(&nilPointerStruct).Error; err != nil {
t.Errorf("Failed to save nil pointer struct", err)
}
var pointerStruct2 PointerStruct
if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
t.Errorf("Failed to query saved nil pointer struct", err)
}
var normalStruct2 NormalStruct
if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
t.Errorf("Failed to query saved nil pointer struct", err)
}
var partialNilPointerStruct1 = PointerStruct{Num: &num}
if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
t.Errorf("Failed to save partial nil pointer struct", err)
}
var pointerStruct3 PointerStruct
if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
t.Errorf("Failed to query saved partial nil pointer struct", err)
}
var normalStruct3 NormalStruct
if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
t.Errorf("Failed to query saved partial pointer struct", err)
}
var partialNilPointerStruct2 = PointerStruct{Name: &name}
if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
t.Errorf("Failed to save partial nil pointer struct", err)
}
var pointerStruct4 PointerStruct
if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
t.Errorf("Failed to query saved partial nil pointer struct", err)
}
var normalStruct4 NormalStruct
if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
t.Errorf("Failed to query saved partial pointer struct", err)
}
}
package gorm_test
import "testing"
type Cat struct {
Id int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}
type Dog struct {
Id int
Name string
Toys []Toy `gorm:"polymorphic:Owner;"`
}
type Toy struct {
Id int
Name string
OwnerId int
OwnerType string
}
func TestPolymorphic(t *testing.T) {
DB.AutoMigrate(&Cat{})
DB.AutoMigrate(&Dog{})
DB.AutoMigrate(&Toy{})
cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat nip"}}
dog := Dog{Name: "Pluto", Toys: []Toy{Toy{Name: "orange ball"}, Toy{Name: "yellow ball"}}}
DB.Save(&cat).Save(&dog)
var catToys []Toy
if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() {
t.Errorf("Did not find any has one polymorphic association")
} else if len(catToys) != 1 {
t.Errorf("Should have found only one polymorphic has one association")
} else if catToys[0].Name != cat.Toy.Name {
t.Errorf("Should have found the proper has one polymorphic association")
}
var dogToys []Toy
if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() {
t.Errorf("Did not find any polymorphic has many associations")
} else if len(dogToys) != len(dog.Toys) {
t.Errorf("Should have found all polymorphic has many associations")
}
if DB.Model(&cat).Association("Toy").Count() != 1 {
t.Errorf("Should return one polymorphic has one association")
}
if DB.Model(&dog).Association("Toys").Count() != 2 {
t.Errorf("Should return two polymorphic has many associations")
}
}
package gorm
import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strings"
)
func getRealValue(value reflect.Value, field string) interface{} {
result := reflect.Indirect(value).FieldByName(field).Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
}
return result
}
func equalAsString(a interface{}, b interface{}) bool {
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
}
func Preload(scope *Scope) {
if scope.Search.preload == nil {
return
}
preloadMap := map[string]bool{}
fields := scope.Fields()
for _, preload := range scope.Search.preload {
schema, conditions := preload.schema, preload.conditions
keys := strings.Split(schema, ".")
currentScope := scope
currentFields := fields
originalConditions := conditions
conditions = []interface{}{}
for i, key := range keys {
var found bool
if preloadMap[strings.Join(keys[:i+1], ".")] {
goto nextLoop
}
if i == len(keys)-1 {
conditions = originalConditions
}
for _, field := range currentFields {
if field.Name != key || field.Relationship == nil {
continue
}
found = true
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, conditions)
case "has_many":
currentScope.handleHasManyPreload(field, conditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, conditions)
case "many_to_many":
fallthrough
default:
currentScope.Err(errors.New("not supported relation"))
}
break
}
if !found {
value := reflect.ValueOf(currentScope.Value)
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
value = value.Index(0).Elem()
}
scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
return
}
preloadMap[strings.Join(keys[:i+1], ".")] = true
nextLoop:
if i < len(keys)-1 {
currentScope = currentScope.getColumnsAsScope(key)
currentFields = currentScope.Fields()
}
}
}
}
func makeSlice(typ reflect.Type) interface{} {
if typ.Kind() == reflect.Slice {
typ = typ.Elem()
}
sliceType := reflect.SliceOf(typ)
slice := reflect.New(sliceType)
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
return slice.Interface()
}
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name
primaryKeys := scope.getColumnAsArray(primaryName)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
relation := field.Relationship
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
}
}
} else {
if err := scope.SetColumn(field, result); err != nil {
scope.Err(err)
return
}
}
}
}
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name
primaryKeys := scope.getColumnAsArray(primaryName)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
relation := field.Relationship
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
if scope.IndirectValue().Kind() == reflect.Slice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, primaryName), value) {
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
}
}
}
} else {
scope.SetColumn(field, resultValues)
}
}
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
associationPrimaryKey := scope.New(results).PrimaryField().Name
scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, associationPrimaryKey)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
object.FieldByName(field.Name).Set(result)
}
}
} else {
scope.SetColumn(field, result)
}
}
}
func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
for i := 0; i < values.Len(); i++ {
columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
}
case reflect.Struct:
return []interface{}{values.FieldByName(column).Interface()}
}
return
}
func (scope *Scope) getColumnsAsScope(column string) *Scope {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
modelType := values.Type().Elem()
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
fieldStruct, _ := modelType.FieldByName(column)
var columns reflect.Value
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
} else {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
}
for i := 0; i < values.Len(); i++ {
column := reflect.Indirect(values.Index(i)).FieldByName(column)
if column.Kind() == reflect.Ptr {
column = column.Elem()
}
if column.Kind() == reflect.Slice {
for i := 0; i < column.Len(); i++ {
columns = reflect.Append(columns, column.Index(i).Addr())
}
} else {
columns = reflect.Append(columns, column.Addr())
}
}
return scope.New(columns.Interface())
case reflect.Struct:
return scope.New(values.FieldByName(column).Addr().Interface())
}
return nil
}
package gorm_test
import (
"encoding/json"
"reflect"
"testing"
)
func getPreloadUser(name string) *User {
return getPreparedUser(name, "Preload")
}
func checkUserHasPreloadData(user User, t *testing.T) {
u := getPreloadUser(user.Name)
if user.BillingAddress.Address1 != u.BillingAddress.Address1 {
t.Error("Failed to preload user's BillingAddress")
}
if user.ShippingAddress.Address1 != u.ShippingAddress.Address1 {
t.Error("Failed to preload user's ShippingAddress")
}
if user.CreditCard.Number != u.CreditCard.Number {
t.Error("Failed to preload user's CreditCard")
}
if user.Company.Name != u.Company.Name {
t.Error("Failed to preload user's Company")
}
if len(user.Emails) != len(u.Emails) {
t.Error("Failed to preload user's Emails")
} else {
var found int
for _, e1 := range u.Emails {
for _, e2 := range user.Emails {
if e1.Email == e2.Email {
found++
break
}
}
}
if found != len(u.Emails) {
t.Error("Failed to preload user's email details")
}
}
}
func TestPreload(t *testing.T) {
user1 := getPreloadUser("user1")
DB.Save(user1)
preloadDB := DB.Where("role = ?", "Preload").Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company")
var user User
preloadDB.Find(&user)
checkUserHasPreloadData(user, t)
user2 := getPreloadUser("user2")
DB.Save(user2)
user3 := getPreloadUser("user3")
DB.Save(user3)
var users []User
preloadDB.Find(&users)
for _, user := range users {
checkUserHasPreloadData(user, t)
}
var users2 []*User
preloadDB.Find(&users2)
for _, user := range users2 {
checkUserHasPreloadData(*user, t)
}
var users3 []*User
preloadDB.Preload("Emails", "email = ?", user3.Emails[0].Email).Find(&users3)
for _, user := range users3 {
if user.Name == user3.Name {
if len(user.Emails) != 1 {
t.Errorf("should only preload one emails for user3 when with condition")
}
} else if len(user.Emails) != 0 {
t.Errorf("should not preload any emails for other users when with condition")
}
}
}
func TestNestedPreload1(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1 Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
if err := DB.Create(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload2(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1s []*Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2s []Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{
Level2s: []Level2{
{
Level1s: []*Level1{
&Level1{Value: "value1"},
&Level1{Value: "value2"},
},
},
{
Level1s: []*Level1{
&Level1{Value: "value3"},
},
},
},
}
if err := DB.Create(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload3(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1 Level1
Level3ID uint
}
Level3 struct {
Name string
ID uint
Level2s []Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{
Level2s: []Level2{
{Level1: Level1{Value: "value1"}},
{Level1: Level1{Value: "value2"}},
},
}
if err := DB.Create(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload4(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
},
},
}
if err := DB.Create(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
// Slice: []Level3
func TestNestedPreload5(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1 Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload6(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2s []Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{
Level2s: []Level2{
{
Level1s: []Level1{
{Value: "value1"},
{Value: "value2"},
},
},
{
Level1s: []Level1{
{Value: "value3"},
},
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{
Level2s: []Level2{
{
Level1s: []Level1{
{Value: "value3"},
{Value: "value4"},
},
},
{
Level1s: []Level1{
{Value: "value5"},
},
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload7(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1 Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2s []Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{
Level2s: []Level2{
{Level1: Level1{Value: "value1"}},
{Level1: Level1{Value: "value2"}},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{
Level2s: []Level2{
{Level1: Level1{Value: "value3"}},
{Level1: Level1{Value: "value4"}},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload8(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value3"},
Level1{Value: "value4"},
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload9(t *testing.T) {
type (
Level0 struct {
ID uint
Value string
Level1ID uint
}
Level1 struct {
ID uint
Value string
Level2ID uint
Level2_1ID uint
Level0s []Level0
}
Level2 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level2_1 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
Level2_1 Level2_1
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level2_1{})
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists(&Level0{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
},
},
Level2_1: Level2_1{
Level1s: []Level1{
Level1{
Value: "value1-1",
Level0s: []Level0{{Value: "Level0-1"}},
},
Level1{
Value: "value2-2",
Level0s: []Level0{{Value: "Level0-2"}},
},
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value3"},
Level1{Value: "value4"},
},
},
Level2_1: Level2_1{
Level1s: []Level1{
Level1{Value: "value3-3"},
Level1{Value: "value4-4"},
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func toJSONString(v interface{}) []byte {
r, _ := json.MarshalIndent(v, "", " ")
return r
}
package gorm_test
import (
"fmt"
"reflect"
"github.com/jinzhu/now"
"testing"
"time"
)
func TestFirstAndLast(t *testing.T) {
DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}})
DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}})
var user1, user2, user3, user4 User
DB.First(&user1)
DB.Order("id").Limit(1).Find(&user2)
DB.Last(&user3)
DB.Order("id desc").Limit(1).Find(&user4)
if user1.Id != user2.Id || user3.Id != user4.Id {
t.Errorf("First and Last should by order by primary key")
}
var users []User
DB.First(&users)
if len(users) != 1 {
t.Errorf("Find first record as slice")
}
if DB.Joins("left join emails on emails.user_id = users.id").First(&User{}).Error != nil {
t.Errorf("Should not raise any error when order with Join table")
}
}
func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) {
DB.Save(&Animal{Name: "animal1"})
DB.Save(&Animal{Name: "animal2"})
var animal1, animal2, animal3, animal4 Animal
DB.First(&animal1)
DB.Order("counter").Limit(1).Find(&animal2)
DB.Last(&animal3)
DB.Order("counter desc").Limit(1).Find(&animal4)
if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
t.Errorf("First and Last should work correctly")
}
}
func TestUIntPrimaryKey(t *testing.T) {
var animal Animal
DB.First(&animal, uint64(1))
if animal.Counter != 1 {
t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
}
DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal)
if animal.Counter != 2 {
t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
}
}
func TestFindAsSliceOfPointers(t *testing.T) {
DB.Save(&User{Name: "user"})
var users []User
DB.Find(&users)
var userPointers []*User
DB.Find(&userPointers)
if len(users) == 0 || len(users) != len(userPointers) {
t.Errorf("Find slice of pointers")
}
}
func TestSearchWithPlainSQL(t *testing.T) {
user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%")
if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
t.Errorf("Search with plain SQL")
}
if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() {
t.Errorf("Search with plan SQL (regexp)")
}
var users []User
DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1)
if len(users) != 2 {
t.Errorf("Should found 2 users that age > 1, but got %v", len(users))
}
DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users)
if len(users) != 3 {
t.Errorf("Should found 3 users that age >= 1, but got %v", len(users))
}
scopedb.Where("age <> ?", 20).Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users age != 20, but got %v", len(users))
}
scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
}
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
}
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
if len(users) != 1 {
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
}
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users, but got %v", len(users))
}
DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users)
if len(users) != 3 {
t.Errorf("Should found 3 users, but got %v", len(users))
}
DB.Where("id in (?)", user1.Id).Find(&users)
if len(users) != 1 {
t.Errorf("Should found 1 users, but got %v", len(users))
}
if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() {
t.Errorf("Should not get RecordNotFound error when looking for none existing records")
}
}
func TestSearchWithStruct(t *testing.T) {
user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
if DB.Where(user1.Id).First(&User{}).RecordNotFound() {
t.Errorf("Search with primary key")
}
if DB.First(&User{}, user1.Id).RecordNotFound() {
t.Errorf("Search with primary key as inline condition")
}
if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() {
t.Errorf("Search with primary key as inline condition")
}
var users []User
DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users)
if len(users) != 3 {
t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users))
}
var user User
DB.First(&user, &User{Name: user1.Name})
if user.Id == 0 || user.Name != user1.Name {
t.Errorf("Search first record with inline pointer of struct")
}
DB.First(&user, User{Name: user1.Name})
if user.Id == 0 || user.Name != user.Name {
t.Errorf("Search first record with inline struct")
}
DB.Where(&User{Name: user1.Name}).First(&user)
if user.Id == 0 || user.Name != user1.Name {
t.Errorf("Search first record with where struct")
}
DB.Find(&users, &User{Name: user2.Name})
if len(users) != 1 {
t.Errorf("Search all records with inline struct")
}
}
func TestSearchWithMap(t *testing.T) {
user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
var user User
DB.First(&user, map[string]interface{}{"name": user1.Name})
if user.Id == 0 || user.Name != user1.Name {
t.Errorf("Search first record with inline map")
}
user = User{}
DB.Where(map[string]interface{}{"name": user2.Name}).First(&user)
if user.Id == 0 || user.Name != user2.Name {
t.Errorf("Search first record with where map")
}
var users []User
DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users)
if len(users) != 1 {
t.Errorf("Search all records with inline map")
}
DB.Find(&users, map[string]interface{}{"name": user3.Name})
if len(users) != 1 {
t.Errorf("Search all records with inline map")
}
}
func TestSearchWithEmptyChain(t *testing.T) {
user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: now.MustParse("2000-1-1")}
user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: now.MustParse("2010-1-1")}
user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: now.MustParse("2020-1-1")}
DB.Save(&user1).Save(&user2).Save(&user3)
if DB.Where("").Where("").First(&User{}).Error != nil {
t.Errorf("Should not raise any error if searching with empty strings")
}
if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
t.Errorf("Should not raise any error if searching with empty struct")
}
if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
t.Errorf("Should not raise any error if searching with empty map")
}
}
func TestSelect(t *testing.T) {
user1 := User{Name: "SelectUser1"}
DB.Save(&user1)
var user User
DB.Where("name = ?", user1.Name).Select("name").Find(&user)
if user.Id != 0 {
t.Errorf("Should not have ID because only selected name, %+v", user.Id)
}
if user.Name != user1.Name {
t.Errorf("Should have user Name when selected it")
}
}
func TestOrderAndPluck(t *testing.T) {
user1 := User{Name: "OrderPluckUser1", Age: 1}
user2 := User{Name: "OrderPluckUser2", Age: 10}
user3 := User{Name: "OrderPluckUser3", Age: 20}
DB.Save(&user1).Save(&user2).Save(&user3)
scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%")
var ages []int64
scopedb.Order("age desc").Pluck("age", &ages)
if ages[0] != 20 {
t.Errorf("The first age should be 20 when order with age desc")
}
var ages1, ages2 []int64
scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2)
if !reflect.DeepEqual(ages1, ages2) {
t.Errorf("The first order is the primary order")
}
var ages3, ages4 []int64
scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4)
if reflect.DeepEqual(ages3, ages4) {
t.Errorf("Reorder should work")
}
var names []string
var ages5 []int64
scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names)
if names != nil && ages5 != nil {
if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) {
t.Errorf("Order with multiple orders")
}
} else {
t.Errorf("Order with multiple orders")
}
DB.Model(User{}).Select("name, age").Find(&[]User{})
}
func TestLimit(t *testing.T) {
user1 := User{Name: "LimitUser1", Age: 1}
user2 := User{Name: "LimitUser2", Age: 10}
user3 := User{Name: "LimitUser3", Age: 20}
user4 := User{Name: "LimitUser4", Age: 10}
user5 := User{Name: "LimitUser5", Age: 20}
DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5)
var users1, users2, users3 []User
DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
t.Errorf("Limit should works")
}
}
func TestOffset(t *testing.T) {
for i := 0; i < 20; i++ {
DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)})
}
var users1, users2, users3, users4 []User
DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
t.Errorf("Offset should work")
}
}
func TestOr(t *testing.T) {
user1 := User{Name: "OrUser1", Age: 1}
user2 := User{Name: "OrUser2", Age: 10}
user3 := User{Name: "OrUser3", Age: 20}
DB.Save(&user1).Save(&user2).Save(&user3)
var users []User
DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users)
if len(users) != 2 {
t.Errorf("Find users with or")
}
}
func TestCount(t *testing.T) {
user1 := User{Name: "CountUser1", Age: 1}
user2 := User{Name: "CountUser2", Age: 10}
user3 := User{Name: "CountUser3", Age: 20}
DB.Save(&user1).Save(&user2).Save(&user3)
var count, count1, count2 int64
var users []User
if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil {
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
}
if count != int64(len(users)) {
t.Errorf("Count() method should get correct value")
}
DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2)
if count1 != 1 || count2 != 3 {
t.Errorf("Multiple count in chain")
}
}
func TestNot(t *testing.T) {
DB.Create(getPreparedUser("user1", "not"))
DB.Create(getPreparedUser("user2", "not"))
DB.Create(getPreparedUser("user3", "not"))
DB.Create(getPreparedUser("user4", "not"))
DB := DB.Where("role = ?", "not")
var users1, users2, users3, users4, users5, users6, users7, users8 []User
if DB.Find(&users1).RowsAffected != 4 {
t.Errorf("should find 4 not users")
}
DB.Not(users1[0].Id).Find(&users2)
if len(users1)-len(users2) != 1 {
t.Errorf("Should ignore the first users with Not")
}
DB.Not([]int{}).Find(&users3)
if len(users1)-len(users3) != 0 {
t.Errorf("Should find all users with a blank condition")
}
var name3Count int64
DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
DB.Not("name", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
}
DB.Not("name = ?", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
}
DB.Not("name <> ?", "user3").Find(&users4)
if len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
}
DB.Not(User{Name: "user3"}).Find(&users5)
if len(users1)-len(users5) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
}
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
if len(users1)-len(users6) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
}
DB.Not("name", []string{"user3"}).Find(&users7)
if len(users1)-len(users7) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
}
var name2Count int64
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
DB.Not("name", []string{"user3", "user2"}).Find(&users8)
if len(users1)-len(users8) != (int(name3Count) + int(name2Count)) {
t.Errorf("Should find all users's name not equal 3")
}
}
func TestFillSmallerStruct(t *testing.T) {
user1 := User{Name: "SmallerUser", Age: 100}
DB.Save(&user1)
type SimpleUser struct {
Name string
Id int64
UpdatedAt time.Time
CreatedAt time.Time
}
var simpleUser SimpleUser
DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser)
if simpleUser.Id == 0 || simpleUser.Name == "" {
t.Errorf("Should fill data correctly into smaller struct")
}
}
func TestFindOrInitialize(t *testing.T) {
var user1, user2, user3, user4, user5, user6 User
DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1)
if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 {
t.Errorf("user should be initialized with search value")
}
DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2)
if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 {
t.Errorf("user should be initialized with search value")
}
DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"})
if user3.Name != "find or init 2" || user3.Id != 0 {
t.Errorf("user should be initialized with inline search value")
}
DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4)
if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 {
t.Errorf("user should be initialized with search value and attrs")
}
DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4)
if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 {
t.Errorf("user should be initialized with search value and assign attrs")
}
DB.Save(&User{Name: "find or init", Age: 33})
DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5)
if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 {
t.Errorf("user should be found and not initialized by Attrs")
}
DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6)
if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 {
t.Errorf("user should be found with FirstOrInit")
}
DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6)
if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 {
t.Errorf("user should be found and updated with assigned attrs")
}
}
func TestFindOrCreate(t *testing.T) {
var user1, user2, user3, user4, user5, user6, user7, user8 User
DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1)
if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 {
t.Errorf("user should be created with search value")
}
DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2)
if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 {
t.Errorf("user should be created with search value")
}
DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"})
if user3.Name != "find or create 2" || user3.Id == 0 {
t.Errorf("user should be created with inline search value")
}
DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4)
if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 {
t.Errorf("user should be created with search value and attrs")
}
updatedAt1 := user4.UpdatedAt
DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("UpdateAt should be changed when update values with assign")
}
DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4)
if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 {
t.Errorf("user should be created with search value and assigned attrs")
}
DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5)
if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 {
t.Errorf("user should be found and not initialized by Attrs")
}
DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6)
if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 {
t.Errorf("user should be found and updated with assigned attrs")
}
DB.Where(&User{Name: "find or create"}).Find(&user7)
if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 {
t.Errorf("user should be found and updated with assigned attrs")
}
DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8)
if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() {
t.Errorf("embedded struct email should be saved")
}
if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() {
t.Errorf("embedded struct credit card should be saved")
}
}
func TestSelectWithEscapedFieldName(t *testing.T) {
user1 := User{Name: "EscapedFieldNameUser", Age: 1}
user2 := User{Name: "EscapedFieldNameUser", Age: 10}
user3 := User{Name: "EscapedFieldNameUser", Age: 20}
DB.Save(&user1).Save(&user2).Save(&user3)
var names []string
DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names)
if len(names) != 3 {
t.Errorf("Expected 3 name, but got: %d", len(names))
}
}
func TestSelectWithVariables(t *testing.T) {
DB.Save(&User{Name: "jinzhu"})
rows, _ := DB.Table("users").Select("? as fake", "name").Rows()
if !rows.Next() {
t.Errorf("Should have returned at least one row")
} else {
columns, _ := rows.Columns()
if !reflect.DeepEqual(columns, []string{"fake"}) {
t.Errorf("Should only contains one column")
}
}
}
func TestSelectWithArrayInput(t *testing.T) {
DB.Save(&User{Name: "jinzhu", Age: 42})
var user User
DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user)
if user.Name != "jinzhu" || user.Age != 42 {
t.Errorf("Should have selected both age and name")
}
}
package gorm
import (
"errors"
"fmt"
"strings"
"time"
"reflect"
)
type Scope struct {
Search *search
Value interface{}
Sql string
SqlVars []interface{}
db *DB
indirectValue *reflect.Value
instanceId string
primaryKeyField *Field
skipLeft bool
fields map[string]*Field
selectAttrs *[]string
}
func (scope *Scope) IndirectValue() reflect.Value {
if scope.indirectValue == nil {
value := reflect.Indirect(reflect.ValueOf(scope.Value))
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
scope.indirectValue = &value
}
return *scope.indirectValue
}
func (scope *Scope) NeedPtr() *Scope {
reflectKind := reflect.ValueOf(scope.Value).Kind()
if !((reflectKind == reflect.Invalid) || (reflectKind == reflect.Ptr)) {
err := fmt.Errorf("%v %v\n", fileWithLineNum(), "using unaddressable value")
scope.Err(err)
fmt.Printf(err.Error())
}
return scope
}
// New create a new Scope without search information
func (scope *Scope) New(value interface{}) *Scope {
return &Scope{db: scope.NewDB(), Search: &search{}, Value: value}
}
// NewDB create a new DB without search information
func (scope *Scope) NewDB() *DB {
if scope.db != nil {
db := scope.db.clone()
db.search = nil
db.Value = nil
return db
}
return nil
}
func (scope *Scope) DB() *DB {
return scope.db
}
// SqlDB return *sql.DB
func (scope *Scope) SqlDB() sqlCommon {
return scope.db.db
}
// SkipLeft skip remaining callbacks
func (scope *Scope) SkipLeft() {
scope.skipLeft = true
}
// Quote used to quote database column name according to database dialect
func (scope *Scope) Quote(str string) string {
if strings.Index(str, ".") != -1 {
newStrs := []string{}
for _, str := range strings.Split(str, ".") {
newStrs = append(newStrs, scope.Dialect().Quote(str))
}
return strings.Join(newStrs, ".")
} else {
return scope.Dialect().Quote(str)
}
}
// Dialect get dialect
func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect
}
// Err write error
func (scope *Scope) Err(err error) error {
if err != nil {
scope.db.err(err)
}
return err
}
// Log print log message
func (scope *Scope) Log(v ...interface{}) {
scope.db.log(v...)
}
// HasError check if there are any error
func (scope *Scope) HasError() bool {
return scope.db.Error != nil
}
func (scope *Scope) PrimaryFields() []*Field {
var fields = []*Field{}
for _, field := range scope.GetModelStruct().PrimaryFields {
fields = append(fields, scope.Fields()[field.DBName])
}
return fields
}
func (scope *Scope) PrimaryField() *Field {
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
if len(primaryFields) > 1 {
if field, ok := scope.Fields()["id"]; ok {
return field
}
}
return scope.Fields()[primaryFields[0].DBName]
}
return nil
}
// PrimaryKey get the primary key's column name
func (scope *Scope) PrimaryKey() string {
if field := scope.PrimaryField(); field != nil {
return field.DBName
}
return ""
}
// PrimaryKeyZero check the primary key is blank or not
func (scope *Scope) PrimaryKeyZero() bool {
field := scope.PrimaryField()
return field == nil || field.IsBlank
}
// PrimaryKeyValue get the primary key's value
func (scope *Scope) PrimaryKeyValue() interface{} {
if field := scope.PrimaryField(); field != nil && field.Field.IsValid() {
return field.Field.Interface()
}
return 0
}
// HasColumn to check if has column
func (scope *Scope) HasColumn(column string) bool {
for _, field := range scope.GetStructFields() {
if field.IsNormal && (field.Name == column || field.DBName == column) {
return true
}
}
return false
}
// SetColumn to set the column's value
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
if field, ok := column.(*Field); ok {
return field.Set(value)
} else if name, ok := column.(string); ok {
if field, ok := scope.Fields()[name]; ok {
return field.Set(value)
}
dbName := ToDBName(name)
if field, ok := scope.Fields()[dbName]; ok {
return field.Set(value)
}
if field, ok := scope.FieldByName(name); ok {
return field.Set(value)
}
}
return errors.New("could not convert column to field")
}
func (scope *Scope) CallMethod(name string, checkError bool) {
if scope.Value == nil || (checkError && scope.HasError()) {
return
}
call := func(value interface{}) {
if fm := reflect.ValueOf(value).MethodByName(name); fm.IsValid() {
switch f := fm.Interface().(type) {
case func():
f()
case func(s *Scope):
f(scope)
case func(s *DB):
f(scope.NewDB())
case func() error:
scope.Err(f())
case func(s *Scope) error:
scope.Err(f(scope))
case func(s *DB) error:
scope.Err(f(scope.NewDB()))
default:
scope.Err(fmt.Errorf("unsupported function %v", name))
}
}
}
if values := scope.IndirectValue(); values.Kind() == reflect.Slice {
for i := 0; i < values.Len(); i++ {
call(values.Index(i).Addr().Interface())
}
} else {
call(scope.Value)
}
}
func (scope *Scope) CallMethodWithErrorCheck(name string) {
scope.CallMethod(name, true)
}
// AddToVars add value as sql's vars, gorm will escape them
func (scope *Scope) AddToVars(value interface{}) string {
if expr, ok := value.(*expr); ok {
exp := expr.expr
for _, arg := range expr.args {
exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1)
}
return exp
} else {
scope.SqlVars = append(scope.SqlVars, value)
return scope.Dialect().BinVar(len(scope.SqlVars))
}
}
type tabler interface {
TableName() string
}
type dbTabler interface {
TableName(*DB) string
}
// TableName get table name
func (scope *Scope) TableName() string {
if scope.Search != nil && len(scope.Search.tableName) > 0 {
return scope.Search.tableName
}
if tabler, ok := scope.Value.(tabler); ok {
return tabler.TableName()
}
if tabler, ok := scope.Value.(dbTabler); ok {
return tabler.TableName(scope.db)
}
return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
}
func (scope *Scope) QuotedTableName() (name string) {
if scope.Search != nil && len(scope.Search.tableName) > 0 {
if strings.Index(scope.Search.tableName, " ") != -1 {
return scope.Search.tableName
}
return scope.Quote(scope.Search.tableName)
} else {
return scope.Quote(scope.TableName())
}
}
// CombinedConditionSql get combined condition sql
func (scope *Scope) CombinedConditionSql() string {
return scope.joinsSql() + scope.whereSql() + scope.groupSql() +
scope.havingSql() + scope.orderSql() + scope.limitSql() + scope.offsetSql()
}
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
for _, field := range scope.Fields() {
if field.Name == name || field.DBName == name {
return field, true
}
}
return nil, false
}
// Raw set sql
func (scope *Scope) Raw(sql string) *Scope {
scope.Sql = strings.Replace(sql, "$$", "?", -1)
return scope
}
// Exec invoke sql
func (scope *Scope) Exec() *Scope {
defer scope.Trace(NowFunc())
if !scope.HasError() {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
scope.db.RowsAffected = count
}
}
}
return scope
}
// Set set value by name
func (scope *Scope) Set(name string, value interface{}) *Scope {
scope.db.InstantSet(name, value)
return scope
}
// Get get value by name
func (scope *Scope) Get(name string) (interface{}, bool) {
return scope.db.Get(name)
}
// InstanceId get InstanceId for scope
func (scope *Scope) InstanceId() string {
if scope.instanceId == "" {
scope.instanceId = fmt.Sprintf("%v%v", &scope, &scope.db)
}
return scope.instanceId
}
func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
return scope.Set(name+scope.InstanceId(), value)
}
func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
return scope.Get(name + scope.InstanceId())
}
// Trace print sql log
func (scope *Scope) Trace(t time.Time) {
if len(scope.Sql) > 0 {
scope.db.slog(scope.Sql, t, scope.SqlVars...)
}
}
// Begin start a transaction
func (scope *Scope) Begin() *Scope {
if db, ok := scope.SqlDB().(sqlDb); ok {
if tx, err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.InstanceSet("gorm:started_transaction", true)
}
}
return scope
}
// CommitOrRollback commit current transaction if there is no error, otherwise rollback it
func (scope *Scope) CommitOrRollback() *Scope {
if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
if db, ok := scope.db.db.(sqlTx); ok {
if scope.HasError() {
db.Rollback()
} else {
db.Commit()
}
scope.db.db = scope.db.parent.db
}
}
return scope
}
func (scope *Scope) SelectAttrs() []string {
if scope.selectAttrs == nil {
attrs := []string{}
for _, value := range scope.Search.selects {
if str, ok := value.(string); ok {
attrs = append(attrs, str)
} else if strs, ok := value.([]string); ok {
attrs = append(attrs, strs...)
} else if strs, ok := value.([]interface{}); ok {
for _, str := range strs {
attrs = append(attrs, fmt.Sprintf("%v", str))
}
}
}
scope.selectAttrs = &attrs
}
return *scope.selectAttrs
}
func (scope *Scope) OmitAttrs() []string {
return scope.Search.omits
}
func (scope *Scope) changeableDBColumn(column string) bool {
selectAttrs := scope.SelectAttrs()
omitAttrs := scope.OmitAttrs()
if len(selectAttrs) > 0 {
for _, attr := range selectAttrs {
if column == ToDBName(attr) {
return true
}
}
return false
}
for _, attr := range omitAttrs {
if column == ToDBName(attr) {
return false
}
}
return true
}
func (scope *Scope) changeableField(field *Field) bool {
selectAttrs := scope.SelectAttrs()
omitAttrs := scope.OmitAttrs()
if len(selectAttrs) > 0 {
for _, attr := range selectAttrs {
if field.Name == attr || field.DBName == attr {
return true
}
}
return false
}
for _, attr := range omitAttrs {
if field.Name == attr || field.DBName == attr {
return false
}
}
return !field.IsIgnored
}
func (scope *Scope) shouldSaveAssociations() bool {
saveAssociations, ok := scope.Get("gorm:save_associations")
if ok && !saveAssociations.(bool) {
return false
}
return true
}
package gorm
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
)
func (scope *Scope) primaryCondition(value interface{}) string {
return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value)
}
func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) {
switch value := clause["query"].(type) {
case string:
// if string is number
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return scope.primaryCondition(scope.AddToVars(id))
} else if value != "" {
str = fmt.Sprintf("(%v)", value)
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return scope.primaryCondition(scope.AddToVars(value))
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}:
str = fmt.Sprintf("(%v in (?))", scope.Quote(scope.PrimaryKey()))
clause["args"] = []interface{}{value}
case map[string]interface{}:
var sqls []string
for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value)))
}
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
for _, field := range scope.New(value).Fields() {
if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
values := reflect.ValueOf(arg)
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) {
var notEqualSql string
var primaryKey = scope.PrimaryKey()
switch value := clause["query"].(type) {
case string:
// is number
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value)
notEqualSql = fmt.Sprintf("NOT (%v)", value)
} else {
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value))
notEqualSql = fmt.Sprintf("(%v <> ?)", scope.Quote(value))
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64:
return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value)
case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string:
if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey))
clause["args"] = []interface{}{value}
}
return ""
case map[string]interface{}:
var sqls []string
for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value)))
}
return strings.Join(sqls, " AND ")
case interface{}:
var sqls []string
for _, field := range scope.New(value).Fields() {
if !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice: // For where("id in (?)", []int64{1,2})
values := reflect.ValueOf(arg)
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default:
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = scanner.Value()
}
str = strings.Replace(notEqualSql, "?", scope.AddToVars(arg), 1)
}
}
return
}
func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) {
switch value := clause["query"].(type) {
case string:
str = value
case []string:
str = strings.Join(value, ", ")
}
args := clause["args"].([]interface{})
for _, arg := range args {
switch reflect.ValueOf(arg).Kind() {
case reflect.Slice:
values := reflect.ValueOf(arg)
var tempMarks []string
for i := 0; i < values.Len(); i++ {
tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
}
str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
default:
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value()
}
str = strings.Replace(str, "?", scope.Dialect().Quote(fmt.Sprintf("%v", arg)), 1)
}
}
return
}
func (scope *Scope) whereSql() (sql string) {
var primaryConditions, andConditions, orConditions []string
if !scope.Search.Unscoped && scope.Fields()["deleted_at"] != nil {
sql := fmt.Sprintf("(%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02')", scope.QuotedTableName(), scope.QuotedTableName())
primaryConditions = append(primaryConditions, sql)
}
if !scope.PrimaryKeyZero() {
primaryConditions = append(primaryConditions, scope.primaryCondition(scope.AddToVars(scope.PrimaryKeyValue())))
}
for _, clause := range scope.Search.whereConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
for _, clause := range scope.Search.orConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
orConditions = append(orConditions, sql)
}
}
for _, clause := range scope.Search.notConditions {
if sql := scope.buildNotCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
orSql := strings.Join(orConditions, " OR ")
combinedSql := strings.Join(andConditions, " AND ")
if len(combinedSql) > 0 {
if len(orSql) > 0 {
combinedSql = combinedSql + " OR " + orSql
}
} else {
combinedSql = orSql
}
if len(primaryConditions) > 0 {
sql = "WHERE " + strings.Join(primaryConditions, " AND ")
if len(combinedSql) > 0 {
sql = sql + " AND (" + combinedSql + ")"
}
} else if len(combinedSql) > 0 {
sql = "WHERE " + combinedSql
}
return
}
func (scope *Scope) selectSql() string {
if len(scope.Search.selects) == 0 {
return "*"
}
return scope.buildSelectQuery(scope.Search.selects)
}
func (scope *Scope) orderSql() string {
if len(scope.Search.orders) == 0 {
return ""
}
return " ORDER BY " + strings.Join(scope.Search.orders, ",")
}
func (scope *Scope) limitSql() string {
if !scope.Dialect().HasTop() {
if len(scope.Search.limit) == 0 {
return ""
}
return " LIMIT " + scope.Search.limit
}
return ""
}
func (scope *Scope) topSql() string {
if scope.Dialect().HasTop() && len(scope.Search.offset) == 0 {
if len(scope.Search.limit) == 0 {
return ""
}
return " TOP(" + scope.Search.limit + ")"
}
return ""
}
func (scope *Scope) offsetSql() string {
if len(scope.Search.offset) == 0 {
return ""
}
if scope.Dialect().HasTop() {
sql := " OFFSET " + scope.Search.offset + " ROW "
if len(scope.Search.limit) > 0 {
sql += "FETCH NEXT " + scope.Search.limit + " ROWS ONLY"
}
return sql
}
return " OFFSET " + scope.Search.offset
}
func (scope *Scope) groupSql() string {
if len(scope.Search.group) == 0 {
return ""
}
return " GROUP BY " + scope.Search.group
}
func (scope *Scope) havingSql() string {
if scope.Search.havingCondition == nil {
return ""
}
return " HAVING " + scope.buildWhereCondition(scope.Search.havingCondition)
}
func (scope *Scope) joinsSql() string {
return scope.Search.joins + " "
}
func (scope *Scope) prepareQuerySql() {
if scope.Search.raw {
scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")"))
} else {
scope.Raw(fmt.Sprintf("SELECT %v %v FROM %v %v", scope.topSql(), scope.selectSql(), scope.QuotedTableName(), scope.CombinedConditionSql()))
}
return
}
func (scope *Scope) inlineCondition(values ...interface{}) *Scope {
if len(values) > 0 {
scope.Search.Where(values[0], values[1:]...)
}
return scope
}
func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope {
for _, f := range funcs {
(*f)(scope)
if scope.skipLeft {
break
}
}
return scope
}
func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignoreProtectedAttrs bool) (results map[string]interface{}, hasUpdate bool) {
if !scope.IndirectValue().CanAddr() {
return values, true
}
var hasExpr bool
fields := scope.Fields()
for key, value := range values {
if field, ok := fields[ToDBName(key)]; ok && field.Field.IsValid() {
if !reflect.DeepEqual(field.Field, reflect.ValueOf(value)) {
if _, ok := value.(*expr); ok {
hasExpr = true
} else if !equalAsString(field.Field.Interface(), value) {
hasUpdate = true
field.Set(value)
}
}
}
}
if hasExpr {
var updateMap = map[string]interface{}{}
for key, value := range fields {
if v, ok := values[key]; ok {
updateMap[key] = v
} else {
updateMap[key] = value.Field.Interface()
}
}
return updateMap, true
}
return
}
func (scope *Scope) row() *sql.Row {
defer scope.Trace(NowFunc())
scope.callCallbacks(scope.db.parent.callback.rowQueries)
scope.prepareQuerySql()
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
}
func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.Trace(NowFunc())
scope.callCallbacks(scope.db.parent.callback.rowQueries)
scope.prepareQuerySql()
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
}
func (scope *Scope) initialize() *Scope {
for _, clause := range scope.Search.whereConditions {
scope.updatedAttrsWithValues(convertInterfaceToMap(clause["query"]), false)
}
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.initAttrs), false)
scope.updatedAttrsWithValues(convertInterfaceToMap(scope.Search.assignAttrs), false)
return scope
}
func (scope *Scope) pluck(column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value))
scope.Search.Select(column)
if dest.Kind() != reflect.Slice {
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
return scope
}
rows, err := scope.rows()
if scope.Err(err) == nil {
defer rows.Close()
for rows.Next() {
elem := reflect.New(dest.Type().Elem()).Interface()
scope.Err(rows.Scan(elem))
dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem()))
}
}
return scope
}
func (scope *Scope) count(value interface{}) *Scope {
scope.Search.Select("count(*)")
scope.Err(scope.row().Scan(value))
return scope
}
func (scope *Scope) typeName() string {
value := scope.IndirectValue()
if value.Kind() == reflect.Slice {
return value.Type().Elem().Name()
}
return value.Type().Name()
}
func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
toScope := scope.db.NewScope(value)
fromFields := scope.Fields()
toFields := toScope.Fields()
for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") {
var fromField, toField *Field
if field, ok := scope.FieldByName(foreignKey); ok {
fromField = field
} else {
fromField = fromFields[ToDBName(foreignKey)]
}
if field, ok := toScope.FieldByName(foreignKey); ok {
toField = field
} else {
toField = toFields[ToDBName(foreignKey)]
}
if fromField != nil {
if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" {
joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
} else if relationship.Kind == "belongs_to" {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()
scope.Err(toScope.db.Where(sql, foreignKeyValue).Find(value).Error)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
query := toScope.db.Where(sql, scope.PrimaryKeyValue())
if relationship.PolymorphicType != "" {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
scope.Err(query.Find(value).Error)
}
} else {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error)
}
return scope
} else if toField != nil {
sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName))
scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error)
return scope
}
}
scope.Err(fmt.Errorf("invalid association %v", foreignKeys))
return scope
}
func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(scope, joinTable) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes []string
for _, s := range []*Scope{scope, toScope} {
for _, primaryField := range s.GetModelStruct().PrimaryFields {
value := reflect.Indirect(reflect.New(primaryField.Struct.Type))
primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
dbName := ToDBName(s.GetModelStruct().ModelType.Name() + primaryField.Name)
sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType)
}
}
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join(sqlTypes, ","))).Error)
}
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
}
}
func (scope *Scope) createTable() *Scope {
var tags []string
var primaryKeys []string
for _, field := range scope.GetStructFields() {
if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
}
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
scope.createJoinTable(field)
}
var primaryKeyStr string
if len(primaryKeys) > 0 {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec()
return scope
}
func (scope *Scope) dropTable() *Scope {
scope.Raw(fmt.Sprintf("DROP TABLE %v", scope.QuotedTableName())).Exec()
return scope
}
func (scope *Scope) dropTableIfExists() *Scope {
if scope.Dialect().HasTable(scope, scope.TableName()) {
scope.dropTable()
}
return scope
}
func (scope *Scope) modifyColumn(column string, typ string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec()
}
func (scope *Scope) dropColumn(column string) {
scope.Raw(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", scope.QuotedTableName(), scope.Quote(column))).Exec()
}
func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
if scope.Dialect().HasIndex(scope, scope.TableName(), indexName) {
return
}
var columns []string
for _, name := range column {
if regexp.MustCompile("^[a-zA-Z]+$").MatchString(name) {
columns = append(columns, scope.Quote(name))
} else {
columns = append(columns, name)
}
}
sqlCreate := "CREATE INDEX"
if unique {
sqlCreate = "CREATE UNIQUE INDEX"
}
scope.Raw(fmt.Sprintf("%s %v ON %v(%v);", sqlCreate, indexName, scope.QuotedTableName(), strings.Join(columns, ", "))).Exec()
}
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var table = scope.TableName()
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.Quote(keyName), scope.Quote(field), scope.Quote(dest), onDelete, onUpdate)).Exec()
}
func (scope *Scope) removeIndex(indexName string) {
scope.Dialect().RemoveIndex(scope, indexName)
}
func (scope *Scope) autoMigrate() *Scope {
tableName := scope.TableName()
quotedTableName := scope.QuotedTableName()
if !scope.Dialect().HasTable(scope, tableName) {
scope.createTable()
} else {
for _, field := range scope.GetStructFields() {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
}
}
scope.createJoinTable(field)
}
}
scope.autoIndex()
return scope
}
func (scope *Scope) autoIndex() *Scope {
var indexes = map[string][]string{}
var uniqueIndexes = map[string][]string{}
for _, field := range scope.GetStructFields() {
sqlSettings := parseTagSetting(field.Tag.Get("sql"))
if name, ok := sqlSettings["INDEX"]; ok {
if name == "INDEX" {
name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName)
}
indexes[name] = append(indexes[name], field.DBName)
}
if name, ok := sqlSettings["UNIQUE_INDEX"]; ok {
if name == "UNIQUE_INDEX" {
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
}
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
}
}
for name, columns := range indexes {
scope.addIndex(false, name, columns...)
}
for name, columns := range uniqueIndexes {
scope.addIndex(true, name, columns...)
}
return scope
}
package gorm_test
import (
"github.com/jinzhu/gorm"
"testing"
)
func NameIn1And2(d *gorm.DB) *gorm.DB {
return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"})
}
func NameIn2And3(d *gorm.DB) *gorm.DB {
return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"})
}
func NameIn(names []string) func(d *gorm.DB) *gorm.DB {
return func(d *gorm.DB) *gorm.DB {
return d.Where("name in (?)", names)
}
}
func TestScopes(t *testing.T) {
user1 := User{Name: "ScopeUser1", Age: 1}
user2 := User{Name: "ScopeUser2", Age: 1}
user3 := User{Name: "ScopeUser3", Age: 2}
DB.Save(&user1).Save(&user2).Save(&user3)
var users1, users2, users3 []User
DB.Scopes(NameIn1And2).Find(&users1)
if len(users1) != 2 {
t.Errorf("Should found two users's name in 1, 2")
}
DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2)
if len(users2) != 1 {
t.Errorf("Should found one user's name is 2")
}
DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3)
if len(users3) != 2 {
t.Errorf("Should found two users's name in 1, 3")
}
}
package gorm
import "fmt"
type search struct {
db *DB
whereConditions []map[string]interface{}
orConditions []map[string]interface{}
notConditions []map[string]interface{}
havingCondition map[string]interface{}
initAttrs []interface{}
assignAttrs []interface{}
selects map[string]interface{}
omits []string
orders []string
joins string
preload []searchPreload
offset string
limit string
group string
tableName string
raw bool
Unscoped bool
}
type searchPreload struct {
schema string
conditions []interface{}
}
func (s *search) clone() *search {
clone := *s
return &clone
}
func (s *search) Where(query interface{}, values ...interface{}) *search {
s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) Not(query interface{}, values ...interface{}) *search {
s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) Or(query interface{}, values ...interface{}) *search {
s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})
return s
}
func (s *search) Attrs(attrs ...interface{}) *search {
s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
return s
}
func (s *search) Assign(attrs ...interface{}) *search {
s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
return s
}
func (s *search) Order(value string, reorder ...bool) *search {
if len(reorder) > 0 && reorder[0] {
s.orders = []string{value}
} else {
s.orders = append(s.orders, value)
}
return s
}
func (s *search) Select(query interface{}, args ...interface{}) *search {
s.selects = map[string]interface{}{"query": query, "args": args}
return s
}
func (s *search) Omit(columns ...string) *search {
s.omits = columns
return s
}
func (s *search) Limit(value interface{}) *search {
s.limit = s.getInterfaceAsSql(value)
return s
}
func (s *search) Offset(value interface{}) *search {
s.offset = s.getInterfaceAsSql(value)
return s
}
func (s *search) Group(query string) *search {
s.group = s.getInterfaceAsSql(query)
return s
}
func (s *search) Having(query string, values ...interface{}) *search {
s.havingCondition = map[string]interface{}{"query": query, "args": values}
return s
}
func (s *search) Joins(query string) *search {
s.joins = query
return s
}
func (s *search) Preload(schema string, values ...interface{}) *search {
var preloads []searchPreload
for _, preload := range s.preload {
if preload.schema != schema {
preloads = append(preloads, preload)
}
}
preloads = append(preloads, searchPreload{schema, values})
s.preload = preloads
return s
}
func (s *search) Raw(b bool) *search {
s.raw = b
return s
}
func (s *search) unscoped() *search {
s.Unscoped = true
return s
}
func (s *search) Table(name string) *search {
s.tableName = name
return s
}
func (s *search) getInterfaceAsSql(value interface{}) (str string) {
switch value.(type) {
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
str = fmt.Sprintf("%v", value)
default:
s.db.err(InvalidSql)
}
if str == "-1" {
return ""
}
return
}
package gorm
import (
"reflect"
"testing"
)
func TestCloneSearch(t *testing.T) {
s := new(search)
s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age")
s1 := s.clone()
s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email")
if reflect.DeepEqual(s.whereConditions, s1.whereConditions) {
t.Errorf("Where should be copied")
}
if reflect.DeepEqual(s.orders, s1.orders) {
t.Errorf("Order should be copied")
}
if reflect.DeepEqual(s.initAttrs, s1.initAttrs) {
t.Errorf("InitAttrs should be copied")
}
if reflect.DeepEqual(s.Select, s1.Select) {
t.Errorf("selectStr should be copied")
}
}
package gorm_test
import (
"database/sql/driver"
"encoding/json"
"testing"
)
func TestScannableSlices(t *testing.T) {
if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil {
t.Errorf("Should create table with slice values correctly: %s", err)
}
r1 := RecordWithSlice{
Strings: ExampleStringSlice{"a", "b", "c"},
Structs: ExampleStructSlice{
{"name1", "value1"},
{"name2", "value2"},
},
}
if err := DB.Save(&r1).Error; err != nil {
t.Errorf("Should save record with slice values")
}
var r2 RecordWithSlice
if err := DB.Find(&r2).Error; err != nil {
t.Errorf("Should fetch record with slice values")
}
if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" {
t.Errorf("Should have serialised and deserialised a string array")
}
if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" {
t.Errorf("Should have serialised and deserialised a struct array")
}
}
type RecordWithSlice struct {
ID uint64
Strings ExampleStringSlice `sql:"type:text"`
Structs ExampleStructSlice `sql:"type:text"`
}
type ExampleStringSlice []string
func (l ExampleStringSlice) Value() (driver.Value, error) {
return json.Marshal(l)
}
func (l *ExampleStringSlice) Scan(input interface{}) error {
return json.Unmarshal(input.([]byte), l)
}
type ExampleStruct struct {
Name string
Value string
}
type ExampleStructSlice []ExampleStruct
func (l ExampleStructSlice) Value() (driver.Value, error) {
return json.Marshal(l)
}
func (l *ExampleStructSlice) Scan(input interface{}) error {
return json.Unmarshal(input.([]byte), l)
}
package gorm
import (
"fmt"
"reflect"
"time"
)
type sqlite3 struct {
commonDialect
}
func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "integer"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
}
func (sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count)
return count > 0
}
func (sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count)
return count > 0
}
func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count)
return count > 0
}
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
}
package gorm_test
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"time"
)
type User struct {
Id int64
Age int64
UserNum Num
Name string `sql:"size:255"`
Birthday time.Time // Time
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
Emails []Email // Embedded structs
BillingAddress Address // Embedded struct
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key
CreditCard CreditCard
Latitude float64
Languages []Language `gorm:"many2many:user_languages;"`
CompanyID int64
Company Company
Role
PasswordHash []byte
IgnoreMe int64 `sql:"-"`
IgnoreStringSlice []string `sql:"-"`
Ignored struct{ Name string } `sql:"-"`
IgnoredPointer *User `sql:"-"`
}
type CreditCard struct {
ID int8
Number string
UserId sql.NullInt64
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
}
type Email struct {
Id int16
UserId int
Email string `sql:"type:varchar(100);"`
CreatedAt time.Time
UpdatedAt time.Time
}
type Address struct {
ID int
Address1 string
Address2 string
Post string
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
}
type Language struct {
Id int
Name string
Users []User `gorm:"many2many:user_languages;"`
}
type Product struct {
Id int64
Code string
Price int64
CreatedAt time.Time
UpdatedAt time.Time
AfterFindCallTimes int64
BeforeCreateCallTimes int64
AfterCreateCallTimes int64
BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64
BeforeSaveCallTimes int64
AfterSaveCallTimes int64
BeforeDeleteCallTimes int64
AfterDeleteCallTimes int64
}
type Company struct {
Id int64
Name string
Owner *User `sql:"-"`
}
type Role struct {
Name string
}
func (role *Role) Scan(value interface{}) error {
if b, ok := value.([]uint8); ok {
role.Name = string(b)
} else {
role.Name = value.(string)
}
return nil
}
func (role Role) Value() (driver.Value, error) {
return role.Name, nil
}
func (role Role) IsAdmin() bool {
return role.Name == "admin"
}
type Num int64
func (i *Num) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
case int64:
*i = Num(s)
default:
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
}
return nil
}
type Animal struct {
Counter uint64 `gorm:"primary_key:yes"`
Name string `sql:"DEFAULT:'galeone'"`
From string //test reserved sql keyword as field name
Age time.Time `sql:"DEFAULT:current_timestamp"`
unexported string // unexported value
CreatedAt time.Time
UpdatedAt time.Time
}
type JoinTable struct {
From uint64
To uint64
Time time.Time `sql:"default: null"`
}
type Post struct {
Id int64
CategoryId sql.NullInt64
MainCategoryId int64
Title string
Body string
Comments []*Comment
Category Category
MainCategory Category
}
type Category struct {
Id int64
Name string
}
type Comment struct {
Id int64
PostId int64
Content string
Post Post
}
// Scanner
type NullValue struct {
Id int64
Name sql.NullString `sql:"not null"`
Age sql.NullInt64
Male sql.NullBool
Height sql.NullFloat64
AddedAt NullTime
}
type NullTime struct {
Time time.Time
Valid bool
}
func (nt *NullTime) Scan(value interface{}) error {
if value == nil {
nt.Valid = false
return nil
}
nt.Time, nt.Valid = value.(time.Time), true
return nil
}
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
func getPreparedUser(name string, role string) *User {
var company Company
DB.Where(Company{Name: role}).FirstOrCreate(&company)
return &User{
Name: name,
Age: 20,
Role: Role{role},
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
Emails: []Email{
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
},
Company: company,
Languages: []Language{
{Name: fmt.Sprintf("lang_1_%v", name)},
{Name: fmt.Sprintf("lang_2_%v", name)},
},
}
}
dialects=("postgres" "mysql" "sqlite")
for dialect in "${dialects[@]}" ; do
GORM_DIALECT=${dialect} go test
done
package gorm_test
import (
"testing"
"time"
"github.com/jinzhu/gorm"
)
func TestUpdate(t *testing.T) {
product1 := Product{Code: "product1code"}
product2 := Product{Code: "product2code"}
DB.Save(&product1).Save(&product2).Update("code", "product2newcode")
if product2.Code != "product2newcode" {
t.Errorf("Record should be updated")
}
DB.First(&product1, product1.Id)
DB.First(&product2, product2.Id)
updatedAt1 := product1.UpdatedAt
updatedAt2 := product2.UpdatedAt
var product3 Product
DB.First(&product3, product2.Id).Update("code", "product2newcode")
if updatedAt2.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should not be updated if nothing changed")
}
if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
t.Errorf("Product1 should not be updated")
}
if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() {
t.Errorf("Product2's code should be updated")
}
if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
t.Errorf("Product2's code should be updated")
}
DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode")
var product4 Product
DB.First(&product4, product1.Id)
if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should be updated if something changed")
}
if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() {
t.Errorf("Product1's code should be updated")
}
if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() {
t.Errorf("Product should not be changed to 789")
}
if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
t.Error("No error should raise when update with CamelCase")
}
if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
t.Error("No error should raise when update_column with CamelCase")
}
var products []Product
DB.Find(&products)
if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) {
t.Error("RowsAffected should be correct when do batch update")
}
DB.First(&product4, product4.Id)
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
var product5 Product
DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100-50 {
t.Errorf("Update with expression")
}
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("Update with expression should update UpdatedAt")
}
}
func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
animal := Animal{Name: "Ferdinand"}
DB.Save(&animal)
updatedAt1 := animal.UpdatedAt
DB.Save(&animal).Update("name", "Francis")
if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should not be updated if nothing changed")
}
var animals []Animal
DB.Find(&animals)
if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) {
t.Error("RowsAffected should be correct when do batch update")
}
animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone)
DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
DB.First(&animal, animal.Counter)
if animal.Name != "galeone" {
t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name)
}
// When changing a field with a default value, the change must occur
animal.Name = "amazing horse"
DB.Save(&animal)
DB.First(&animal, animal.Counter)
if animal.Name != "amazing horse" {
t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name)
}
}
func TestUpdates(t *testing.T) {
product1 := Product{Code: "product1code", Price: 10}
product2 := Product{Code: "product2code", Price: 10}
DB.Save(&product1).Save(&product2)
DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100})
if product1.Code != "product1newcode" || product1.Price != 100 {
t.Errorf("Record should be updated also with map")
}
DB.First(&product1, product1.Id)
DB.First(&product2, product2.Id)
updatedAt1 := product1.UpdatedAt
updatedAt2 := product2.UpdatedAt
var product3 Product
DB.First(&product3, product1.Id).Updates(Product{Code: "product1newcode", Price: 100})
if product3.Code != "product1newcode" || product3.Price != 100 {
t.Errorf("Record should be updated with struct")
}
if updatedAt1.Format(time.RFC3339Nano) != product3.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should not be updated if nothing changed")
}
if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
t.Errorf("Product2 should not be updated")
}
if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() {
t.Errorf("Product1 should be updated")
}
DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"})
if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() {
t.Errorf("Product2's code should be updated")
}
var product4 Product
DB.First(&product4, product2.Id)
if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should be updated if something changed")
}
if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
t.Errorf("product2's code should be updated")
}
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
var product5 Product
DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100 {
t.Errorf("Updates with expression")
}
if product5.UpdatedAt.Format(time.RFC3339Nano) == product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("Updates with expression should update UpdatedAt")
}
}
func TestUpdateColumn(t *testing.T) {
product1 := Product{Code: "product1code", Price: 10}
product2 := Product{Code: "product2code", Price: 20}
DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100})
if product2.Code != "product2newcode" || product2.Price != 100 {
t.Errorf("product 2 should be updated with update column")
}
var product3 Product
DB.First(&product3, product1.Id)
if product3.Code != "product1code" || product3.Price != 10 {
t.Errorf("product 1 should not be updated")
}
DB.First(&product2, product2.Id)
updatedAt2 := product2.UpdatedAt
DB.Model(product2).UpdateColumn("code", "update_column_new")
var product4 Product
DB.First(&product4, product2.Id)
if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("updatedAt should not be updated with update column")
}
DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50"))
var product5 Product
DB.First(&product5, product4.Id)
if product5.Price != product4.Price+100-50 {
t.Errorf("UpdateColumn with expression")
}
if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("UpdateColumn with expression should not update UpdatedAt")
}
}
func TestSelectWithUpdate(t *testing.T) {
user := getPreparedUser("select_user", "select_with_update")
DB.Create(user)
var reloadUser User
DB.First(&reloadUser, user.Id)
reloadUser.Name = "new_name"
reloadUser.Age = 50
reloadUser.BillingAddress = Address{Address1: "New Billing Address"}
reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"}
reloadUser.CreditCard = CreditCard{Number: "987654321"}
reloadUser.Emails = []Email{
{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
}
reloadUser.Company = Company{Name: "new company"}
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser)
var queryUser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
if queryUser.Name == user.Name || queryUser.Age != user.Age {
t.Errorf("Should only update users with name column")
}
if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 ||
queryUser.ShippingAddressId != user.ShippingAddressId ||
queryUser.CreditCard.ID == user.CreditCard.ID ||
len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id {
t.Errorf("Should only update selected relationships")
}
}
func TestSelectWithUpdateWithMap(t *testing.T) {
user := getPreparedUser("select_user", "select_with_update_map")
DB.Create(user)
updateValues := map[string]interface{}{
"Name": "new_name",
"Age": 50,
"BillingAddress": Address{Address1: "New Billing Address"},
"ShippingAddress": Address{Address1: "New ShippingAddress Address"},
"CreditCard": CreditCard{Number: "987654321"},
"Emails": []Email{
{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
},
"Company": Company{Name: "new company"},
}
var reloadUser User
DB.First(&reloadUser, user.Id)
DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues)
var queryUser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
if queryUser.Name == user.Name || queryUser.Age != user.Age {
t.Errorf("Should only update users with name column")
}
if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 ||
queryUser.ShippingAddressId != user.ShippingAddressId ||
queryUser.CreditCard.ID == user.CreditCard.ID ||
len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id {
t.Errorf("Should only update selected relationships")
}
}
func TestOmitWithUpdate(t *testing.T) {
user := getPreparedUser("omit_user", "omit_with_update")
DB.Create(user)
var reloadUser User
DB.First(&reloadUser, user.Id)
reloadUser.Name = "new_name"
reloadUser.Age = 50
reloadUser.BillingAddress = Address{Address1: "New Billing Address"}
reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"}
reloadUser.CreditCard = CreditCard{Number: "987654321"}
reloadUser.Emails = []Email{
{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
}
reloadUser.Company = Company{Name: "new company"}
DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser)
var queryUser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
if queryUser.Name != user.Name || queryUser.Age == user.Age {
t.Errorf("Should only update users with name column")
}
if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 ||
queryUser.ShippingAddressId == user.ShippingAddressId ||
queryUser.CreditCard.ID != user.CreditCard.ID ||
len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
t.Errorf("Should only update relationships that not omited")
}
}
func TestOmitWithUpdateWithMap(t *testing.T) {
user := getPreparedUser("select_user", "select_with_update_map")
DB.Create(user)
updateValues := map[string]interface{}{
"Name": "new_name",
"Age": 50,
"BillingAddress": Address{Address1: "New Billing Address"},
"ShippingAddress": Address{Address1: "New ShippingAddress Address"},
"CreditCard": CreditCard{Number: "987654321"},
"Emails": []Email{
{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
},
"Company": Company{Name: "new company"},
}
var reloadUser User
DB.First(&reloadUser, user.Id)
DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues)
var queryUser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
if queryUser.Name != user.Name || queryUser.Age == user.Age {
t.Errorf("Should only update users with name column")
}
if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 ||
queryUser.ShippingAddressId == user.ShippingAddressId ||
queryUser.CreditCard.ID != user.CreditCard.ID ||
len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
t.Errorf("Should only update relationships not omited")
}
}
func TestSelectWithUpdateColumn(t *testing.T) {
user := getPreparedUser("select_user", "select_with_update_map")
DB.Create(user)
updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
var reloadUser User
DB.First(&reloadUser, user.Id)
DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues)
var queryUser User
DB.First(&queryUser, user.Id)
if queryUser.Name == user.Name || queryUser.Age != user.Age {
t.Errorf("Should only update users with name column")
}
}
func TestOmitWithUpdateColumn(t *testing.T) {
user := getPreparedUser("select_user", "select_with_update_map")
DB.Create(user)
updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
var reloadUser User
DB.First(&reloadUser, user.Id)
DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues)
var queryUser User
DB.First(&queryUser, user.Id)
if queryUser.Name != user.Name || queryUser.Age == user.Age {
t.Errorf("Should omit name column when update user")
}
}
func TestUpdateColumnsSkipsAssociations(t *testing.T) {
user := getPreparedUser("update_columns_user", "special_role")
user.Age = 99
address1 := "first street"
user.BillingAddress = Address{Address1: address1}
DB.Save(user)
// Update a single field of the user and verify that the changed address is not stored.
newAge := int64(100)
user.BillingAddress.Address1 = "second street"
db := DB.Model(user).UpdateColumns(User{Age: newAge})
if db.RowsAffected != 1 {
t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected)
}
// Verify that Age now=`newAge`.
freshUser := &User{Id: user.Id}
DB.First(freshUser)
if freshUser.Age != newAge {
t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age)
}
// Verify that user's BillingAddress.Address1 is not changed and is still "first street".
DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID)
if freshUser.BillingAddress.Address1 != address1 {
t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
}
}
package gorm
import (
"bytes"
"strings"
)
// Copied from golint
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
var commonInitialismsReplacer *strings.Replacer
func init() {
var commonInitialismsForReplacer []string
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
var smap = map[string]string{}
func ToDBName(name string) string {
if v, ok := smap[name]; ok {
return v
}
value := commonInitialismsReplacer.Replace(name)
buf := bytes.NewBufferString("")
for i, v := range value {
if i > 0 && v >= 'A' && v <= 'Z' {
buf.WriteRune('_')
}
buf.WriteRune(v)
}
s := strings.ToLower(buf.String())
smap[name] = s
return s
}
type expr struct {
expr string
args []interface{}
}
func Expr(expression string, args ...interface{}) *expr {
return &expr{expr: expression, args: args}
}
package gorm
import (
"fmt"
"reflect"
"regexp"
"runtime"
)
func fileWithLineNum() string {
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", file, line)
}
}
return ""
}
func isBlank(value reflect.Value) bool {
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
}
func toSearchableMap(attrs ...interface{}) (result interface{}) {
if len(attrs) > 1 {
if str, ok := attrs[0].(string); ok {
result = map[string]interface{}{str: attrs[1]}
}
} else if len(attrs) == 1 {
if attr, ok := attrs[0].(map[string]interface{}); ok {
result = attr
}
if attr, ok := attrs[0].(interface{}); ok {
result = attr
}
}
return
}
func convertInterfaceToMap(values interface{}) map[string]interface{} {
attrs := map[string]interface{}{}
switch value := values.(type) {
case map[string]interface{}:
for k, v := range value {
attrs[ToDBName(k)] = v
}
case []interface{}:
for _, v := range value {
for key, value := range convertInterfaceToMap(v) {
attrs[key] = value
}
}
case interface{}:
reflectValue := reflect.ValueOf(values)
switch reflectValue.Kind() {
case reflect.Map:
for _, key := range reflectValue.MapKeys() {
attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface()
}
default:
scope := Scope{Value: values}
for _, field := range scope.Fields() {
if !field.IsBlank && !field.IsIgnored {
attrs[field.DBName] = field.Field.Interface()
}
}
}
}
return attrs
}
package httpd
import (
"bufio"
"encoding/binary"
"errors"
"io"
"io/ioutil"
"math/rand"
"net"
"strconv"
"time"
)
type (
closeError struct {
code int
text string
}
netError struct {
msg string
temporary bool
timeout bool
}
messageReader struct {
c *Conn
seq int
}
messageWriter struct {
c *Conn
seq int
}
)
const (
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
maxControlFramePayloadSize = 125
finalBit = 1 << 7
maskBit = 1 << 7
writeWait = time.Second
defaultReadBufferSize = 4096
defaultWriteBufferSize = 4096
continuationFrame = 0
noFrame = -1
)
// Close codes defined in RFC 6455, section 11.7.
const (
TextMessage = 1
BinaryMessage = 2
CloseMessage = 8
PingMessage = 9
PongMessage = 10
CloseNormalClosure = 1000
CloseGoingAway = 1001
CloseProtocolError = 1002
CloseUnsupportedData = 1003
CloseNoStatusReceived = 1005
CloseAbnormalClosure = 1006
CloseInvalidFramePayloadData = 1007
ClosePolicyViolation = 1008
CloseMessageTooBig = 1009
CloseMandatoryExtension = 1010
CloseInternalServerErr = 1011
CloseTLSHandshake = 1015
)
var (
errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true}
errUnexpectedEOF = &closeError{code: CloseAbnormalClosure, text: io.ErrUnexpectedEOF.Error()}
errBadWriteOpCode = errors.New("websocket: bad write message type")
errWriteClosed = errors.New("websocket: write closed")
errInvalidControlFrame = errors.New("websocket: invalid control frame")
ErrCloseSent = errors.New("websocket: close sent")
ErrReadLimit = errors.New("websocket: read limit exceeded")
keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
)
func (e *closeError) Error() string {
return "websocket: close " + strconv.Itoa(e.code) + " " + e.text
}
func (e *netError) Error() string { return e.msg }
func (e *netError) Temporary() bool { return e.temporary }
func (e *netError) Timeout() bool { return e.timeout }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
mu := make(chan bool, 1)
mu <- true
if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize
}
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize
}
c := &Conn{
isServer: isServer,
br: bufio.NewReaderSize(conn, readBufferSize),
conn: conn,
mu: mu,
readFinal: true,
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
writeFrameType: noFrame,
writePos: maxFrameHeaderSize,
}
c.SetPingHandler(nil)
c.SetPongHandler(nil)
return c
}
// Subprotocol returns the negotiated protocol for the connection.
func (c *Conn) Subprotocol() string {
return c.subprotocol
}
// Close closes the underlying network connection without sending or waiting for a close frame.
func (c *Conn) Close() error {
return c.conn.Close()
}
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetWriteDeadline sets the write deadline on the underlying network
// connection. After a write has timed out, the websocket state is corrupt and
// all future writes will return an error. A zero value for t means writes will
// not time out.
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline = t
return nil
}
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error {
wr, err := c.NextWriter(messageType)
if err != nil {
return err
}
w := wr.(messageWriter)
if _, err := w.write(true, data); err != nil {
return err
}
if c.writeSeq == w.seq {
if err := c.flushFrame(true, nil); err != nil {
return err
}
}
return nil
}
// Write methods
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
<-c.mu
defer func() { c.mu <- true }()
if c.closeSent {
return ErrCloseSent
} else if frameType == CloseMessage {
c.closeSent = true
}
c.conn.SetWriteDeadline(deadline)
for _, buf := range bufs {
if len(buf) > 0 {
n, err := c.conn.Write(buf)
if n != len(buf) {
// Close on partial write.
c.conn.Close()
}
if err != nil {
return err
}
}
}
return nil
}
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
if !isControl(messageType) {
return errBadWriteOpCode
}
if len(data) > maxControlFramePayloadSize {
return errInvalidControlFrame
}
b0 := byte(messageType) | finalBit
b1 := byte(len(data))
if !c.isServer {
b1 |= maskBit
}
buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize)
buf = append(buf, b0, b1)
if c.isServer {
buf = append(buf, data...)
} else {
key := newMaskKey()
buf = append(buf, key[:]...)
buf = append(buf, data...)
maskBytes(key, 0, buf[6:])
}
d := time.Hour * 1000
if !deadline.IsZero() {
d = deadline.Sub(time.Now())
if d < 0 {
return errWriteTimeout
}
}
timer := time.NewTimer(d)
select {
case <-c.mu:
timer.Stop()
case <-timer.C:
return errWriteTimeout
}
defer func() { c.mu <- true }()
if c.closeSent {
return ErrCloseSent
} else if messageType == CloseMessage {
c.closeSent = true
}
c.conn.SetWriteDeadline(deadline)
n, err := c.conn.Write(buf)
if n != 0 && n != len(buf) {
c.conn.Close()
}
return err
}
// NextReader returns the next data message received from the peer. The
// returned messageType is either TextMessage or BinaryMessage.
//
// There can be at most one open reader on a connection. NextReader discards
// the previous message if the application has not already consumed it.
//
// The NextReader method and the readers returned from the method cannot be
// accessed by more than one goroutine at a time.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readSeq++
c.readLength = 0
for c.readErr == nil {
frameType, err := c.advanceFrame()
if err != nil {
c.readErr = hideTempErr(err)
break
}
if frameType == TextMessage || frameType == BinaryMessage {
return frameType, messageReader{c, c.readSeq}, nil
}
}
return noFrame, nil, c.readErr
}
// NextWriter returns a writer for the next message to send. The writer's
// Close method flushes the complete message to the network.
//
// There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so.
//
// The NextWriter method and the writers returned from the method cannot be
// accessed by more than one goroutine at a time.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if c.writeErr != nil {
return nil, c.writeErr
}
if c.writeFrameType != noFrame {
if err := c.flushFrame(true, nil); err != nil {
return nil, err
}
}
if !isControl(messageType) && !isData(messageType) {
return nil, errBadWriteOpCode
}
c.writeFrameType = messageType
return messageWriter{c, c.writeSeq}, nil
}
func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}
// readFull is like io.ReadFull except that io.EOF is never returned.
func (c *Conn) readFull(p []byte) (err error) {
var n int
for n < len(p) && err == nil {
var nn int
nn, err = c.br.Read(p[n:])
n += nn
}
if n == len(p) {
err = nil
} else if err == io.EOF {
err = errUnexpectedEOF
}
return
}
func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame.
if c.readRemaining > 0 {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err
}
}
// 2. Read and parse first two bytes of frame header.
var b [8]byte
if err := c.readFull(b[:2]); err != nil {
return noFrame, err
}
final := b[0]&finalBit != 0
frameType := int(b[0] & 0xf)
reserved := int((b[0] >> 4) & 0x7)
mask := b[1]&maskBit != 0
c.readRemaining = int64(b[1] & 0x7f)
if reserved != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
}
switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125")
}
if !final {
return noFrame, c.handleProtocolError("control frame not final")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame")
}
c.readFinal = final
default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
}
// 3. Read and parse frame length.
switch c.readRemaining {
case 126:
if err := c.readFull(b[:2]); err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
case 127:
if err := c.readFull(b[:8]); err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
}
// 4. Handle frame masking.
if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}
if mask {
c.readMaskPos = 0
if err := c.readFull(c.readMaskKey[:]); err != nil {
return noFrame, err
}
}
// 5. For text and binary messages, enforce read limit and return.
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
c.readLength += c.readRemaining
if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit
}
return frameType, nil
}
// 6. Read control frame payload.
var payload []byte
if c.readRemaining > 0 {
payload = make([]byte, c.readRemaining)
c.readRemaining = 0
if err := c.readFull(payload); err != nil {
return noFrame, err
}
if c.isServer {
maskBytes(c.readMaskKey, 0, payload)
}
}
// 7. Process control frame payload.
switch frameType {
case PongMessage:
if err := c.handlePong(string(payload)); err != nil {
return noFrame, err
}
case PingMessage:
if err := c.handlePing(string(payload)); err != nil {
return noFrame, err
}
case CloseMessage:
c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait))
closeCode := CloseNoStatusReceived
closeText := ""
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
closeText = string(payload[2:])
}
switch closeCode {
case CloseNormalClosure, CloseGoingAway:
return noFrame, io.EOF
default:
return noFrame, &closeError{code: closeCode, text: closeText}
}
}
return frameType, nil
}
func (c *Conn) flushFrame(final bool, extra []byte) error {
length := c.writePos - maxFrameHeaderSize + len(extra)
// Check for invalid control frames.
if isControl(c.writeFrameType) &&
(!final || length > maxControlFramePayloadSize) {
c.writeSeq++
c.writeFrameType = noFrame
c.writePos = maxFrameHeaderSize
return errInvalidControlFrame
}
b0 := byte(c.writeFrameType)
if final {
b0 |= finalBit
}
b1 := byte(0)
if !c.isServer {
b1 |= maskBit
}
// Assume that the frame starts at beginning of c.writeBuf.
framePos := 0
if c.isServer {
// Adjust up if mask not included in the header.
framePos = 4
}
switch {
case length >= 65536:
c.writeBuf[framePos] = b0
c.writeBuf[framePos+1] = b1 | 127
binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))
case length > 125:
framePos += 6
c.writeBuf[framePos] = b0
c.writeBuf[framePos+1] = b1 | 126
binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
default:
framePos += 8
c.writeBuf[framePos] = b0
c.writeBuf[framePos+1] = b1 | byte(length)
}
if !c.isServer {
key := newMaskKey()
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos])
if len(extra) > 0 {
c.writeErr = errors.New("websocket: internal error, extra used in client mode")
return c.writeErr
}
}
// Write the buffers to the connection.
c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
// Setup for next frame.
c.writePos = maxFrameHeaderSize
c.writeFrameType = continuationFrame
if final {
c.writeSeq++
c.writeFrameType = noFrame
}
return c.writeErr
}
// ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
var r io.Reader
messageType, r, err = c.NextReader()
if err != nil {
return messageType, nil, err
}
p, err = ioutil.ReadAll(r)
return messageType, p, err
}
// SetReadDeadline sets the read deadline on the underlying network connection.
// After a read has timed out, the websocket connection state is corrupt and
// all future reads will return an error. A zero value for t means reads will
// not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetReadLimit sets the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer
// and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) {
c.readLimit = limit
}
// SetPingHandler sets the handler for ping messages received from the peer.
// The default ping handler sends a pong to the peer.
func (c *Conn) SetPingHandler(h func(string) error) {
if h == nil {
h = func(message string) error {
c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
return nil
}
}
c.handlePing = h
}
// SetPongHandler sets the handler for pong messages received from the peer.
// The default pong handler does nothing.
func (c *Conn) SetPongHandler(h func(string) error) {
if h == nil {
h = func(string) error { return nil }
}
c.handlePong = h
}
// UnderlyingConn returns the internal net.Conn. This can be used to further
// modifications to connection specific flags.
func (c *Conn) UnderlyingConn() net.Conn {
return c.conn
}
func (r messageReader) Read(b []byte) (int, error) {
if r.seq != r.c.readSeq {
return 0, io.EOF
}
for r.c.readErr == nil {
if r.c.readRemaining > 0 {
if int64(len(b)) > r.c.readRemaining {
b = b[:r.c.readRemaining]
}
n, err := r.c.br.Read(b)
r.c.readErr = hideTempErr(err)
if r.c.isServer {
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
}
r.c.readRemaining -= int64(n)
return n, r.c.readErr
}
if r.c.readFinal {
r.c.readSeq++
return 0, io.EOF
}
frameType, err := r.c.advanceFrame()
switch {
case err != nil:
r.c.readErr = hideTempErr(err)
case frameType == TextMessage || frameType == BinaryMessage:
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
}
}
err := r.c.readErr
if err == io.EOF && r.seq == r.c.readSeq {
err = errUnexpectedEOF
}
return 0, err
}
func (w messageWriter) err() error {
c := w.c
if c.writeSeq != w.seq {
return errWriteClosed
}
if c.writeErr != nil {
return c.writeErr
}
return nil
}
func (w messageWriter) ncopy(max int) (int, error) {
n := len(w.c.writeBuf) - w.c.writePos
if n <= 0 {
if err := w.c.flushFrame(false, nil); err != nil {
return 0, err
}
n = len(w.c.writeBuf) - w.c.writePos
}
if n > max {
n = max
}
return n, nil
}
func (w messageWriter) write(final bool, p []byte) (int, error) {
if err := w.err(); err != nil {
return 0, err
}
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
// Don't buffer large messages.
err := w.c.flushFrame(final, p)
if err != nil {
return 0, err
}
return len(p), nil
}
nn := len(p)
for len(p) > 0 {
n, err := w.ncopy(len(p))
if err != nil {
return 0, err
}
copy(w.c.writeBuf[w.c.writePos:], p[:n])
w.c.writePos += n
p = p[n:]
}
return nn, nil
}
func (w messageWriter) Write(p []byte) (int, error) {
return w.write(false, p)
}
func (w messageWriter) WriteString(p string) (int, error) {
if err := w.err(); err != nil {
return 0, err
}
nn := len(p)
for len(p) > 0 {
n, err := w.ncopy(len(p))
if err != nil {
return 0, err
}
copy(w.c.writeBuf[w.c.writePos:], p[:n])
w.c.writePos += n
p = p[n:]
}
return nn, nil
}
func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
if err := w.err(); err != nil {
return 0, err
}
for {
if w.c.writePos == len(w.c.writeBuf) {
err = w.c.flushFrame(false, nil)
if err != nil {
break
}
}
var n int
n, err = r.Read(w.c.writeBuf[w.c.writePos:])
w.c.writePos += n
nn += int64(n)
if err != nil {
if err == io.EOF {
err = nil
}
break
}
}
return nn, err
}
func (w messageWriter) Close() error {
if err := w.err(); err != nil {
return err
}
return w.c.flushFrame(true, nil)
}
func maskBytes(key [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
func newMaskKey() [4]byte {
n := rand.Uint32()
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
}
func isControl(frameType int) bool {
return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
}
func isData(frameType int) bool {
return frameType == TextMessage || frameType == BinaryMessage
}
func hideTempErr(err error) error {
if e, ok := err.(net.Error); ok && e.Temporary() {
err = &netError{msg: e.Error(), timeout: e.Timeout()}
}
return err
}
// FormatCloseMessage formats closeCode and text as a WebSocket close message.
func FormatCloseMessage(closeCode int, text string) []byte {
buf := make([]byte, 2+len(text))
binary.BigEndian.PutUint16(buf, uint16(closeCode))
copy(buf[2:], text)
return buf
}
package httpd
import (
"bufio"
"compress/gzip"
"encoding/json"
"net"
"net/http"
"net/url"
"strings"
"time"
)
const (
//---------
// Gzip Compersion Level
//---------
GzipBest = gzip.BestCompression
GzipBestSpeed = gzip.BestSpeed
GzipDefault = gzip.DefaultCompression
GzipNoCompression = gzip.NoCompression
)
//
// Response
//
func (r response) Read() []byte {
return r.b.Bytes()
}
func (r response) Write(bytes []byte) (int, error) {
return r.b.Write(bytes)
}
func (r response) Header() http.Header {
return r.r.Header()
}
func (r response) WriteHeader(status int) {
if r.hijacked {
return
}
r.r.WriteHeader(status)
}
func (r *response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
var netConn, rw, err = r.r.(http.Hijacker).Hijack()
if err != nil {
return nil, nil, err
}
r.hijacked = true
return netConn, rw, err
}
//
// Context
//
func (c *Context) Param(name string) (value string) {
for i := range c.params {
if c.params[i].Key == name {
return c.params[i].Value
}
}
return ""
}
func (c *Context) UrlParse(u string) (*url.URL, error) {
return url.Parse(u)
}
// Request Header Contains Token Value
func (c *Context) RequestContainsToken(name, value string) bool {
for _, v := range c.Request.Header[name] {
for _, s := range strings.Split(v, ",") {
if strings.EqualFold(value, strings.TrimSpace(s)) {
return true
}
}
}
return false
}
func (c *Context) RequestURL() *url.URL {
return c.Request.URL
}
// Return RequestURLString
func (c *Context) RequestURLString() string {
return c.Request.URL.String()
}
func (c *Context) SetCookie(name, value string, expire time.Duration) {
expiration := time.Now().Add(expire)
cookie := http.Cookie{
Name: name,
Value: value,
Expires: expiration,
}
http.SetCookie(c.Response, &cookie)
}
// JSON sends an application/json response with status code.
func (c *Context) JSON(code int, i interface{}) error {
c.Response.Header().Set(ContentType, ApplicationJSON)
c.Response.WriteHeader(code)
return json.NewEncoder(c.Response.b).Encode(i)
}
// String sends a text/plain response with status code.
func (c *Context) String(code int, s string) error {
c.Response.Header().Set(ContentType, TextPlain)
c.Response.WriteHeader(code)
_, err := c.Response.Write([]byte(s))
return err
}
// HTML sends a text/html response with status code.
func (c *Context) HTML(code int, html string) error {
c.Response.Header().Set(ContentType, TextHTML)
c.Response.WriteHeader(code)
_, err := c.Response.Write([]byte(html))
return err
}
// NoContent sends a response with no body and a status code.
func (c *Context) NoContent(code int) error {
c.Response.WriteHeader(code)
return nil
}
// Error Method
func (c *Context) Error(error string, code int) {
c.Response.Header().Set(ContentType, "text/plain; charset=utf-8")
c.Response.WriteHeader(code)
c.Response.Write([]byte(error))
}
// Redirect Method
func (c *Context) Redirect(url string, code int) {
c.headerSent = true
http.Redirect(c.Response, c.Request, url, code)
}
// Send Response
func (c *Context) Send() {
if c.headerSent || c.Response.hijacked {
return
}
c.headerSent = true
// Make sure a handler have enabled gzip compression
if c.Response.Header().Get(ContentEncoding) == "gzip" {
// Make sure the browser accept Gzip encoding
if c.RequestContainsToken(AcceptEncoding, "gzip") {
content := c.Response.Read()
// Remove previews Content-Length header
c.Response.Header().Del(ContentLength)
// If Content-Type header is not set
if c.Response.Header().Get(ContentType) == "" {
// Force detection of Content-Type
c.Response.Header().Set(ContentType, http.DetectContentType(content))
}
// Gzip Response
writer, err := gzip.NewWriterLevel(c.Response.r, c.GzipLevel)
if err != nil {
writer, _ = gzip.NewWriterLevel(c.Response.r, GzipBestSpeed)
}
writer.Write(content)
writer.Close()
return
}
// Remove encoding
c.Response.Header().Del(ContentEncoding)
}
c.Response.r.Write(c.Response.Read())
}
func (c *Context) getSubProtocols() []string {
h := strings.TrimSpace(c.Request.Header.Get(WebsocketProtocol))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
func (c *Context) sameOrigin(origin string, host string) bool {
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin)
if err != nil {
return false
}
return u.Host == host
}
package httpd
import (
"bufio"
"bytes"
"errors"
"log"
"net"
"net/http"
"os"
"time"
)
const (
//-------------
// Media types
//-------------
ApplicationJSON = "application/json"
ApplicationProtobuf = "application/protobuf"
ApplicationMsgpack = "application/msgpack"
TextPlain = "text/plain"
TextHTML = "text/html"
ApplicationForm = "application/x-www-form-urlencoded"
MultipartForm = "multipart/form-data"
//---------
// Headers
//---------
Accept = "Accept"
AcceptEncoding = "Accept-Encoding"
ContentDisposition = "Content-Disposition"
ContentEncoding = "Content-Encoding"
ContentLength = "Content-Length"
ContentType = "Content-Type"
Authorization = "Authorization"
// Websocket related
Connection = "Connection"
Upgrade = "Upgrade"
Origin = "Origin"
WebsocketKey = "Sec-Websocket-Key"
WebsocketProtocol = "Sec-Websocket-Protocol"
WebsocketVersion = "Sec-Websocket-Version"
WebSocketExtensions = "Sec-WebSocket-Extensions"
//---------
// Status Codes
//---------
StatusAccepted = http.StatusAccepted
StatusBadGateway = http.StatusBadGateway
StatusBadRequest = http.StatusBadRequest
StatusConflict = http.StatusConflict
StatusContinue = http.StatusContinue
StatusCreated = http.StatusCreated
StatusExpectationFailed = http.StatusExpectationFailed
StatusForbidden = http.StatusForbidden
StatusFound = http.StatusFound
StatusGatewayTimeout = http.StatusGatewayTimeout
StatusGone = http.StatusGone
StatusHTTPVersionNotSupported = http.StatusHTTPVersionNotSupported
StatusInternalServerError = http.StatusInternalServerError
StatusLengthRequired = http.StatusLengthRequired
StatusMethodNotAllowed = http.StatusMethodNotAllowed
StatusMovedPermanently = http.StatusMovedPermanently
StatusMultipleChoices = http.StatusMultipleChoices
StatusNoContent = http.StatusNoContent
StatusNonAuthoritativeInfo = http.StatusNonAuthoritativeInfo
StatusNotAcceptable = http.StatusNotAcceptable
StatusNotFound = http.StatusNotFound
StatusNotImplemented = http.StatusNotImplemented
StatusNotModified = http.StatusNotModified
StatusOK = http.StatusOK
StatusPartialContent = http.StatusPartialContent
StatusPaymentRequired = http.StatusPaymentRequired
StatusPreconditionFailed = http.StatusPreconditionFailed
StatusProxyAuthRequired = http.StatusProxyAuthRequired
StatusRequestEntityTooLarge = http.StatusRequestEntityTooLarge
StatusRequestTimeout = http.StatusRequestTimeout
StatusRequestURITooLong = http.StatusRequestURITooLong
StatusRequestedRangeNotSatisfiable = http.StatusRequestedRangeNotSatisfiable
StatusResetContent = http.StatusResetContent
StatusSeeOther = http.StatusSeeOther
StatusServiceUnavailable = http.StatusServiceUnavailable
StatusSwitchingProtocols = http.StatusSwitchingProtocols
StatusTeapot = http.StatusTeapot
StatusTemporaryRedirect = http.StatusTemporaryRedirect
StatusUnauthorized = http.StatusUnauthorized
StatusUnsupportedMediaType = http.StatusUnsupportedMediaType
StatusUseProxy = http.StatusUseProxy
)
type (
// Handler context
Context struct {
headerSent bool
challenge string
Request *http.Request
Response response
GzipLevel int
RemoteAddr string
RequestURI string
RequestMethod string
Session interface{}
params []routeParam
}
Handle func(*Context)
HandleWs func(*Context, *Conn)
// Handler
Handler interface {
Handle(*Context)
}
HandlerFunc func(*Context)
Router struct {
prefix string
parent *Router
children []*Router
trees map[string]*route
ErrorLogger *log.Logger
MethodNotFound HandlerFunc
HandleNotAllowed bool
MethodNotAllowed HandlerFunc
PanicHandler func(Context, interface{})
beforeHandle []func(Handler) Handler
afterHandle []func(Handler) Handler
}
Upgrader struct {
HandshakeTimeout time.Duration
ReadBufferSize int
WriteBufferSize int
Subprotocols []string
CheckOrigin func(string, string) bool
}
Conn struct {
conn net.Conn
isServer bool
subprotocol string
// Write fields
mu chan bool // used as mutex to protect write to conn and closeSent
closeSent bool // true if close message was sent
// Message writer fields.
writeErr error
writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf.
writeFrameType int // type of the current frame.
writeSeq int // incremented to invalidate message writers.
writeDeadline time.Time
// Read fields
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
readFinal bool // true the current message has more frames.
readSeq int // incremented to invalidate message readers.
readLength int64 // Message size.
readLimit int64 // Maximum message size.
readMaskPos int
readMaskKey [4]byte
handlePong func(string) error
handlePing func(string) error
}
// route
route struct {
path string
wildChild bool
nType nodeType
maxParams uint8
indices string
children []*route
upgrader Upgrader
handle interface{}
priority uint32
}
// Private structs
routeParam struct {
Key string
Value string
}
response struct {
hijacked bool
r http.ResponseWriter
b *bytes.Buffer
}
finalRouter struct{}
)
// Convert ServeHTTP to Handle
func (f HandlerFunc) Handle(c *Context) {
f(c)
}
//
func NewRouter() *Router {
return &Router{
prefix: "/",
ErrorLogger: log.New(os.Stderr, "", log.LstdFlags),
HandleNotAllowed: false,
}
}
func ListenAndServe(addr string, handler http.Handler) error {
return http.ListenAndServe(addr, handler)
}
func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
return http.ListenAndServeTLS(addr, certFile, keyFile, handler)
}
func StatusText(code int) string {
return http.StatusText(code)
}
//
// Private functions
//
// Create new Context
func newContext(request *http.Request, responsew http.ResponseWriter) *Context {
return &Context{
Request: request,
Response: response{
hijacked: false,
r: responsew,
b: &bytes.Buffer{},
},
RemoteAddr: request.RemoteAddr,
RequestURI: request.RequestURI,
RequestMethod: request.Method,
GzipLevel: GzipBestSpeed,
params: nil,
}
}
// Create new Final Router
func newRouter() Handler {
return Handler(&finalRouter{})
}
// Final router Handle function
func (r *finalRouter) Handle(c *Context) {
c.Send()
}
func (u *Upgrader) upgrade(challenge string, subprotocol string, netConn net.Conn, rw *bufio.ReadWriter) (*Conn, error) {
var br *bufio.Reader
br = rw.Reader
if br.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
con := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
con.subprotocol = subprotocol
p := con.writeBuf[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challenge)...)
p = append(p, "\r\n"...)
if con.subprotocol != "" {
p = append(p, "Sec-Websocket-Protocol: "...)
p = append(p, con.subprotocol...)
p = append(p, "\r\n"...)
}
p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if _, err := netConn.Write(p); err != nil {
netConn.Close()
return nil, err
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}
return con, nil
}
package httpd
import (
"fmt"
"net/http"
"strings"
)
// Make sure the Router conforms with the http.Handler interface
var _ http.Handler = NewRouter()
// GET is a shortcut for router.Handle("GET", path, handle)
func (r *Router) Get(path string, handle Handle) error {
return r.Add("GET", path, handle, nil)
}
// HEAD is a shortcut for router.Handle("HEAD", path, handle)
func (r *Router) Head(path string, handle Handle) error {
return r.Add("HEAD", path, handle, nil)
}
// OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle)
func (r *Router) Options(path string, handle Handle) error {
return r.Add("OPTIONS", path, handle, nil)
}
// POST is a shortcut for router.Handle("POST", path, handle)
func (r *Router) Post(path string, handle Handle) error {
return r.Add("POST", path, handle, nil)
}
// PUT is a shortcut for router.Handle("PUT", path, handle)
func (r *Router) Put(path string, handle Handle) error {
return r.Add("PUT", path, handle, nil)
}
// PATCH is a shortcut for router.Handle("PATCH", path, handle)
func (r *Router) Patch(path string, handle Handle) error {
return r.Add("PATCH", path, handle, nil)
}
// DELETE is a shortcut for router.Handle("DELETE", path, handle)
func (r *Router) Delete(path string, handle Handle) error {
return r.Add("DELETE", path, handle, nil)
}
// Websocket is a shortcut for router.Handle("GET", path, handle)
func (r *Router) Websocket(path string, up Upgrader, handle HandleWs) error {
return r.Add("GET", path, handle, &up)
}
// Create new SubRouter
func (r *Router) Subrouter(path string) *Router {
sr := &Router{
parent: r,
trees: r.trees,
ErrorLogger: r.ErrorLogger,
prefix: r.subPath(path),
HandleNotAllowed: false,
}
r.children = append(r.children, sr)
return sr
}
// Handle()
func (r *Router) Add(method, path string, handle interface{}, up *Upgrader) error {
path = r.subPath(path)
if path[0] != '/' {
return fmt.Errorf("path must begin with '/' in path '" + path + "'")
}
if r.trees == nil {
r.trees = make(map[string]*route)
}
switch handle.(type) {
case Handle, HandleWs:
root := r.trees[method]
if root == nil {
root = new(route)
r.trees[method] = root
}
err := root.addRoute(path, handle, up)
if err != nil {
return r.Errorf("%s", err)
}
default:
return r.Errorf("Unknown handler !")
}
return nil
}
// Before Handler
func (r *Router) BeforeHandle(middleware ...func(Handler) Handler) {
r.beforeHandle = append(r.beforeHandle, middleware...)
}
// After Handler
func (r *Router) AfterHandle(middleware ...func(Handler) Handler) {
r.afterHandle = append(r.afterHandle, middleware...)
}
// ServeFiles - http FileServer
func (r *Router) ServeFiles(path string, root string) error {
if len(path) < 10 || path[len(path)-10:] != "/*filepath" {
r.Errorf("path must end with /*filepath in path '" + path + "'")
}
fileServer := http.FileServer(http.Dir(root))
r.Add("GET", path, Handle(
func(c *Context) {
c.Request.URL.Path = c.Param("filepath")
fileServer.ServeHTTP(c.Response, c.Request)
},
), nil)
return nil
}
// get SubPath of route
func (r *Router) subPath(p string) string {
pre := r.prefix
if (pre == "/" || pre[:len(pre)-1] == "/") && p[:1] == "/" {
pre = pre[:len(pre)-1]
}
return pre + p
}
// ServeHTTP
func (r *Router) ServeHTTP(res http.ResponseWriter, req *http.Request) {
child := r.getLastChild(req.URL.Path, r)
hndlr := newRouter()
for i := len(child.afterHandle) - 1; i >= 0; i-- {
hndlr = child.afterHandle[i](hndlr)
}
hndlr = func(next Handler) Handler {
return HandlerFunc(func(c *Context) {
child.Handle(c)
next.Handle(c)
})
}(hndlr)
for i := len(child.beforeHandle) - 1; i >= 0; i-- {
hndlr = child.beforeHandle[i](hndlr)
}
context := newContext(req, res)
hndlr.Handle(context)
}
// Handle
func (r *Router) Handle(c *Context) {
if r.PanicHandler != nil {
defer r.recv(c)
}
if root := r.trees[c.RequestMethod]; root != nil {
path := c.Request.URL.Path
if handle, ps, tsr, up := root.getValue(path); handle != nil {
c.params = ps
switch handle.(type) {
case Handle:
handle.(Handle)(c)
case HandleWs:
ws, err := r.websocket(c, up)
if err != nil {
r.Errorf("%s", err)
return
}
handle.(HandleWs)(c, ws)
}
return
} else if c.RequestMethod != "CONNECT" && path != "/" {
code := 301
if c.RequestMethod != "GET" {
code = 307
}
if tsr {
if len(path) > 1 && path[len(path)-1] == '/' {
c.Request.URL.Path = path[:len(path)-1]
} else {
c.Request.URL.Path = path + "/"
}
c.Redirect(c.RequestURLString(), code)
return
}
fixedPath, found := root.findCaseInsensitivePath(
cleanPath(path),
true,
)
if found {
c.Request.URL.Path = string(fixedPath)
c.Redirect(c.RequestURLString(), code)
return
}
}
}
// Handle 405
if r.HandleNotAllowed {
for method := range r.trees {
if method == c.RequestMethod {
continue
}
handle, _, _, _ := r.trees[method].getValue(c.Request.URL.Path)
if handle != nil {
if r.MethodNotAllowed != nil {
r.MethodNotAllowed(c)
} else {
c.Error(
StatusText(StatusMethodNotAllowed),
StatusMethodNotAllowed,
)
}
return
}
}
}
if r.MethodNotFound != nil {
r.MethodNotFound(c)
} else {
http.NotFound(c.Response, c.Request)
}
}
// Handle Websocket
func (r *Router) websocket(c *Context, up Upgrader) (*Conn, error) {
if !c.RequestContainsToken(WebsocketVersion, "13") {
c.Error(StatusText(StatusBadRequest), StatusBadRequest)
return nil, r.Errorf("Websocket: version != 13")
}
if !c.RequestContainsToken(Connection, "upgrade") {
c.Error(StatusText(StatusBadRequest), StatusBadRequest)
return nil, r.Errorf("Websocket: missing header Connection:upgrade")
}
if !c.RequestContainsToken(Upgrade, "websocket") {
c.Error(StatusText(StatusBadRequest), StatusBadRequest)
return nil, r.Errorf("Websocket: missing header: Upgrade:websocket")
}
checkOrigin := up.CheckOrigin
if checkOrigin == nil {
checkOrigin = c.sameOrigin
}
//
u, err := c.UrlParse(c.Request.Header.Get(Origin))
if err != nil {
c.Error(StatusText(StatusForbidden), StatusForbidden)
return nil, r.Errorf("Websocket: cannot parse origin")
}
if !checkOrigin(u.Host, c.RequestURL().Host) {
c.Error(StatusText(StatusForbidden), StatusForbidden)
return nil, r.Errorf("Websocket: invalid origin")
}
challenge := c.Request.Header.Get(WebsocketKey)
if challenge == "" {
c.Error(StatusText(StatusBadRequest), StatusBadRequest)
return nil, r.Errorf("Websocket: missing header: Sec-Websocket-Key")
}
// WebsocketProtocol negocitation
subprotocol := ""
rsubprotocol := c.getSubProtocols()
if len(up.Subprotocols) > 0 && len(rsubprotocol) > 0 {
Scan:
for _, sproto := range up.Subprotocols {
for _, cproto := range rsubprotocol {
if cproto == sproto {
subprotocol = sproto
break Scan
}
}
}
if subprotocol == "" {
c.Error(StatusText(StatusBadRequest), StatusBadRequest)
return nil, r.Errorf("Websocket: server does not implement %s", rsubprotocol)
}
}
net, buf, err := c.Response.Hijack()
if err != nil {
c.Error(StatusText(StatusBadRequest), StatusBadRequest)
return nil, r.Errorf("Websocket: unable to hijack connection")
}
return up.upgrade(challenge, subprotocol, net, buf)
}
func (r *Router) getLastChild(path string, last *Router) *Router {
for _, child := range last.children {
if strings.HasPrefix(path, child.prefix) {
last = child.getLastChild(path, child)
}
}
return last
}
func (r *Router) Errorf(err string, a ...interface{}) error {
r.ErrorLogger.Printf("httpd.Router: "+err, a...)
return fmt.Errorf(err, a...)
}
func (r *Router) recv(c *Context) {
if rcv := recover(); rcv != nil {
//r.PanicHandler(c, rcv)
}
}
package session
import (
"crypto/sha1"
"fmt"
"sync"
"time"
)
type (
session struct {
sessions map[string]*Session
m sync.Mutex
Expire time.Duration
}
Session struct {
sid string
token string
Data interface{}
expire *time.Time
}
)
const (
sGarbage = time.Duration(10 * time.Minute)
sExpire = time.Duration(1 * time.Hour)
)
var (
std = &session{
sessions: make(map[string]*Session),
Expire: sExpire,
}
)
func init() {
time.AfterFunc(sGarbage, std.sessiongc)
}
// Set Global Session Expire Duration
func SetExpire(expire int64) {
std.m.Lock()
std.Expire = time.Duration(expire) * time.Second
std.m.Unlock()
}
// Get Session Data
func Get(name string) *Session {
if name == "" {
return nil
}
std.m.Lock()
ses, found := std.sessions[name]
std.m.Unlock()
// New session or session expired
if !found {
return nil
} else if ses.expired() {
delete(std.sessions, name)
return nil
}
return ses
}
// Get Session Id
func (s *Session) Id() string {
return s.sid
}
// Get CSRF token
func (s *Session) Token() string {
if s.token == "" {
s.token = newRandomId()
}
return s.token
}
// Validate Token
func (s *Session) Validate(token string) bool {
if token == s.token {
s.token = newRandomId()
return true
}
return false
}
// Set Session Data
func Set(data *Session) {
std.m.Lock()
expire := time.Now().Add(std.Expire)
data.expire = &expire
std.sessions[data.sid] = data
std.m.Unlock()
}
// Create new Session
func New() *Session {
expire := time.Now().Add(std.Expire)
return &Session{
sid: newRandomId(),
expire: &expire,
}
}
// Check if session is expired
func (s *Session) expired() bool {
if s.expire == nil {
return false
}
return s.expire.Before(time.Now())
}
// Garbage Collector
func (s *session) sessiongc() {
s.m.Lock()
for _, ses := range s.sessions {
if ses.expired() {
delete(s.sessions, ses.Id())
}
}
s.m.Unlock()
time.AfterFunc(sGarbage, s.sessiongc)
}
// New Random Id
func newRandomId() string {
h := sha1.New()
c := []byte(time.Now().String())
h.Write(c)
return fmt.Sprintf("%x", h.Sum(nil))
}
package session
import (
"testing"
"time"
)
type (
sData struct {
user int
}
)
func TestRandomness(t *testing.T) {
max := 1000
sessions := make(map[int]string)
for i := 0; i < max; i++ {
sessions[i] = newRandomId()
for j := 0; j < i; j++ {
if sessions[i] == sessions[j] {
t.Errorf("Duplicate session id found ")
}
}
}
}
func TestExire(t *testing.T) {
ses := New()
ses.expire = nil
if ses.expired() == true {
t.Errorf("Expected to not be expired")
}
nexp := time.Now().Truncate(3 * time.Second)
ses.expire = &nexp
if ses.expired() == false {
t.Errorf("Expected be expired")
}
}
func TestToken(t *testing.T) {
s1 := New()
if s1.Token() == "" {
t.Errorf("No token?")
}
old := s1.Token()
// Validate token
if s1.Validate(old) == false {
t.Errorf("Invalid token 1 ?")
}
// Token refresh
if s1.Token() == old {
t.Errorf("Same token?")
}
// Revalidate old token
if s1.Validate(old) == true {
t.Errorf("Same old token?")
}
Set(s1)
s2 := Get(s1.Id())
if s2.Validate(s2.Token()) != true {
t.Errorf("Invalid token 2 ?")
}
// Read from Session again
s3 := Get(s2.Id())
if s3.Token() != s2.Token() {
t.Errorf("Storage issue Invalid token 2 ?")
}
if s2.Validate(s3.Token()) == false {
t.Errorf("Storage issue Invalid token 3 ?")
}
if s3.Token() != s2.Token() {
t.Errorf("Storage issue Invalid token 2 ?")
}
}
func TestStorage(t *testing.T) {
s1 := New()
SetExpire(1)
if Get("") != nil {
t.Errorf("Storage for empy session ?")
}
if Get(s1.Id()) != nil {
t.Errorf("Allready in storage ?")
}
Set(s1)
if Get(s1.Id()) == nil {
t.Errorf("Missing from storage ?")
}
s1.Data = &sData{
user: 100,
}
Set(s1)
data := Get(s1.Id()).Data.(*sData)
if data.user != 100 {
t.Errorf("Invalid data")
}
// Force Expire
nexp := time.Now().Truncate(3 * time.Second)
std.sessions[s1.Id()].expire = &nexp
if Get(s1.Id()) != nil {
t.Errorf("Session not expired ?")
}
s2 := New()
Set(s2)
s3 := New()
Set(s3)
s4 := New()
Set(s4)
// Set Expire
s3exp := time.Now().Truncate(3 * time.Second)
std.sessions[s3.Id()].expire = &s3exp
std.sessiongc()
//time.Sleep(10 * time.Second)
}
package httpd
import (
"fmt"
"strings"
"unicode"
)
func min(a, b int) int {
if a <= b {
return a
}
return b
}
func countParams(path string) uint8 {
var n uint
for i := 0; i < len(path); i++ {
if path[i] != ':' && path[i] != '*' {
continue
}
n++
}
if n >= 255 {
return 255
}
return uint8(n)
}
type nodeType uint8
const (
static nodeType = 0
param nodeType = 1
catchAll nodeType = 2
)
func (n *route) incrementChildPrio(pos int) int {
n.children[pos].priority++
prio := n.children[pos].priority
// adjust position (move to front)
newPos := pos
for newPos > 0 && n.children[newPos-1].priority < prio {
// swap node positions
tmpN := n.children[newPos-1]
n.children[newPos-1] = n.children[newPos]
n.children[newPos] = tmpN
newPos--
}
// build new index char string
if newPos != pos {
n.indices = n.indices[:newPos] + // unchanged prefix, might be empty
n.indices[pos:pos+1] + // the index char we move
n.indices[newPos:pos] + n.indices[pos+1:] // rest without char at 'pos'
}
return newPos
}
func (n *route) addRoute(path string, handle interface{}, up *Upgrader) error {
fullPath := path
n.priority++
numParams := countParams(path)
// non-empty tree
if len(n.path) > 0 || len(n.children) > 0 {
walk:
for {
// Update maxParams of the current node
if numParams > n.maxParams {
n.maxParams = numParams
}
// Find the longest common prefix.
// This also implies that the common prefix contains no ':' or '*'
// since the existing key can't contain those chars.
i := 0
max := min(len(path), len(n.path))
for i < max && path[i] == n.path[i] {
i++
}
// Split edge
if i < len(n.path) {
child := route{
path: n.path[i:],
wildChild: n.wildChild,
indices: n.indices,
children: n.children,
handle: n.handle,
priority: n.priority - 1,
}
// Update maxParams (max of all children)
for i := range child.children {
if child.children[i].maxParams > child.maxParams {
child.maxParams = child.children[i].maxParams
}
}
n.children = []*route{&child}
// []byte for proper unicode char conversion, see #65
n.indices = string([]byte{n.path[i]})
n.path = path[:i]
n.handle = nil
n.wildChild = false
}
// Make new node a child of this node
if i < len(path) {
path = path[i:]
if n.wildChild {
n = n.children[0]
n.priority++
// Update maxParams of the child node
if numParams > n.maxParams {
n.maxParams = numParams
}
numParams--
// Check if the wildcard matches
if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
// check for longer wildcard, e.g. :name and :names
if len(n.path) >= len(path) || path[len(n.path)] == '/' {
continue walk
}
}
return fmt.Errorf(
"path segment '%s' conflicts with existing wildcard '%s' in path '%s'",
path,
n.path,
fullPath,
)
}
c := path[0]
// slash after param
if n.nType == param && c == '/' && len(n.children) == 1 {
n = n.children[0]
n.priority++
continue walk
}
// Check if a child with the next path byte exists
for i := 0; i < len(n.indices); i++ {
if c == n.indices[i] {
i = n.incrementChildPrio(i)
n = n.children[i]
continue walk
}
}
// Otherwise insert it
if c != ':' && c != '*' {
// []byte for proper unicode char conversion, see #65
n.indices += string([]byte{c})
child := &route{
maxParams: numParams,
}
n.children = append(n.children, child)
n.incrementChildPrio(len(n.indices) - 1)
n = child
}
n.insertChild(numParams, path, fullPath, handle, up)
return nil
} else if i == len(path) { // Make node a (in-path) leaf
if n.handle != nil {
return fmt.Errorf("a handle is already registered for path ''" + fullPath + "'")
}
n.handle = handle
}
return nil
}
} else { // Empty tree
return n.insertChild(numParams, path, fullPath, handle, up)
}
}
func (n *route) insertChild(numParams uint8, path, fullPath string, handle interface{}, up *Upgrader) error {
var offset int // already handled bytes of the path
// find prefix until first wildcard (beginning with ':'' or '*'')
for i, max := 0, len(path); numParams > 0; i++ {
c := path[i]
if c != ':' && c != '*' {
continue
}
// find wildcard end (either '/' or path end)
end := i + 1
for end < max && path[end] != '/' {
switch path[end] {
// the wildcard name must not contain ':' and '*'
case ':', '*':
fmt.Errorf("only one wildcard per path segment is allowed, has: '" +
path[i:] + "' in path '" + fullPath + "'")
default:
end++
}
}
// check if this Node existing children which would be
// unreachable if we insert the wildcard here
if len(n.children) > 0 {
return fmt.Errorf("wildcard route '" + path[i:end] +
"' conflicts with existing children in path '" + fullPath + "'")
}
// check if the wildcard has a name
if end-i < 2 {
return fmt.Errorf("wildcards must be named with a non-empty name in path '" + fullPath + "'")
}
if c == ':' { // param
// split path at the beginning of the wildcard
if i > 0 {
n.path = path[offset:i]
offset = i
}
child := &route{
nType: param,
maxParams: numParams,
}
n.children = []*route{child}
n.wildChild = true
n = child
n.priority++
numParams--
// if the path doesn't end with the wildcard, then there
// will be another non-wildcard subpath starting with '/'
if end < max {
n.path = path[offset:end]
offset = end
child := &route{
maxParams: numParams,
priority: 1,
}
n.children = []*route{child}
n = child
}
} else { // catchAll
if end != max || numParams > 1 {
return fmt.Errorf("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
}
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
return fmt.Errorf("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
}
// currently fixed width 1 for '/'
i--
if path[i] != '/' {
return fmt.Errorf("no / before catch-all in path '" + fullPath + "'")
}
n.path = path[offset:i]
// first node: catchAll node with empty path
child := &route{
wildChild: true,
nType: catchAll,
maxParams: 1,
}
n.children = []*route{child}
n.indices = string(path[i])
n = child
n.priority++
// second node: node holding the variable
child = &route{
path: path[i:],
nType: catchAll,
maxParams: 1,
handle: handle,
priority: 1,
}
n.children = []*route{child}
return nil
}
}
// insert remaining path part and handle to the leaf
n.path = path[offset:]
if up != nil {
n.upgrader = *up
}
n.handle = handle
return nil
}
// Returns the handle registered with the given path (key). The values of
// wildcards are saved to a map.
// If no handle can be found, a TSR (trailing slash redirect) recommendation is
// made if a handle exists with an extra (without the) trailing slash for the
// given path.
func (n *route) getValue(path string) (handle interface{}, p []routeParam, tsr bool, up Upgrader) {
walk: // Outer loop for walking the tree
for {
if len(path) > len(n.path) {
if path[:len(n.path)] == n.path {
path = path[len(n.path):]
// If this node does not have a wildcard (param or catchAll)
// child, we can just look up the next child node and continue
// to walk down the tree
if !n.wildChild {
c := path[0]
for i := 0; i < len(n.indices); i++ {
if c == n.indices[i] {
n = n.children[i]
continue walk
}
}
// Nothing found.
// We can recommend to redirect to the same URL without a
// trailing slash if a leaf exists for that path.
tsr = (path == "/" && n.handle != nil)
return
}
// handle wildcard child
n = n.children[0]
switch n.nType {
case param:
// find param end (either '/' or path end)
end := 0
for end < len(path) && path[end] != '/' {
end++
}
// save param value
if p == nil {
// lazy allocation
p = make([]routeParam, 0, n.maxParams)
}
i := len(p)
p = p[:i+1] // expand slice within preallocated capacity
p[i].Key = n.path[1:]
p[i].Value = path[:end]
// we need to go deeper!
if end < len(path) {
if len(n.children) > 0 {
path = path[end:]
n = n.children[0]
continue walk
}
// ... but we can't
tsr = (len(path) == end+1)
return
}
if handle = n.handle; handle != nil {
return
} else if len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists for TSR recommendation
n = n.children[0]
tsr = (n.path == "/" && n.handle != nil)
}
return
case catchAll:
// save param value
if p == nil {
// lazy allocation
p = make([]routeParam, 0, n.maxParams)
}
i := len(p)
p = p[:i+1] // expand slice within preallocated capacity
p[i].Key = n.path[2:]
p[i].Value = path
handle = n.handle
return
default:
panic("invalid node type")
}
}
} else if path == n.path {
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if handle = n.handle; handle != nil {
up = n.upgrader
return
}
// No handle found. Check if a handle for this path + a
// trailing slash exists for trailing slash recommendation
for i := 0; i < len(n.indices); i++ {
if n.indices[i] == '/' {
n = n.children[i]
tsr = (len(n.path) == 1 && n.handle != nil) ||
(n.nType == catchAll && n.children[0].handle != nil)
return
}
}
return
}
// Nothing found. We can recommend to redirect to the same URL with an
// extra trailing slash if a leaf exists for that path
tsr = (path == "/") ||
(len(n.path) == len(path)+1 && n.path[len(path)] == '/' &&
path == n.path[:len(n.path)-1] && n.handle != nil)
return
}
}
// Makes a case-insensitive lookup of the given path and tries to find a handler.
// It can optionally also fix trailing slashes.
// It returns the case-corrected path and a bool indicating whether the lookup
// was successful.
func (n *route) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) {
ciPath = make([]byte, 0, len(path)+1) // preallocate enough memory
// Outer loop for walking the tree
for len(path) >= len(n.path) && strings.ToLower(path[:len(n.path)]) == strings.ToLower(n.path) {
path = path[len(n.path):]
ciPath = append(ciPath, n.path...)
if len(path) > 0 {
// If this node does not have a wildcard (param or catchAll) child,
// we can just look up the next child node and continue to walk down
// the tree
if !n.wildChild {
r := unicode.ToLower(rune(path[0]))
for i, index := range n.indices {
// must use recursive approach since both index and
// ToLower(index) could exist. We must check both.
if r == unicode.ToLower(index) {
out, found := n.children[i].findCaseInsensitivePath(path, fixTrailingSlash)
if found {
return append(ciPath, out...), true
}
}
}
// Nothing found. We can recommend to redirect to the same URL
// without a trailing slash if a leaf exists for that path
found = (fixTrailingSlash && path == "/" && n.handle != nil)
return
}
n = n.children[0]
switch n.nType {
case param:
// find param end (either '/' or path end)
k := 0
for k < len(path) && path[k] != '/' {
k++
}
// add param value to case insensitive path
ciPath = append(ciPath, path[:k]...)
// we need to go deeper!
if k < len(path) {
if len(n.children) > 0 {
path = path[k:]
n = n.children[0]
continue
}
// ... but we can't
if fixTrailingSlash && len(path) == k+1 {
return ciPath, true
}
return
}
if n.handle != nil {
return ciPath, true
} else if fixTrailingSlash && len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists
n = n.children[0]
if n.path == "/" && n.handle != nil {
return append(ciPath, '/'), true
}
}
return
case catchAll:
return append(ciPath, path...), true
default:
panic("invalid node type")
}
} else {
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if n.handle != nil {
return ciPath, true
}
// No handle found.
// Try to fix the path by adding a trailing slash
if fixTrailingSlash {
for i := 0; i < len(n.indices); i++ {
if n.indices[i] == '/' {
n = n.children[i]
if (len(n.path) == 1 && n.handle != nil) ||
(n.nType == catchAll && n.children[0].handle != nil) {
return append(ciPath, '/'), true
}
return
}
}
}
return
}
}
// Nothing found.
// Try to fix the path by adding / removing a trailing slash
if fixTrailingSlash {
if path == "/" {
return ciPath, true
}
if len(path)+1 == len(n.path) && n.path[len(path)] == '/' &&
strings.ToLower(path) == strings.ToLower(n.path[:len(path)]) &&
n.handle != nil {
return append(ciPath, n.path...), true
}
}
return
}
package httpd
import (
"crypto/sha1"
"encoding/base64"
"path"
)
func cleanPath(p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
np += "/"
}
return np
}
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
package daemon
import (
"fmt"
"os"
"os/signal"
"path/filepath"
"syscall"
)
type (
daemon struct {
name string // Proccess name
uid uint32 // Switch user to
gid uint32 // Switch group to
chroot string // Chroot directory to
pid int // Current proccess pid
pidfile string // Pid file location
args []string // os.Args
env []string // Environment vars
signal chan os.Signal
// Signal Handlers
hupHandler SignalHandler // SIGHUP handler
intHandler SignalHandler // SIGINT handler
termHandler SignalHandler // SIGTERM handler
quitHandler SignalHandler // SIGQUIT handler
usr1Handler SignalHandler // SIGUSR1 handler
usr2Handler SignalHandler // SIGUSR2 handler
}
Config struct {
Uid uint32 // Child uid
Gid uint32 // Child gid
Chroot string // Child Chroot
Pidfile string // Child pidfie
HupHandler SignalHandler // SIGHUP handler
IntHandler SignalHandler // SIGINT handler
TermHandler SignalHandler // SIGTERM handler
QuitHandler SignalHandler // SIGQUIT handler
Usr1Handler SignalHandler // SIGUSR1 handler
Usr2Handler SignalHandler // SIGUSR2 handler
}
// SignalHandler must return true to cancel default signal behavior
SignalHandler func(int) bool
)
var std = new()
const (
self_proc = "/proc/self/exe"
childenv = "daemon-child-proccess"
SIGHUP = 1
SIGINT = 2
SIGQUIT = 3
SIGUSR1 = 10
SIGUSR2 = 12
SIGTERM = 15
)
//
// Public Interface
//
// Configure daemon with Settings
func Setup(s *Config) {
// Setup drop privs
std.uid = s.Uid
std.gid = s.Gid
std.chroot = s.Chroot
// Setup pidfile
std.pidfile = s.Pidfile
// Setup Handlers
std.hupHandler = s.HupHandler
std.intHandler = s.IntHandler
std.termHandler = s.TermHandler
std.quitHandler = s.QuitHandler
std.usr1Handler = s.Usr1Handler
std.usr2Handler = s.Usr2Handler
}
// Run process in background
func Background() error {
_, err := std.daemon()
if os.Getenv(childenv) == "" {
if err != nil {
return err
}
// Parrent exit
std.Exit(0)
}
return nil
}
// Set Enviroment value
func SetEnv(name, value string) {
std.setEnv(name, value)
}
// Get Enviroment value
func GetEnv(name string) string {
return std.getEnv(name)
}
// Add Argv to new Proccess
func SetArg(arg string) {
std.args = append(std.args, arg)
}
// Get current Args
func GetArgs() []string {
return os.Args
}
// Get current process pid
func GetPid() int {
return std.pid
}
// Is daemon chrooted ?
func IsChroot() bool {
return std.chroot != ""
}
// Create new Daemon
func new() *daemon {
procName := os.Getenv(childenv)
std := &daemon{
name: procName,
pid: os.Getpid(),
args: os.Args[:1],
env: os.Environ(),
signal: make(chan os.Signal),
}
go func(d *daemon) {
// Add Signal Notifiers
signal.Notify(d.signal, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)
signal.Notify(d.signal, syscall.SIGHUP, syscall.SIGUSR1, syscall.SIGUSR2)
for {
select {
case recv := <-d.signal:
switch recv {
case syscall.SIGINT:
if d.intHandler != nil {
if d.intHandler(SIGINT) {
continue
}
}
// Default behavior
d.Exit(2)
case syscall.SIGTERM:
if d.termHandler != nil {
if d.termHandler(SIGTERM) {
continue
}
}
// Default behavior
d.Exit(0)
case syscall.SIGHUP:
if d.hupHandler != nil {
if d.hupHandler(SIGHUP) {
continue
}
}
// Default behavior
d.Exit(2)
case syscall.SIGUSR1:
if d.usr1Handler != nil {
d.usr1Handler(SIGUSR1)
}
case syscall.SIGUSR2:
if d.usr2Handler != nil {
d.usr2Handler(SIGUSR2)
}
case syscall.SIGQUIT:
if d.quitHandler != nil {
if d.quitHandler(SIGQUIT) {
continue
}
}
// Default behavior
d.Exit(2)
}
}
}
}(std)
return std
}
// Set Enviroment value
func (d *daemon) setEnv(name, value string) {
if os.Getenv(childenv) == "" {
d.env = append(d.env, fmt.Sprintf("%s=%s", name, value))
} else if name != childenv {
os.Setenv(name, value)
}
}
// Return Enviroment value
func (d *daemon) getEnv(name string) string {
return os.Getenv(name)
}
// New daemon
func (d *daemon) daemon() (int, error) {
// Code executed by parrent
if os.Getenv(childenv) == "" {
// Get process name
self, err := os.Readlink(self_proc)
if err != nil {
return 0, err
}
d.name = filepath.Base(self)
// Setup child env
d.setEnv(childenv, d.name)
// Process Attributes
sysProcAttr := &syscall.SysProcAttr{
Chroot: d.chroot,
Setsid: true,
}
// Switch User / Group
if d.uid > 0 || d.gid > 0 {
sysProcAttr.Credential = &syscall.Credential{
Uid: d.uid,
Gid: d.gid,
}
}
// Start Proccess
proc, err := os.StartProcess(
d.name,
d.args,
&os.ProcAttr{
Dir: filepath.Dir(self),
Env: d.env,
Files: []*os.File{
nil,
os.Stdout, // Read Stdout from child (debug)
os.Stderr, // Read Stderr from child (debug)
},
Sys: sysProcAttr,
},
)
if err != nil {
return 0, err
}
return proc.Pid, nil
}
// Code executed by child
if d.pidfile != "" {
pf, err := os.OpenFile(d.pidfile, os.O_CREATE|os.O_RDWR, os.FileMode(0640))
if err != nil {
return d.pid, err
}
defer pf.Close()
_, err = pf.Write([]byte(fmt.Sprintf("%d", d.pid)))
}
return d.pid, nil
}
func (d *daemon) Exit(code int) {
// Make sure we are child process
if os.Getenv(childenv) != "" && d.pidfile != "" {
_, err := os.Stat(d.pidfile)
if err == nil {
os.Remove(d.pidfile)
}
}
os.Exit(code)
}
package daemon
import (
"os"
)
func c() {
_ = os.Stderr
}
package logger
import (
"bytes"
"fmt"
"io"
"os"
"runtime"
"time"
)
func NewEntry(depth int, logger *Logger) *Entry {
return &Entry{
Logger: logger,
Level: logger.Level,
depth: depth,
}
}
func (entry *Entry) log(level Level, msg string) {
entry.Time = time.Now()
entry.Level = level
entry.Message = msg
//entry.File =
_, file, line, ok := runtime.Caller(entry.depth)
if !ok {
file = "???"
line = 0
}
short := file
for i := len(file) - 1; i > 0; i-- {
if file[i] == '/' {
short = file[i+1:]
break
}
}
entry.File = short
entry.Line = line
reader, err := entry.Reader()
if err != nil {
entry.Logger.mu.Lock()
fmt.Fprintf(os.Stderr, "Failed to obtain reader, %v\n", err)
entry.Logger.mu.Unlock()
}
entry.Logger.mu.Lock()
defer entry.Logger.mu.Unlock()
_, err = io.Copy(entry.Logger.Out, reader)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", err)
}
// To avoid Entry#log() returning a value that only would make sense for
// panic() to use in Entry#Panic(), we avoid the allocation by checking
// directly here.
if level <= PanicLevel {
panic(entry)
}
}
// Returns a reader for the entry, which is a proxy to the formatter.
func (entry *Entry) Reader() (*bytes.Buffer, error) {
serialized, err := entry.Logger.Formatter.Format(entry)
return bytes.NewBuffer(serialized), err
}
func (entry *Entry) Debug(args ...interface{}) {
if entry.Level >= DebugLevel {
entry.log(DebugLevel, fmt.Sprint(args...))
}
}
func (entry *Entry) Print(args ...interface{}) {
entry.Info(args...)
}
func (entry *Entry) Info(args ...interface{}) {
if entry.Level >= InfoLevel {
entry.log(InfoLevel, fmt.Sprint(args...))
}
}
func (entry *Entry) Warn(args ...interface{}) {
if entry.Level >= WarnLevel {
entry.log(WarnLevel, fmt.Sprint(args...))
}
}
func (entry *Entry) Warning(args ...interface{}) {
entry.Warn(args...)
}
func (entry *Entry) Error(args ...interface{}) {
if entry.Level >= ErrorLevel {
entry.log(ErrorLevel, fmt.Sprint(args...))
}
}
func (entry *Entry) Fatal(args ...interface{}) {
if entry.Level >= FatalLevel {
entry.log(FatalLevel, fmt.Sprint(args...))
}
os.Exit(1)
}
func (entry *Entry) Panic(args ...interface{}) {
if entry.Level >= PanicLevel {
entry.log(PanicLevel, fmt.Sprint(args...))
}
panic(fmt.Sprint(args...))
}
// Entry Printf family functions
func (entry *Entry) Debugf(format string, args ...interface{}) {
if entry.Level >= DebugLevel {
entry.Debug(fmt.Sprintf(format, args...))
}
}
func (entry *Entry) Infof(format string, args ...interface{}) {
if entry.Level >= InfoLevel {
entry.Info(fmt.Sprintf(format, args...))
}
}
func (entry *Entry) Printf(format string, args ...interface{}) {
entry.Infof(format, args...)
}
func (entry *Entry) Warnf(format string, args ...interface{}) {
if entry.Level >= WarnLevel {
entry.Warn(fmt.Sprintf(format, args...))
}
}
func (entry *Entry) Warningf(format string, args ...interface{}) {
entry.Warnf(format, args...)
}
func (entry *Entry) Errorf(format string, args ...interface{}) {
if entry.Level >= ErrorLevel {
entry.Error(fmt.Sprintf(format, args...))
}
}
func (entry *Entry) Fatalf(format string, args ...interface{}) {
if entry.Level >= FatalLevel {
entry.Fatal(fmt.Sprintf(format, args...))
}
os.Exit(1)
}
func (entry *Entry) Panicf(format string, args ...interface{}) {
if entry.Level >= PanicLevel {
entry.Panic(fmt.Sprintf(format, args...))
}
}
package logger
import (
"bytes"
"encoding/json"
"fmt"
)
func (f *LogFormatter) Format(entry *Entry) ([]byte, error) {
timestampFormat := f.TimestampFormat
if timestampFormat == "" {
timestampFormat = DefaultTimestampFormat
}
b := &bytes.Buffer{}
fmt.Fprintf(b, "%s %s:%d [%s] \"%s\"",
entry.Time.Format(timestampFormat),
entry.File,
entry.Line,
entry.Level.String(),
entry.Message,
)
b.WriteByte('\n')
return b.Bytes(), nil
}
func (f *JSONFormatter) Format(entry *Entry) ([]byte, error) {
data := make(Fields, 3)
timestampFormat := f.TimestampFormat
if timestampFormat == "" {
timestampFormat = DefaultTimestampFormat
}
data["time"] = entry.Time.Format(timestampFormat)
data["msg"] = entry.Message
data["level"] = entry.Level.String()
serialized, err := json.Marshal(data)
if err != nil {
return nil, fmt.Errorf("Failed to marshal fields to JSON, %v", err)
}
return append(serialized, '\n'), nil
}
package logger
import (
"io"
"os"
"sync"
"time"
)
var (
std = New()
)
type (
Level uint8
Logger struct {
Out io.Writer
Level Level
Formatter Formatter
mu sync.Mutex
}
Fields map[string]interface{}
Entry struct {
Logger *Logger
Time time.Time
Level Level
File string
Line int
depth int
Message string
}
Formatter interface {
Format(*Entry) ([]byte, error)
}
LogFormatter struct {
TimestampFormat string
}
JSONFormatter struct {
TimestampFormat string
}
)
const (
PanicLevel Level = iota
FatalLevel
ErrorLevel
WarnLevel
InfoLevel
DebugLevel
// Default Format
DefaultTimestampFormat = time.Stamp
)
// Convert the Level to a string. E.g. PanicLevel becomes "panic".
func (level Level) String() string {
switch level {
case DebugLevel:
return "debug"
case InfoLevel:
return "info"
case WarnLevel:
return "warning"
case ErrorLevel:
return "error"
case FatalLevel:
return "fatal"
case PanicLevel:
return "panic"
}
return "unknown"
}
func New() *Logger {
return &Logger{
Out: os.Stderr,
Formatter: new(LogFormatter),
Level: InfoLevel,
}
}
// SetOutput sets the standard logger output.
func SetOutput(out io.Writer) {
std.SetOutput(out)
}
func OpenLog(logfile string) (*os.File, error) {
return OpenLogm(logfile, 0640)
}
func OpenLogm(logfile string, mode uint32) (*os.File, error) {
lf, err := os.OpenFile(
logfile,
os.O_CREATE|os.O_APPEND|os.O_RDWR,
os.FileMode(mode),
)
if err != nil {
return nil, err
}
return lf, nil
}
// SetFormatter sets the standard logger formatter.
func SetFormatter(formatter Formatter) {
std.SetFormatter(formatter)
}
// SetLevel sets the standard logger level.
func SetLevel(level Level) {
std.SetLevel(level)
}
// GetLevel returns the standard logger level.
func GetLevel() Level {
return std.GetLevel()
}
// Debug logs a message at level Debug on the standard logger.
func Debug(args ...interface{}) {
std.Debug(args...)
}
// Print logs a message at level Info on the standard logger.
func Print(args ...interface{}) {
std.Print(args...)
}
// Info logs a message at level Info on the standard logger.
func Info(args ...interface{}) {
std.Info(args...)
}
// Warn logs a message at level Warn on the standard logger.
func Warn(args ...interface{}) {
std.Warn(args...)
}
// Warning logs a message at level Warn on the standard logger.
func Warning(args ...interface{}) {
std.Warning(args...)
}
// Error logs a message at level Error on the standard logger.
func Error(args ...interface{}) {
std.Error(args...)
}
// Panic logs a message at level Panic on the standard logger.
func Panic(args ...interface{}) {
std.Panic(args...)
}
// Fatal logs a message at level Fatal on the standard logger.
func Fatal(args ...interface{}) {
std.Fatal(args...)
}
// Debugf logs a message at level Debug on the standard logger.
func Debugf(format string, args ...interface{}) {
std.Debugf(format, args...)
}
// Printf logs a message at level Info on the standard logger.
func Printf(format string, args ...interface{}) {
std.Printf(format, args...)
}
// Infof logs a message at level Info on the standard logger.
func Infof(format string, args ...interface{}) {
std.Infof(format, args...)
}
// Warnf logs a message at level Warn on the standard logger.
func Warnf(format string, args ...interface{}) {
std.Warnf(format, args...)
}
// Warningf logs a message at level Warn on the standard logger.
func Warningf(format string, args ...interface{}) {
std.Warningf(format, args...)
}
// Errorf logs a message at level Error on the standard logger.
func Errorf(format string, args ...interface{}) {
std.Errorf(format, args...)
}
// Panicf logs a message at level Panic on the standard logger.
func Panicf(format string, args ...interface{}) {
std.Panicf(format, args...)
}
// Fatalf logs a message at level Fatal on the standard logger.
func Fatalf(format string, args ...interface{}) {
std.Fatalf(format, args...)
}
func (logger *Logger) SetOutput(out io.Writer) {
logger.mu.Lock()
defer logger.mu.Unlock()
logger.Out = out
}
func (logger *Logger) SetFormatter(formatter Formatter) {
logger.mu.Lock()
defer logger.mu.Unlock()
logger.Formatter = formatter
}
func (logger *Logger) SetLevel(level Level) {
logger.mu.Lock()
defer logger.mu.Unlock()
logger.Level = level
}
func (logger *Logger) GetLevel() Level {
logger.mu.Lock()
defer logger.mu.Unlock()
return logger.Level
}
func (logger *Logger) Debugf(format string, args ...interface{}) {
if logger.Level >= DebugLevel {
NewEntry(5, logger).Debugf(format, args...)
}
}
func (logger *Logger) Infof(format string, args ...interface{}) {
if logger.Level >= InfoLevel {
NewEntry(5, logger).Infof(format, args...)
}
}
func (logger *Logger) Printf(format string, args ...interface{}) {
NewEntry(6, logger).Printf(format, args...)
}
func (logger *Logger) Warnf(format string, args ...interface{}) {
if logger.Level >= WarnLevel {
NewEntry(5, logger).Warnf(format, args...)
}
}
func (logger *Logger) Warningf(format string, args ...interface{}) {
if logger.Level >= WarnLevel {
NewEntry(5, logger).Warnf(format, args...)
}
}
func (logger *Logger) Errorf(format string, args ...interface{}) {
if logger.Level >= ErrorLevel {
NewEntry(5, logger).Errorf(format, args...)
}
}
func (logger *Logger) Fatalf(format string, args ...interface{}) {
if logger.Level >= FatalLevel {
NewEntry(5, logger).Fatalf(format, args...)
}
os.Exit(1)
}
func (logger *Logger) Panicf(format string, args ...interface{}) {
if logger.Level >= PanicLevel {
NewEntry(4, logger).Panicf(format, args...)
}
}
func (logger *Logger) Debug(args ...interface{}) {
if logger.Level >= DebugLevel {
NewEntry(4, logger).Debug(args...)
}
}
func (logger *Logger) Info(args ...interface{}) {
if logger.Level >= InfoLevel {
NewEntry(4, logger).Info(args...)
}
}
func (logger *Logger) Print(args ...interface{}) {
NewEntry(4, logger).Info(args...)
}
func (logger *Logger) Warn(args ...interface{}) {
if logger.Level >= WarnLevel {
NewEntry(4, logger).Warn(args...)
}
}
func (logger *Logger) Warning(args ...interface{}) {
if logger.Level >= WarnLevel {
NewEntry(4, logger).Warn(args...)
}
}
func (logger *Logger) Error(args ...interface{}) {
if logger.Level >= ErrorLevel {
NewEntry(4, logger).Error(args...)
}
}
func (logger *Logger) Fatal(args ...interface{}) {
if logger.Level >= FatalLevel {
NewEntry(4, logger).Fatal(args...)
}
os.Exit(1)
}
func (logger *Logger) Panic(args ...interface{}) {
if logger.Level >= PanicLevel {
NewEntry(4, logger).Panic(args...)
}
}
package cfg
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"regexp"
"strconv"
"strings"
)
// ConfigFile is the representation of configuration settings.
// The public interface is entirely through methods.
type ConfigFile struct {
data map[string]map[string]string // Maps sections to options to values.
}
var (
DefaultSection = "default" // Default section name (must be lower-case).
// Maximum allowed depth when recursively substituing variable names.
DepthValues = 200
// Strings accepted as bool.
BoolStrings = map[string]bool{
"0": false,
"1": true,
"f": false,
"false": false,
"n": false,
"no": false,
"off": false,
"on": true,
"t": true,
"true": true,
"y": true,
"yes": true,
}
varRegExp = regexp.MustCompile(`%\(([a-zA-Z0-9_.\-]+)\)s`)
)
// AddSection adds a new section to the configuration.
// It returns true if the new section was inserted, and false if the section
// already existed.
func (c *ConfigFile) AddSection(section string) bool {
section = strings.ToLower(section)
if _, ok := c.data[section]; ok {
return false
}
c.data[section] = make(map[string]string)
return true
}
// RemoveSection removes a section from the configuration.
// It returns true if the section was removed, and false if section did not
// exist.
func (c *ConfigFile) RemoveSection(section string) bool {
section = strings.ToLower(section)
switch _, ok := c.data[section]; {
case !ok:
return false
case section == DefaultSection:
return false // default section cannot be removed
default:
for o, _ := range c.data[section] {
delete(c.data[section], o)
}
delete(c.data, section)
}
return true
}
// AddOption adds a new option and value to the configuration.
// It returns true if the option and value were inserted, and false if the
// value was overwritten.
// If the section does not exist in advance, it is created.
func (c *ConfigFile) AddOption(section string, option string, value string) bool {
c.AddSection(section) // make sure section exists
section = strings.ToLower(section)
option = strings.ToLower(option)
_, ok := c.data[section][option]
c.data[section][option] = value
return !ok
}
// RemoveOption removes a option and value from the configuration.
// It returns true if the option and value were removed, and false otherwise,
// including if the section did not exist.
func (c *ConfigFile) RemoveOption(section string, option string) bool {
section = strings.ToLower(section)
option = strings.ToLower(option)
if _, ok := c.data[section]; !ok {
return false
}
_, ok := c.data[section][option]
delete(c.data[section], option)
return ok
}
// NewConfigFile creates an empty configuration representation.
// This representation can be filled with AddSection and AddOption and then
// saved to a file using WriteConfigFile.
func NewConfig() *ConfigFile {
c := new(ConfigFile)
c.data = make(map[string]map[string]string)
c.AddSection(DefaultSection) // default section always exists
return c
}
func stripComments(l string) string {
// comments are preceded by space or TAB
for _, c := range []string{" ;", "\t;", " #", "\t#"} {
if i := strings.Index(l, c); i != -1 {
l = l[0:i]
}
}
return l
}
func firstIndex(s string, delim []byte) int {
for i := 0; i < len(s); i++ {
for j := 0; j < len(delim); j++ {
if s[i] == delim[j] {
return i
}
}
}
return -1
}
func (c *ConfigFile) read(buf *bufio.Reader) error {
var section, option string
for {
l, err := buf.ReadString('\n') // parse line-by-line
if err == io.EOF {
if len(l) == 0 {
break
}
} else if err != nil {
return err
}
l = strings.TrimSpace(l)
// switch written for readability (not performance)
switch {
case len(l) == 0: // empty line
continue
case l[0] == '#': // comment
continue
case l[0] == ';': // comment
continue
case len(l) >= 3 && strings.ToLower(l[0:3]) == "rem":
// comment (for windows users)
continue
case l[0] == '[' && l[len(l)-1] == ']': // new section
option = "" // reset multi-line value
section = strings.TrimSpace(l[1 : len(l)-1])
c.AddSection(section)
case section == "": // not new section and no section defined so far
return errors.New("Section not found: must start with section")
default: // other alternatives
i := firstIndex(l, []byte{'=', ':'})
switch {
case i > 0: // option and value
i := firstIndex(l, []byte{'=', ':'})
option = strings.TrimSpace(l[0:i])
value := strings.TrimSpace(stripComments(l[i+1:]))
c.AddOption(section, option, value)
case section != "" && option != "":
// continuation of multi-line value
prev, _ := c.GetRawString(section, option)
value := strings.TrimSpace(stripComments(l))
c.AddOption(section, option, prev+"\n"+value)
default:
return errors.New(fmt.Sprintf("Could not parse line: %s", l))
}
}
}
return nil
}
// ReadConfigFile reads a file and returns a new configuration representation.
// This representation can be queried with GetString, etc.
func ReadFile(fname string) (*ConfigFile, error) {
file, err := os.Open(fname)
if err != nil {
return nil, err
}
c := NewConfig()
if err := c.read(bufio.NewReader(file)); err != nil {
return nil, err
}
if err := file.Close(); err != nil {
return nil, err
}
return c, nil
}
func (c *ConfigFile) write(buf *bufio.Writer, header string) error {
if header != "" {
_, err := buf.WriteString(fmt.Sprintf("# %s\n", header))
if err != nil {
return err
}
}
for section, sectionmap := range c.data {
if section == DefaultSection && len(sectionmap) == 0 {
continue // skip default section if empty
}
_, err := buf.WriteString(fmt.Sprintf("[%s]\n", section))
if err != nil {
return err
}
for option, value := range sectionmap {
_, err := buf.WriteString(fmt.Sprintf("%s = %s\n", option, value))
if err != nil {
return err
}
}
if _, err := buf.WriteString("\n"); err != nil {
return err
}
}
return nil
}
// WriteConfigFile saves the configuration representation to a file.
// The desired file permissions must be passed as in os.Open.
// The header is a string that is saved as a comment in the first line of the file.
func (c *ConfigFile) WriteFile(fname string, perm uint32, header string) error {
var file *os.File
file, err := os.OpenFile(fname, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(perm))
if err != nil {
return err
}
buf := bufio.NewWriter(file)
if err := c.write(buf, header); err != nil {
return err
}
buf.Flush()
return file.Close()
}
// GetSections returns the list of sections in the configuration.
// (The default section always exists.)
func (c *ConfigFile) GetSections() (sections []string) {
sections = make([]string, len(c.data))
i := 0
for s, _ := range c.data {
sections[i] = s
i++
}
return sections
}
// HasSection checks if the configuration has the given section.
// (The default section always exists.)
func (c *ConfigFile) HasSection(section string) bool {
_, ok := c.data[strings.ToLower(section)]
return ok
}
// GetOptions returns the list of options available in the given section.
// It returns an error if the section does not exist and an empty list if the
// section is empty.
// Options within the default section are also included.
func (c *ConfigFile) GetOptions(section string) ([]string, error) {
section = strings.ToLower(section)
if _, ok := c.data[section]; !ok {
return nil, errors.New(
fmt.Sprintf("Section not found: %s", section),
)
}
options := make([]string, len(c.data[DefaultSection])+len(c.data[section]))
i := 0
for s, _ := range c.data[DefaultSection] {
options[i] = s
i++
}
for s, _ := range c.data[section] {
options[i] = s
i++
}
return options, nil
}
// HasOption checks if the configuration has the given option in the section.
// It returns false if either the option or section do not exist.
func (c *ConfigFile) HasOption(section string, option string) bool {
section = strings.ToLower(section)
option = strings.ToLower(option)
if _, ok := c.data[section]; !ok {
return false
}
_, okd := c.data[DefaultSection][option]
_, oknd := c.data[section][option]
return okd || oknd
}
// GetRawString gets the (raw) string value for the given option in the
// section.
// The raw string value is not subjected to unfolding, which was illustrated
// in the beginning of this documentation.
// It returns an error if either the section or the option do not exist.
func (c *ConfigFile) GetRawString(section string, option string) (string, error) {
section = strings.ToLower(section)
option = strings.ToLower(option)
if _, ok := c.data[section]; ok {
if value, ok := c.data[section][option]; ok {
return value, nil
}
return "", errors.New(fmt.Sprintf("Option not found: %s", option))
}
return "", errors.New(fmt.Sprintf("Section not found: %s", section))
}
// GetString gets the string value for the given option in the section.
// If the value needs to be unfolded (see e.g. %(host)s example in the
// beginning of this documentation),
// then GetString does this unfolding automatically, up to DepthValues number
// of iterations.
// It returns an error if either the section or the option do not exist, or
// the unfolding cycled.
func (c *ConfigFile) GetString(section string, option string) (string, error) {
value, err := c.GetRawString(section, option)
if err != nil {
return "", err
}
section = strings.ToLower(section)
var i int
for i = 0; i < DepthValues; i++ { // keep a sane depth
vr := varRegExp.FindStringSubmatchIndex(value)
if len(vr) == 0 {
break
}
noption := value[vr[2]:vr[3]]
noption = strings.ToLower(noption)
// search variable in default section
nvalue, _ := c.data[DefaultSection][noption]
if _, ok := c.data[section][noption]; ok {
nvalue = c.data[section][noption]
}
if nvalue == "" {
return "", errors.New(fmt.Sprintf("Option not found: %s", noption))
}
// substitute by new value and take off leading '%(' and trailing ')s'
value = value[0:vr[2]-2] + nvalue + value[vr[3]+2:]
}
if i == DepthValues {
return "",
errors.New(
fmt.Sprintf(
"Possible cycle while unfolding variables: max depth of %d reached",
strconv.Itoa(DepthValues),
),
)
}
return value, nil
}
// GetInt has the same behaviour as GetString but converts the response to int.
func (c *ConfigFile) GetInt64(section string, option string) (int64, error) {
sv, err := c.GetString(section, option)
if err != nil {
return 0, err
}
value, err := strconv.ParseInt(sv, 10, 64)
if err != nil {
return 0, err
}
return value, nil
}
// GetFloat has the same behaviour as GetString but converts the response to
// float.
func (c *ConfigFile) GetFloat(section string, option string) (float64, error) {
sv, err := c.GetString(section, option)
if err != nil {
return float64(0), err
}
value, err := strconv.ParseFloat(sv, 64)
if err != nil {
return float64(0), err
}
return value, nil
}
// GetBool has the same behaviour as GetString but converts the response to
// bool.
// See constant BoolStrings for string values converted to bool.
func (c *ConfigFile) GetBool(section string, option string) (bool, error) {
sv, err := c.GetString(section, option)
if err != nil {
return false, err
}
value, ok := BoolStrings[strings.ToLower(sv)]
if ok == false {
return false, errors.New(
fmt.Sprintf("Could not parse bool value: %s", sv),
)
}
return value, nil
}
package main
import (
"html/template"
"io"
"sync"
// log "os/logger"
)
type (
PageData struct {
Title string
JS []string
CSS []string
Content interface{}
}
HTML struct {
pages map[string]*template.Template
m sync.Mutex
}
)
func NewHTML() *HTML {
return &HTML{
pages: make(map[string]*template.Template),
}
}
func NewPage(title string, data interface{}) *PageData {
return &PageData{
Title: title,
CSS: []string{
"//maxcdn.bootstrapcdn.com/bootstrap/3.2.0/css/bootstrap.min.css",
"//maxcdn.bootstrapcdn.com/bootstrap/3.2.0/css/bootstrap-theme.min.css",
"/assets/css/main.css",
},
Content: data,
}
}
func (p *PageData) AddJS(name string) *PageData {
p.JS = append(p.JS, name)
return p
}
func (t *HTML) Execute(name string, w io.Writer, data *PageData) error {
t.m.Lock()
page, found := t.pages[name]
t.m.Unlock()
// New session or session expired
if !found {
var err error
page, err = t.load(name)
if err != nil {
return err
}
}
page.ExecuteTemplate(w, "layout", data)
return nil
}
// Load Template and Cache it
func (t *HTML) load(name string) (*template.Template, error) {
page, err := template.ParseFiles(
t.file("layout.html"),
t.file("views/"+name+".html"),
)
if err != nil {
return nil, err
}
t.m.Lock()
t.pages[name] = page
t.m.Unlock()
return page, nil
}
func (t *HTML) file(name string) string {
fullpath := server.c.http_root + name
return fullpath
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment