Commit 3496093b by Bogdan Ungureanu

GPSD First release

parents
all:
GOPATH=$(shell pwd) go build
\ No newline at end of file
[server]
port=8080
httpdocs=/home/bogdan/Projects/gpsd/public_html/
[database]
hostname=46.102.175.70
database=gpsd
username=gpsd
password=IzpZsFjCxeiKWndOu90w
package main
import (
"bufio"
"os"
)
func parse_logfile(file string) error {
fd, err := os.Open(file)
if err != nil {
return err
}
defer fd.Close()
scanner := bufio.NewScanner(fd)
scanner.Split(bufio.ScanLines)
adapter := &TK103{}
adapter.onPing = parseonPing
for scanner.Scan() {
line := scanner.Text()
line = line[44 : len(line)-1]
if line[:7] == "Message" {
data := line[9:]
adapter.Handle(data)
}
}
return nil
}
func parseonPing(devstring string, event EventData) {
if event.Active {
if event.Latitude > 0 && event.Longitude > 0 {
logdata := &Log{
Time: event.Time,
Did: 1,
Active: true,
Latitude: event.Latitude,
Longitude: event.Longitude,
Speed: int(event.Speed),
Angle: event.Angle,
}
server.db.Create(&logdata)
}
}
}
body {
padding-top: 70px;
}
#map-canvas {
border: 1px solid #000;
width: 100%;
height: 1024px;
}
.tmap {
width: 640px;
height: 480px;
margin: 0;
padding: 0;
}
\ No newline at end of file
{{ define "layout" }}
<!DOCTYPE html>
<head>
<title>{{ .Title }}</title>
{{range $css_link := .CSS}}
<link rel="stylesheet" href="{{ $css_link }}">{{ end }}
{{range $js_link := .JS}}
<script type="text/javascript" src="{{ $js_link }}"></script>{{ end }}
</head>
<body>
{{ template "content" .Content }}
</body>
</html>
{{ end }}
{{ define "content" }} {{ end}}
\ No newline at end of file
{{ define "content" }}
<!-- Navigation -->
<nav class="navbar navbar-inverse navbar-fixed-top" role="navigation">
<div class="container">
<!-- Brand and toggle get grouped for better mobile display -->
<div class="navbar-header">
<button type="button" class="navbar-toggle" data-toggle="collapse" data-target="#bs-example-navbar-collapse-1">
<span class="sr-only">Toggle navigation</span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
<span class="icon-bar"></span>
</button>
<a class="navbar-brand" href="#">Start Bootstrap</a>
</div>
<!-- Collect the nav links, forms, and other content for toggling -->
<div class="collapse navbar-collapse" id="bs-example-navbar-collapse-1">
<ul class="nav navbar-nav">
<li>
<a href="#">About</a>
</li>
<li>
<a href="#">Services</a>
</li>
<li>
<a href="/logout">Logout</a>
</li>
</ul>
</div>
<!-- /.navbar-collapse -->
</div>
<!-- /.container -->
</nav>
<div class="container">
<div class="row">
<div id="map-canvas"></div>
</div>
<div class="row">
<select name="date" id="event-map">
{{range $date := .LogEvents}} <option value="{{ $date }}">{{ $date }}</option> {{end}}
</select>
<button type="button" class="btn btn-success" onclick="show_track(1)">Track</button>
</div>
<script type="text/javascript">
var map;
var poly;
var bounds ;
function init() {
var mapCanvas = document.getElementById('map-canvas')
var mapOptions = {
center: new google.maps.LatLng(44.4333333, 26.1),
zoom: 7,
mapTypeId: google.maps.MapTypeId.ROADMAP,
};
map = new google.maps.Map(mapCanvas,mapOptions);
var polyOptions = {
strokeColor: '#000000',
strokeOpacity: 1.0,
strokeWeight: 3
};
poly = new google.maps.Polyline(polyOptions);
bounds = new google.maps.LatLngBounds();
poly.setMap(map);
//track(1);
/*
var flightPlanCoordinates = [
new google.maps.LatLng(37.772323, -122.214897),
new google.maps.LatLng(21.291982, -157.821856),
new google.maps.LatLng(-18.142599, 178.431),
new google.maps.LatLng(-27.46758, 153.027892)
];
var flightPath = new google.maps.Polyline({
path: flightPlanCoordinates,
geodesic: true,
strokeColor: '#FF0000',
strokeOpacity: 1.0,
strokeWeight: 2
});
flightPath.setMap(map);
}
*/
}
function show_track(id) {
var event = document.getElementById('event-map')
//var x = document.getElementById("mySelect").selectedIndex;
//var y = document.getElementById("mySelect").options;
//alert("Index: " + y[x].index + " is " + y[x].text);
track(id,event.options[event.selectedIndex].value)
}
function track (id,events) {
$.ajax({
url: "/track/" + id + "/" + events,
dataType: "json",
// if the data is succesfully found
success: function (data) {
var path = poly.getPath();
// Reset poly and bounds
for (var i = 0; i < path.length; i++) {
path.pop();
}
bounds = new google.maps.LatLngBounds();
for (i = 0, len = data.length; i < len; ++i) {
//console.log('got' + data[i].Latitude + data[i].Longitude);
var LatLng = new google.maps.LatLng(data[i].Latitude , data[i].Longitude);
path.push(LatLng);
bounds.extend(LatLng);
/*
var marker = new google.maps.Marker({
position: LatLng,
title: '#' + data[i].Id + "-"+ path.getLength() + " speed "+ data[i].Speed + "km/h",
map: map
});
*/
}
map.fitBounds(bounds);
}
})
}
google.maps.event.addDomListener(window, 'load', init);
</script>
<!-- /.container -->
{{ end}}
{{ define "content" }}
<div class="container">
<div id="loginbox" style="margin-top:50px;" class="mainbox col-md-6 col-md-offset-3 col-sm-8 col-sm-offset-2">
<div class="panel panel-info" >
<div class="panel-heading">
<div class="panel-title">Sign In</div>
<div style="float:right; font-size: 80%; position: relative; top:-10px"><a href="#">Forgot password?</a></div>
</div>
<div style="padding-top:30px" class="panel-body" >
<div style="display:none" id="login-alert" class="alert alert-danger col-sm-12"></div>
<form id="loginform" class="form-horizontal" role="form" method="POST" action="/login">
<div style="margin-bottom: 25px" class="input-group">
<span class="input-group-addon"><i class="glyphicon glyphicon-user"></i></span>
<input id="login-username" type="text" class="form-control" name="username" value="" placeholder="username or email">
</div>
<div style="margin-bottom: 25px" class="input-group">
<span class="input-group-addon"><i class="glyphicon glyphicon-lock"></i></span>
<input id="login-password" type="password" class="form-control" name="password" placeholder="password">
</div>
<div class="input-group">
<div class="checkbox">
<label>
<input id="login-remember" type="checkbox" name="remember" value="1"> Remember me
</label>
</div>
</div>
<div style="margin-top:10px" class="form-group">
<div class="col-sm-12 controls">
<button class="btn btn-success" type="submit">Login</button>
</div>
</div>
</form>
</div>
</div>
</div>
</div>
{{ end }}
\ No newline at end of file
package main
import (
_ "database/driver/mysql"
"database/orm/gorm"
"flag"
"net"
"net/httpd"
"os"
"os/daemon"
log "os/logger"
"sync"
"time"
)
var (
server = &GPSServer{
c: &GPSServerConfig{ // Default config
cfgfile: "/etc/gpsd.conf",
db_host: "localhost",
db_port: 3306,
db_db: "gpsd",
db_user: "root",
db_pass: "",
http_root: "public_html/",
http_port: 8080,
},
clients: make(map[string]*TDeviceInfo),
}
html = NewHTML()
sidname = "sid"
sidexpire = 1 * time.Hour
logfd *os.File
)
type (
// GPS Server - handle client Connections
GPSServer struct {
clients map[string]*TDeviceInfo
c *GPSServerConfig // server_config.go
parselog string
logfile string
db gorm.DB
mutex sync.Mutex
wait sync.WaitGroup
shutdown bool
}
// GPS Logger - handle data logging
EventData struct {
Active bool
AlarmCode int
AlarmMsg string
Time time.Time
Latitude float64
Longitude float64
Speed float64
Miles int64
Angle float64
Power bool // GPS Power
Acc bool // Motor Ignition
}
// Some Handlers
LoginHandler func(string) bool
EventHandler func(string, EventData)
// Tracking device - GPS
TDeviceInfo struct {
Id string
IpAddr string
LastSeen *time.Time // Last seen online
adapter DeviceAdapter
}
// SQL database Users
User struct {
Id int32 `sql:"type:int AUTO_INCREMENT",gorm:"primary_key"`
Name string `sql:"type:varchar(30)"`
Password string `sql:"type:varchar(60)"`
}
// SQL database Devices
Device struct {
Id int64 `sql:"type:int AUTO_INCREMENT;",gorm:"primary_key"`
Uid int64 `sql:"type:int"`
Imei string `sql:"type:varchar(30)"`
}
Presence struct {
id int64
time int64
}
// SQL database Logs
Log struct {
Id int64 `sql:"type:int AUTO_INCREMENT;",gorm:"primary_key"`
Time time.Time `sql:"DEFAULT:NULL"`
Did int64 `sql:"type:int;not null;"`
Active bool `sql:"type:boolean"`
Latitude float64 `sql:"type:float(9,6);not null;"`
Longitude float64 `sql:"type:float(9,6);not null;"`
Speed int `sql:"type:int;not null"`
Angle float64 `sql:"type:float(9,6);not null;"`
}
Session struct {
user string
}
)
func init() {
log.SetLevel(log.DebugLevel)
daemon.Setup(&daemon.Config{
//Uid: 1000,
//Gid: 1000,
//Chroot: "/home/gspd/",
Pidfile: "/var/run/gpsd.pid",
IntHandler: statsHandler,
//TermHandler: signals,
//QuitHandler: signals,
//Usr1Handler: signals,
//Usr2Handler: signals,
HupHandler: hupHandler,
})
flag.StringVar(&server.parselog, "p", "", "Description")
flag.Parse()
if server.parselog == "" {
err := daemon.Background()
if err != nil {
log.Error(err)
}
logfd, err := log.OpenLog("/var/log/gpsd.log")
if err != nil {
log.Errorf("Open log file error: %s", err)
return
}
log.SetOutput(logfd)
}
}
// Process SIGHUP Handler
func hupHandler(sig int) bool {
// Cancel standard behavior
return true
}
// Process SIGINT Handler
func statsHandler(sig int) bool {
log.Info("SIGINT received")
server.Shutdown()
// Cancel standard behavior
server.wait.Wait()
return false
}
func main() {
err := server.c.LoadConfig()
if err != nil {
log.Fatalf("LoadConfig error : %s", err)
}
err = server.InitDatabase()
if err != nil {
log.Fatalf("Database config error : %s", err)
}
if server.parselog != "" {
err := parse_logfile(server.parselog)
if err != nil {
log.Errorf("parse_logfile error: %s", err)
}
return
}
// Start Listening on Ports
err = server.ListenTCP(":9090")
if err != nil {
log.Fatalf("ListenTCP error : %s", err)
}
router, err := gpsd_router()
if err != nil {
log.Error(err)
return
}
err = httpd.ListenAndServe(server.c.httpAddr(), router)
if err != nil {
log.Error(err)
return
}
}
// Server Listen TCP
func (s *GPSServer) ListenTCP(addr string) error {
l, err := net.Listen("tcp", addr)
if err != nil {
log.Error("Error listening: %s", err)
return err
}
log.Info("GPS service listening on " + addr)
go func(l net.Listener, addr string) {
s.wait.Add(1)
for {
// Listen for an incoming connection.
l.(*net.TCPListener).SetDeadline(time.Now().Add(100 * time.Millisecond))
conn, err := l.Accept()
// Shutdown request ?
if server.shutdown == true {
l.Close()
log.Infof("Shutdown listener socket %s", addr)
s.wait.Done()
return
}
if err != nil {
netErr, ok := err.(net.Error)
if ok && netErr.Timeout() && netErr.Temporary() {
continue
}
return
}
go server.handleTCP(conn)
}
}(l, addr)
return nil
}
// Handle New TCP connection
func (s *GPSServer) handleTCP(c net.Conn) {
d := &TDeviceInfo{}
for {
buf := make([]byte, 1024)
mlen, err := c.Read(buf)
if err != nil {
log.Errorf("Error reading: %s", err)
return
}
rcvd := string(buf)[:mlen]
if d.Id == "" {
log.Infof("New Connection from %s", c.RemoteAddr().String())
d, err = GetDeviceAdapter(rcvd, c)
if err != nil {
c.Close()
log.Errorf("No adapter found for device !")
return
}
// Give ServerHandlers to Device
d.onLogin(s.onLogin)
d.onAlarm(s.onAlarm)
d.onPing(s.onPing)
// Save to clients list
s.mutex.Lock()
s.clients[d.Id] = d
s.mutex.Unlock()
}
// Make sure addapter can Handle message
err = d.Handle(rcvd)
if err != nil {
log.Errorf("Unable to parse message %s", rcvd)
c.Close()
return
}
//
}
}
func (s *GPSServer) sqlDevice(device string) *Device {
devsql := &Device{}
server.db.Where("imei = ? ", device).First(&devsql)
if devsql.Id == 0 {
devsql.Imei = device
devsql.Uid = 1 // Default
server.db.Create(&devsql)
server.db.Where("imei = ? ", device).First(&devsql)
}
return devsql
}
// On device Login Request
func (s *GPSServer) onLogin(device string) bool {
return true
}
// On device Allarm Request
func (s *GPSServer) onAlarm(device string, event EventData) {
}
// Normal GPS Data
func (s *GPSServer) onPing(device string, event EventData) {
devsql := s.sqlDevice(device)
if event.Active {
// Send log to Database
logdata := &Log{
Time: event.Time,
Did: devsql.Id,
Active: true,
Latitude: event.Latitude,
Longitude: event.Longitude,
Speed: int(event.Speed),
Angle: event.Angle,
}
server.db.Create(&logdata)
}
}
// Shutdown Server
func (s *GPSServer) Shutdown() {
log.Info("Shutdown request via SIGTERM")
s.shutdown = true
s.wait.Wait()
}
// Initialize Database
func (s *GPSServer) InitDatabase() error {
db, err := gorm.Open("mysql", s.c.DbString())
if err != nil {
return err
}
s.db = db
db.AutoMigrate(&User{}, &Device{}, &Log{})
// Multiple column index
//db.Model(&User{}).AddIndex("idx_user_name_age", "name", "age")
return nil
}
package main
import (
"net"
log "os/logger"
"strings"
)
type (
//
// Interface for GPS adapters
//
DeviceAdapter interface {
Handle(string) error
Send(string, string)
OnLogin(LoginHandler)
OnAlarm(EventHandler)
OnPing(EventHandler)
}
// Implement DeviceAdapter interface
TK103 struct {
DeviceId string
c net.Conn
onLogin LoginHandler
onAlarm EventHandler
onPing EventHandler
}
)
// Handles incoming requests.
func (a *TK103) Handle(rcvd string) error {
cmd_start := strings.Index(rcvd, "B")
if cmd_start > 13 {
log.Error("DeviceId too long")
return nil
}
a.DeviceId = rcvd[1:cmd_start]
cmd := rcvd[cmd_start : cmd_start+4]
switch cmd {
case "BP05": // Login Request
if a.onLogin != nil {
if a.onLogin(a.DeviceId) == true {
a.Send("AP05", "")
}
return nil
}
// Auto accept login
a.Send("AP05", "")
case "BP00": // Handshake Request
a.Send("AP01", "HSO")
case "BR00": // Ping Request
res := a._parse(rcvd[cmd_start+4 : len(rcvd)-1])
if a.onPing != nil {
a.onPing(a.DeviceId, *res)
}
case "BO01": // Alarm Request
alrm := rcvd[cmd_start+4 : cmd_start+5]
res := a._parse(rcvd[cmd_start+5 : len(rcvd)-1])
code := -1
message := ""
switch alrm {
case "0":
code = 0
message = "Vehicle Power Off"
case "1":
code = 1
message = "The vehicle suffers an acciden"
case "2":
code = 2
message = "Driver sends a S.O.S."
case "3":
code = 3
message = "The alarm of the vehicle is activated"
case "4":
code = 4
message = "Vehicle is below the min speed setted"
case "5":
code = 5
message = "Vehicle is over the max speed setted"
case "6":
code = 6
message = "Out of geo fence"
}
res.AlarmCode = code
res.AlarmMsg = message
if a.onAlarm != nil {
a.onAlarm(a.DeviceId, *res)
}
a.Send("AS01", alrm)
default:
log.Infof("Unknown message from %s", a.DeviceId)
log.Printf("Message :%s", rcvd)
}
return nil
}
// Setup Login Handler
func (a *TK103) OnLogin(handler LoginHandler) {
a.onLogin = handler
}
// Setup Alarm Hanlder
func (a *TK103) OnAlarm(handler EventHandler) {
a.onAlarm = handler
}
// Setup Ping Handler
func (a *TK103) OnPing(handler EventHandler) {
a.onPing = handler
}
func (tk *TK103) _parse(data string) *EventData {
yy := data[0:2] // YYMMDD - Year
mm := data[2:4] // YYMMDD - Month
dd := data[4:6] // YYMMDD - Month
active := data[6:7] // The availability of GPS data
latstr := data[7:9] // Latitude - Minute
latmstr := data[9:16] // latitude - Seconds
latind := data[16:17] // Latitude indicator "N" or "S"
longstr := data[17:20] // Longitue - Minute
longmstr := data[20:27] // Longitude - Seconds
longind := data[27:28] // Longture indicator "E" or "V"
speedstr := data[28:33] // The unit is km/h
timeh := data[33:35] //HHMMSS - hours
timem := data[35:37] //HHMMSS - minutes
times := data[37:39] //HHMMSS - seconds
anglestr := data[39:44] // Orientation
//io := data[44:52] // IO State
//mil := data[52:53] // Milepost
//milstr := data[53:61] // Km
// Process Latitude
latitude, err := ToDecimalDegrees(latstr, latmstr)
if err != nil {
log.Errorf("Val: %s'%s,Err: %s", latstr, latmstr, err)
}
if latind == "S" {
latitude = -latitude
}
// Process Longitude
longitude, err := ToDecimalDegrees(longstr, longmstr)
if err != nil {
log.Errorf("Val: %s'%s,Err: %s", longstr, longmstr, err)
}
if longind == "W" {
longitude = -longitude
}
// Process Speed
speed, err := ToDecimal(speedstr)
if err != nil {
log.Errorf("Val: %s,Err: %s", speedstr, err)
}
// Process Time
time, err := FormatTimeShort(yy, mm, dd, timeh, timem, times)
if err != nil {
log.Errorf("Val: 20%s/%s/%s %s:%s:%s, Err: %s",
yy, mm, dd, timeh, timem, times, err)
}
// Process Orientation
angle, err := ToDecimal(anglestr)
if err != nil {
log.Errorf("Val: %s,Err: %s", anglestr, err)
}
return &EventData{
Active: (active == "A"),
Time: time,
Latitude: latitude,
Longitude: longitude,
Speed: speed,
Angle: angle,
}
}
func (tk *TK103) Send(cmd, data string) {
message := "(" + strings.Join([]string{tk.DeviceId, cmd, data}, "") + ")"
_, err := tk.c.Write([]byte(message))
if err != nil {
log.Error("Failed to send cmd " + cmd + " to " + tk.DeviceId)
}
}
package main
import (
"fmt"
"util/cfg"
)
//
// Server Configuration
//
type GPSServerConfig struct { // GPS Server Config
cfgfile string
db_user string
db_pass string
db_host string
db_port int
db_db string
http_addr string
http_port int
http_root string
}
// Load and parse config File
func (c *GPSServerConfig) LoadConfig() error {
conf, err := cfg.ReadFile(c.cfgfile)
if err != nil {
return err
}
if conf.HasOption("database", "hostname") {
c.db_host, err = conf.GetString("database", "hostname")
}
if conf.HasOption("database", "database") {
c.db_db, err = conf.GetString("database", "database")
}
if conf.HasOption("database", "username") {
c.db_user, err = conf.GetString("database", "username")
}
if conf.HasOption("database", "password") {
c.db_pass, err = conf.GetString("database", "password")
}
if conf.HasOption("server", "address") {
c.http_addr, err = conf.GetString("server", "address")
}
if conf.HasOption("server", "httpdocs") {
c.http_root, err = conf.GetString("server", "httpdocs")
}
return nil
}
func (c *GPSServerConfig) httpAddr() string {
return fmt.Sprintf("%s:%d", c.http_addr, c.http_port)
}
func (c *GPSServerConfig) DbString() string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8&parseTime=true",
c.db_user,
c.db_pass,
c.db_host,
c.db_port,
c.db_db,
)
}
package main
import (
"errors"
"net"
)
// Device Helpers
func (d *TDeviceInfo) Handle(msg string) error {
return d.adapter.Handle(msg)
}
func (d *TDeviceInfo) onLogin(handler LoginHandler) {
d.adapter.OnLogin(handler)
}
func (d *TDeviceInfo) onAlarm(handler EventHandler) {
d.adapter.OnAlarm(handler)
}
func (d *TDeviceInfo) onPing(handler EventHandler) {
d.adapter.OnPing(handler)
}
// Define all Known Parsers
func GetDeviceAdapter(msg string, conn net.Conn) (*TDeviceInfo, error) {
// TK102
if msg[:1] == "(" && msg[len(msg)-1:] == ")" {
if len(msg) > 13 && msg[13:14] == "B" {
device := &TDeviceInfo{
Id: msg[1:13],
IpAddr: conn.RemoteAddr().String(),
adapter: &TK103{c: conn},
}
return device, nil
}
}
return nil, errors.New("GetDeviceAdapter: No Adapter found")
}
package main
import (
"net/httpd"
"net/httpd/session"
log "os/logger"
"time"
)
func doLogin(username, password string) (*Session, bool) {
user := &User{}
server.db.Where("name = ? and password = ?", username, password).First(&user)
log.Printf("db : %x ", user)
if user.Id == 0 {
return nil, false
} else {
return &Session{
user: user.Name,
}, true
}
}
func AuthHandler(next httpd.Handler) httpd.Handler {
return httpd.HandlerFunc(func(c *httpd.Context) {
log.Printf("USER auth handler requested !")
s, err := c.Request.Cookie(sidname)
if err != nil {
//Setup new Session
ses := session.New()
session.Set(ses)
c.SetCookie(sidname, ses.Id(), sidexpire)
c.Redirect("/login", 302)
return
}
ses := session.Get(s.Value)
if ses == nil {
ses := session.New()
session.Set(ses)
c.SetCookie(sidname, ses.Id(), sidexpire)
c.Redirect("/login", 302)
return
}
// Handle login
if c.RequestURI == "/login" && c.RequestMethod == "POST" {
usr := c.Request.FormValue("username")
pwd := c.Request.FormValue("password")
sdata, login := doLogin(usr, pwd)
if login == false {
c.Redirect("/login", 302)
return
}
// Store the session
ses.Data = sdata
session.Set(ses)
c.Redirect("/", 302)
return
}
if ses.Data == nil && c.RequestURI != "/login" {
c.Redirect("/login", 302)
return
}
c.Session = ses
next.Handle(c)
})
}
func indexPage(c *httpd.Context) {
data := struct {
Test string
LogEvents []string
ProxyTotal int64
}{}
//SELECT DATE_FORMAT(`time`, '%Y-%m-%d') FROM `logs` WHERE 1 GROUP BY DATE_FORMAT(`time`,'%Y%m%d')
rows, err := server.db.Raw("SELECT DATE_FORMAT(`time`, '%Y-%m-%d') FROM `logs` WHERE 1 GROUP BY DATE_FORMAT( `time` , '%Y%m%d' ) ").Rows()
if err != nil {
log.Errorf("%s", err)
return
}
defer rows.Close()
for rows.Next() {
tm := ""
rows.Scan(&tm)
data.LogEvents = append(data.LogEvents, tm)
}
// Setup new Page
page := NewPage("Real Time GPS Tracker", data)
page.AddJS("https://maps.googleapis.com/maps/api/js?key=AIzaSyBn7yp_UcuCe5U1IZHKcfyJ5wJTMS8YIIM")
page.AddJS("https://ajax.googleapis.com/ajax/libs/jquery/2.1.4/jquery.min.js")
html.Execute("index",
c.Response,
page,
)
}
func loginPage(c *httpd.Context) {
err := html.Execute(
"login",
c.Response,
NewPage("Login", nil),
)
if err != nil {
log.Error(err)
}
}
// Track device
func trackHandler(c *httpd.Context) {
var logs []Log
device := c.Param("device")
date, err := ParseDate(c.Param("date"))
if err != nil {
log.Printf("trackHandler error parsing time %s", err)
c.Error("Invalid date", 400)
return
}
log := &Log{Time: date}
server.db.Order("time asc").Where("did = ? AND time > ? AND time < ? ", device, log.Time, log.Time.Add(24*time.Hour)).Find(&logs)
c.JSON(200, logs)
}
// Logout
func logoutPage(c *httpd.Context) {
sess := c.Session.(*session.Session)
sess.Data = nil
session.Set(sess)
log.Printf("logging out from sid %s", sess.Id())
c.Redirect("/login", 302)
}
// GPSD ROUTER
func gpsd_router() (*httpd.Router, error) {
http := httpd.NewRouter()
http.BeforeHandle(AuthHandler)
http.Get("/", indexPage)
http.Get("/track/:device/:date", trackHandler)
http.Get("/login", loginPage)
http.Get("/logout", logoutPage)
// Assets Handlers
assets := http.Subrouter("/assets")
/*
pwd, err := os.Getwd()
if err != nil {
//fmt.Println(err)
os.Exit(1)
}
*/
assets.ServeFiles("/*filepath", server.c.http_root+"assets/")
return http, nil
}
package main
import (
// log "os/logger"
"strconv"
"strings"
"time"
)
// Speed to Deciamal
// speed = float(speed)*1.852
//strconv.FormatFloat(input_num, 'f', 6, 64)
// Degree to Decimal
func ToDecimalDegrees(minute, seconds string) (float64, error) {
min, err := strconv.ParseFloat(minute, 64)
if err != nil {
return float64(0), err
}
sec, err := strconv.ParseFloat(seconds, 64)
if err != nil {
return min, err
}
// Format to 6 char precision
format := strconv.FormatFloat(min+sec/60, 'f', 6, 64)
decimal, err := strconv.ParseFloat(format, 64)
if err != nil {
return min, err
}
return decimal, nil
}
// Convert To Decimal
func ToDecimal(speed string) (float64, error) {
s, err := strconv.ParseFloat(speed, 64)
if err != nil {
return float64(0), err
}
return s, nil
}
func ParseDate(date string) (time.Time, error) {
dstr := strings.Split(date, "-")
loc, err := time.LoadLocation("Europe/Bucharest")
if err != nil {
return time.Now(), err
}
year, err := strconv.ParseInt(dstr[0], 10, 64)
if err != nil {
//return , err
}
month, err := strconv.ParseInt(dstr[1], 10, 64)
if err != nil {
//return , err
}
day, err := strconv.ParseInt(dstr[2], 10, 64)
if err != nil {
//return , err
}
return time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, loc), nil
}
// Format Time
func FormatTimeShort(yy, mm, dd, h, m, s string) (time.Time, error) {
year, err := strconv.ParseInt("20"+yy, 10, 64)
if err != nil {
//return , err
}
month, err := strconv.ParseInt(mm, 10, 64)
if err != nil {
//return , err
}
day, err := strconv.ParseInt(dd, 10, 64)
if err != nil {
//return , err
}
hour, err := strconv.ParseInt(h, 10, 64)
if err != nil {
//return , err
}
minute, err := strconv.ParseInt(m, 10, 64)
if err != nil {
//return , err
}
second, err := strconv.ParseInt(s, 10, 64)
if err != nil {
//return , err
}
t := time.Date(
int(year),
time.Month(int(month)),
int(day),
int(hour),
int(minute),
int(second),
0,
time.UTC).Add(time.Hour)
return t, nil
}
## Version 1.1 (2013-11-02)
Changes:
- Go-MySQL-Driver now requires Go 1.1
- Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore
- Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors
- `byte(nil)` is now treated as a NULL value. Before, it was treated like an empty string / `[]byte("")`
- DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'.
- Use the IO buffer also for writing. This results in zero allocations (by the driver) for most queries
- Optimized the buffer for reading
- stmt.Query now caches column metadata
- New Logo
- Changed the copyright header to include all contributors
- Improved the LOAD INFILE documentation
- The driver struct is now exported to make the driver directly accessible
- Refactored the driver tests
- Added more benchmarks and moved all to a separate file
- Other small refactoring
New Features:
- Added *old_passwords* support: Required in some cases, but must be enabled by adding `allowOldPasswords=true` to the DSN since it is insecure
- Added a `clientFoundRows` parameter: Return the number of matching rows instead of the number of rows changed on UPDATEs
- Added TLS/SSL support: Use a TLS/SSL encrypted connection to the server. Custom TLS configs can be registered and used
Bugfixes:
- Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification
- Convert to DB timezone when inserting `time.Time`
- Splitted packets (more than 16MB) are now merged correctly
- Fixed false positive `io.EOF` errors when the data was fully read
- Avoid panics on reuse of closed connections
- Fixed empty string producing false nil values
- Fixed sign byte for positive TIME fields
## Version 1.0 (2013-05-14)
Initial Release
# Contributing Guidelines
## Reporting Issues
Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/go-sql-driver/mysql/issues?state=open) or was [recently closed](https://github.com/go-sql-driver/mysql/issues?direction=desc&page=1&sort=updated&state=closed).
Please provide the following minimum information:
* Your Go-MySQL-Driver version (or git SHA)
* Your Go version (run `go version` in your console)
* A detailed issue description
* Error Log if present
* If possible, a short example
## Contributing Code
By contributing to this project, you share your code under the Mozilla Public License 2, as specified in the LICENSE file.
Don't forget to add yourself to the AUTHORS file.
### Pull Requests Checklist
Please check the following points before submitting your pull request:
- [x] Code compiles correctly
- [x] Created tests, if possible
- [x] All tests pass
- [x] Extended the README / documentation, if necessary
- [x] Added yourself to the AUTHORS file
### Code Review
Everyone is invited to review and comment on pull requests.
If it looks fine to you, comment with "LGTM" (Looks good to me).
If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes.
Before merging the Pull Request, at least one [team member](https://github.com/go-sql-driver?tab=members) must have commented with "LGTM".
## Development Ideas
If you are looking for ideas for code contributions, please check our [Development Ideas](https://github.com/go-sql-driver/mysql/wiki/Development-Ideas) Wiki page.
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"database/sql"
"strings"
"sync"
"sync/atomic"
"testing"
)
type TB testing.B
func (tb *TB) check(err error) {
if err != nil {
tb.Fatal(err)
}
}
func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB {
tb.check(err)
return db
}
func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows {
tb.check(err)
return rows
}
func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
tb.check(err)
return stmt
}
func initDB(b *testing.B, queries ...string) *sql.DB {
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
b.Fatalf("Error on %q: %v", query, err)
}
}
return db
}
const concurrencyLevel = 10
func BenchmarkQuery(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
var got string
tb.check(stmt.QueryRow(1).Scan(&got))
if got != "one" {
b.Errorf("query = %q; want one", got)
wg.Done()
return
}
}
}()
}
}
func BenchmarkExec(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := tb.checkDB(sql.Open("mysql", dsn))
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("DO 1"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
if _, err := stmt.Exec(); err != nil {
b.Fatal(err.Error())
}
}
}()
}
}
// data, but no db writes
var roundtripSample []byte
func initRoundtripBenchmarks() ([]byte, int, int) {
if roundtripSample == nil {
roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024))
}
return roundtripSample, 16, len(roundtripSample)
}
func BenchmarkRoundtripTxt(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
sampleString := string(sample)
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
b.StartTimer()
var result string
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sampleString[0:length]
rows := tb.checkRows(db.Query(`SELECT "` + test + `"`))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if result != test {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
func BenchmarkRoundtripBin(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT ?"))
defer stmt.Close()
b.StartTimer()
var result sql.RawBytes
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sample[0:length]
rows := tb.checkRows(stmt.Query(test))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if !bytes.Equal(result, test) {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import "io"
const defaultBufSize = 4096
// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
rd io.Reader
idx int
length int
}
func newBuffer(rd io.Reader) *buffer {
var b [defaultBufSize]byte
return &buffer{
buf: b[:],
rd: rd,
}
}
// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
// move existing data to the beginning
if b.length > 0 && b.idx > 0 {
copy(b.buf[0:b.length], b.buf[b.idx:])
}
// grow buffer if necessary
// TODO: let the buffer shrink again at some point
// Maybe keep the org buf slice and swap back?
if need > len(b.buf) {
// Round up to the next multiple of the default size
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
copy(newBuf, b.buf)
b.buf = newBuf
}
b.idx = 0
for {
n, err := b.rd.Read(b.buf[b.length:])
b.length += n
if err == nil {
if b.length < need {
continue
}
return nil
}
if b.length >= need && err == io.EOF {
return nil
}
return err
}
}
// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) {
if b.length < need {
// refill
if err := b.fill(need); err != nil {
return nil, err
}
}
offset := b.idx
b.idx += need
b.length -= need
return b.buf[offset:b.idx], nil
}
// returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
if b.length > 0 {
return nil
}
// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
}
return make([]byte, length)
}
// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
if b.length == 0 {
return b.buf[:length]
}
return nil
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
if b.length == 0 {
return b.buf
}
return nil
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"database/sql/driver"
"errors"
"net"
"strings"
"time"
)
type mysqlConn struct {
buf *buffer
netConn net.Conn
affectedRows uint64
insertId uint64
cfg *config
maxPacketAllowed int
maxWriteSize int
flags clientFlag
sequence uint8
parseTime bool
strict bool
}
type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
timeout time.Duration
tls *tls.Config
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
}
// Handles parameters set in DSN
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.params {
switch param {
// Charset
case "charset":
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
if err == nil {
break
}
}
if err != nil {
return
}
// time.Time parsing
case "parseTime":
var isBool bool
mc.parseTime, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Strict mode
case "strict":
var isBool bool
mc.strict, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Compression
case "compress":
err = errors.New("Compression not implemented yet")
return
// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
if err != nil {
return
}
}
}
return
}
func (mc *mysqlConn) Begin() (driver.Tx, error) {
if mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
err := mc.exec("START TRANSACTION")
if err == nil {
return &mysqlTx{mc}, err
}
return nil, err
}
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if mc.netConn != nil {
mc.writeCommandPacket(comQuit)
mc.netConn.Close()
mc.netConn = nil
}
mc.cfg = nil
mc.buf = nil
return
}
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
return nil, err
}
stmt := &mysqlStmt{
mc: mc,
}
// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if columnCount > 0 {
err = mc.readUntilEOF()
}
}
return stmt, err
}
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
mc.affectedRows = 0
mc.insertId = 0
err := mc.exec(query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, err
}
return nil, err
}
// with args, must use prepared stmt
return nil, driver.ErrSkip
}
// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err != nil {
return err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil && resLen > 0 {
if err = mc.readUntilEOF(); err != nil {
return err
}
err = mc.readUntilEOF()
}
return err
}
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
if mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) == 0 { // no args, fastpath
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen > 0 {
// Columns
rows.columns, err = mc.readColumns(resLen)
}
return rows, err
}
}
return nil, err
}
// with args, must use prepared stmt
return nil, driver.ErrSkip
}
// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
}
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
}
dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
}
}
return nil, err
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
const (
minProtocolVersion byte = 10
maxPacketSize = 1<<24 - 1
timeFormat = "2006-01-02 15:04:05"
)
// MySQL constants documentation:
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
const (
iOK byte = 0x00
iLocalInFile byte = 0xfb
iEOF byte = 0xfe
iERR byte = 0xff
)
type clientFlag uint32
const (
clientLongPassword clientFlag = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
clientNoSchema
clientCompress
clientODBC
clientLocalFiles
clientIgnoreSpace
clientProtocol41
clientInteractive
clientSSL
clientIgnoreSIGPIPE
clientTransactions
clientReserved
clientSecureConn
clientMultiStatements
clientMultiResults
)
const (
comQuit byte = iota + 1
comInitDB
comQuery
comFieldList
comCreateDB
comDropDB
comRefresh
comShutdown
comStatistics
comProcessInfo
comConnect
comProcessKill
comDebug
comPing
comTime
comDelayedInsert
comChangeUser
comBinlogDump
comTableDump
comConnectOut
comRegiserSlave
comStmtPrepare
comStmtExecute
comStmtSendLongData
comStmtClose
comStmtReset
comSetOption
comStmtFetch
)
const (
fieldTypeDecimal byte = iota
fieldTypeTiny
fieldTypeShort
fieldTypeLong
fieldTypeFloat
fieldTypeDouble
fieldTypeNULL
fieldTypeTimestamp
fieldTypeLongLong
fieldTypeInt24
fieldTypeDate
fieldTypeTime
fieldTypeDateTime
fieldTypeYear
fieldTypeNewDate
fieldTypeVarChar
fieldTypeBit
)
const (
fieldTypeNewDecimal byte = iota + 0xf6
fieldTypeEnum
fieldTypeSet
fieldTypeTinyBLOB
fieldTypeMediumBLOB
fieldTypeLongBLOB
fieldTypeBLOB
fieldTypeVarString
fieldTypeString
fieldTypeGeometry
)
type fieldFlag uint16
const (
flagNotNULL fieldFlag = 1 << iota
flagPriKey
flagUniqueKey
flagMultipleKey
flagBLOB
flagUnsigned
flagZeroFill
flagBinary
flagEnum
flagAutoIncrement
flagTimestamp
flagSet
flagUnknown1
flagUnknown2
flagUnknown3
flagUnknown4
)
const (
collation_ascii_general_ci byte = 11
collation_utf8_general_ci byte = 33
collation_utf8mb4_general_ci byte = 45
collation_utf8mb4_bin byte = 46
collation_latin1_general_ci byte = 48
collation_binary byte = 63
collation_utf8mb4_unicode_ci byte = 224
)
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "github.com/go-sql-driver/mysql"
//
// db, err := sql.Open("mysql", "user:password@/dbname")
//
// See https://github.com/go-sql-driver/mysql#usage for details
package mysql
import (
"database/sql"
"database/sql/driver"
"net"
)
// This struct is exported to make the driver directly accessible.
// In general the driver is used via the database/sql package.
type MySQLDriver struct{}
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error
// New mysqlConn
mc := &mysqlConn{
maxPacketAllowed: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
}
mc.cfg, err = parseDSN(dsn)
if err != nil {
return nil, err
}
// Connect to Server
nd := net.Dialer{Timeout: mc.cfg.timeout}
mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)
if err != nil {
return nil, err
}
mc.buf = newBuffer(mc.netConn)
// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
if err != nil {
mc.Close()
return nil, err
}
// Send Client Authentication Packet
if err = mc.writeAuthPacket(cipher); err != nil {
mc.Close()
return nil, err
}
// Read Result Packet
err = mc.readResultOK()
if err != nil {
// Retry with old authentication method, if allowed
if mc.cfg.allowOldPasswords && err == errOldPassword {
if err = mc.writeOldAuthPacket(cipher); err != nil {
mc.Close()
return nil, err
}
if err = mc.readResultOK(); err != nil {
mc.Close()
return nil, err
}
} else {
mc.Close()
return nil, err
}
}
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxPacketAllowed = stringToInt(maxap) - 1
if mc.maxPacketAllowed < maxPacketSize {
mc.maxWriteSize = mc.maxPacketAllowed
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
func init() {
sql.Register("mysql", &MySQLDriver{})
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"errors"
"fmt"
"io"
)
var (
errInvalidConn = errors.New("Invalid Connection")
errMalformPkt = errors.New("Malformed Packet")
errNoTLS = errors.New("TLS encryption requested but server does not support TLS")
errOldPassword = errors.New("This server only supports the insecure old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
errPktSync = errors.New("Commands out of sync. You can't run this command now")
errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
)
// error type which represents a single MySQL error
type MySQLError struct {
Number uint16
Message string
}
func (me *MySQLError) Error() string {
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
}
// error type which represents a group of one or more MySQL warnings
type MySQLWarnings []mysqlWarning
func (mws MySQLWarnings) Error() string {
var msg string
for i, warning := range mws {
if i > 0 {
msg += "\r\n"
}
msg += fmt.Sprintf("%s %s: %s", warning.Level, warning.Code, warning.Message)
}
return msg
}
// error type which represents a single MySQL warning
type mysqlWarning struct {
Level string
Code string
Message string
}
func (mc *mysqlConn) getWarnings() (err error) {
rows, err := mc.Query("SHOW WARNINGS", []driver.Value{})
if err != nil {
return
}
var warnings = MySQLWarnings{}
var values = make([]driver.Value, 3)
var warning mysqlWarning
var raw []byte
var ok bool
for {
err = rows.Next(values)
switch err {
case nil:
warning = mysqlWarning{}
if raw, ok = values[0].([]byte); ok {
warning.Level = string(raw)
} else {
warning.Level = fmt.Sprintf("%s", values[0])
}
if raw, ok = values[1].([]byte); ok {
warning.Code = string(raw)
} else {
warning.Code = fmt.Sprintf("%s", values[1])
}
if raw, ok = values[2].([]byte); ok {
warning.Message = string(raw)
} else {
warning.Message = fmt.Sprintf("%s", values[0])
}
warnings = append(warnings, warning)
case io.EOF:
return warnings
default:
rows.Close()
return
}
}
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"fmt"
"io"
"os"
"strings"
)
var (
fileRegister map[string]bool
readerRegister map[string]func() io.Reader
)
func init() {
fileRegister = make(map[string]bool)
readerRegister = make(map[string]func() io.Reader)
}
// RegisterLocalFile adds the given file to the file whitelist,
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
// Alternatively you can allow the use of all local files with
// the DSN parameter 'allowAllFiles=true'
//
// filePath := "/home/gopher/data.csv"
// mysql.RegisterLocalFile(filePath)
// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterLocalFile(filePath string) {
fileRegister[strings.Trim(filePath, `"`)] = true
}
// DeregisterLocalFile removes the given filepath from the whitelist.
func DeregisterLocalFile(filePath string) {
delete(fileRegister, strings.Trim(filePath, `"`))
}
// RegisterReaderHandler registers a handler function which is used
// to receive a io.Reader.
// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
// If the handler returns a io.ReadCloser Close() is called when the
// request is finished.
//
// mysql.RegisterReaderHandler("data", func() io.Reader {
// var csvReader io.Reader // Some Reader that returns CSV data
// ... // Open Reader here
// return csvReader
// })
// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
// if err != nil {
// ...
//
func RegisterReaderHandler(name string, handler func() io.Reader) {
readerRegister[name] = handler
}
// DeregisterReaderHandler removes the ReaderHandler function with
// the given name from the registry.
func DeregisterReaderHandler(name string) {
delete(readerRegister, name)
}
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
var rdr io.Reader
data := make([]byte, 4+mc.maxWriteSize)
if strings.HasPrefix(name, "Reader::") { // io.Reader
name = name[8:]
handler, inMap := readerRegister[name]
if handler != nil {
rdr = handler()
}
if rdr == nil {
if !inMap {
err = fmt.Errorf("Reader '%s' is not registered", name)
} else {
err = fmt.Errorf("Reader '%s' is <nil>", name)
}
}
} else { // File
name = strings.Trim(name, `"`)
if mc.cfg.allowAllFiles || fileRegister[name] {
rdr, err = os.Open(name)
} else {
err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
}
}
if rdc, ok := rdr.(io.ReadCloser); ok {
defer func() {
if err == nil {
err = rdc.Close()
} else {
rdc.Close()
}
}()
}
// send content packets
var ioErr error
if err == nil {
var n int
for err == nil && ioErr == nil {
n, err = rdr.Read(data[4:])
if n > 0 {
data[0] = byte(n)
data[1] = byte(n >> 8)
data[2] = byte(n >> 16)
data[3] = mc.sequence
ioErr = mc.writePacket(data[:4+n])
}
}
if err == io.EOF {
err = nil
}
if ioErr != nil {
errLog.Print(ioErr.Error())
return driver.ErrBadConn
}
}
// send empty packet (termination)
ioErr = mc.writePacket([]byte{
0x00,
0x00,
0x00,
mc.sequence,
})
if ioErr != nil {
errLog.Print(ioErr.Error())
return driver.ErrBadConn
}
// read OK packet
if err == nil {
return mc.readResultOK()
} else {
mc.readPacket()
}
return err
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlResult struct {
affectedRows int64
insertId int64
}
func (res *mysqlResult) LastInsertId() (int64, error) {
return res.insertId, nil
}
func (res *mysqlResult) RowsAffected() (int64, error) {
return res.affectedRows, nil
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
"io"
)
type mysqlField struct {
fieldType byte
flags fieldFlag
name string
}
type mysqlRows struct {
mc *mysqlConn
columns []mysqlField
}
type binaryRows struct {
mysqlRows
}
type textRows struct {
mysqlRows
}
func (rows *mysqlRows) Columns() []string {
columns := make([]string, len(rows.columns))
for i := range columns {
columns[i] = rows.columns[i].name
}
return columns
}
func (rows *mysqlRows) Close() error {
mc := rows.mc
if mc == nil {
return nil
}
if mc.netConn == nil {
return errInvalidConn
}
// Remove unread packets from stream
err := mc.readUntilEOF()
rows.mc = nil
return err
}
func (rows *binaryRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return errInvalidConn
}
// Fetch next row from stream
if err := rows.readRow(dest); err != io.EOF {
return err
}
rows.mc = nil
}
return io.EOF
}
func (rows *textRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if mc.netConn == nil {
return errInvalidConn
}
// Fetch next row from stream
if err := rows.readRow(dest); err != io.EOF {
return err
}
rows.mc = nil
}
return io.EOF
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"database/sql/driver"
)
type mysqlStmt struct {
mc *mysqlConn
id uint32
paramCount int
columns []mysqlField // cached from the first query
}
func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.netConn == nil {
errLog.Print(errInvalidConn)
return driver.ErrBadConn
}
err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
stmt.mc = nil
return err
}
func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
}
mc := stmt.mc
mc.affectedRows = 0
mc.insertId = 0
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err == nil {
if resLen > 0 {
// Columns
err = mc.readUntilEOF()
if err != nil {
return nil, err
}
// Rows
err = mc.readUntilEOF()
}
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
insertId: int64(mc.insertId),
}, nil
}
}
return nil, err
}
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
if stmt.mc.netConn == nil {
errLog.Print(errInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := stmt.writeExecutePacket(args)
if err != nil {
return nil, err
}
mc := stmt.mc
// Read Result
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
rows := new(binaryRows)
rows.mc = mc
if resLen > 0 {
// Columns
// If not cached, read them and cache them
if stmt.columns == nil {
rows.columns, err = mc.readColumns(resLen)
stmt.columns = rows.columns
} else {
rows.columns = stmt.columns
err = mc.readUntilEOF()
}
}
return rows, err
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
type mysqlTx struct {
mc *mysqlConn
}
func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return errInvalidConn
}
err = tx.mc.exec("COMMIT")
tx.mc = nil
return
}
func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.netConn == nil {
return errInvalidConn
}
err = tx.mc.exec("ROLLBACK")
tx.mc = nil
return
}
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"fmt"
"testing"
"time"
)
var testDSNs = []struct {
in string
out string
loc *time.Location
}{
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
}
func TestDSNParser(t *testing.T) {
var cfg *config
var err error
var res string
for i, tst := range testDSNs {
cfg, err = parseDSN(tst.in)
if err != nil {
t.Error(err.Error())
}
// pointer not static
cfg.tls = nil
res = fmt.Sprintf("%+v", cfg)
if res != fmt.Sprintf(tst.out, tst.loc) {
t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc))
}
}
}
func TestDSNParserInvalid(t *testing.T) {
var invalidDSNs = []string{
"@net(addr/", // no closing brace
"@tcp(/", // no closing brace
"tcp(/", // no closing brace
"(/", // no closing brace
"net(addr)//", // unescaped
//"/dbname?arg=/some/unescaped/path",
}
for i, tst := range invalidDSNs {
if _, err := parseDSN(tst); err == nil {
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
}
}
}
func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for _, tst := range testDSNs {
if _, err := parseDSN(tst.in); err != nil {
b.Error(err.Error())
}
}
}
}
func TestScanNullTime(t *testing.T) {
var scanTests = []struct {
in interface{}
error bool
valid bool
time time.Time
}{
{tDate, false, true, tDate},
{sDate, false, true, tDate},
{[]byte(sDate), false, true, tDate},
{tDateTime, false, true, tDateTime},
{sDateTime, false, true, tDateTime},
{[]byte(sDateTime), false, true, tDateTime},
{tDate0, false, true, tDate0},
{sDate0, false, true, tDate0},
{[]byte(sDate0), false, true, tDate0},
{sDateTime0, false, true, tDate0},
{[]byte(sDateTime0), false, true, tDate0},
{"", true, false, tDate0},
{"1234", true, false, tDate0},
{0, true, false, tDate0},
}
var nt = NullTime{}
var err error
for _, tst := range scanTests {
err = nt.Scan(tst.in)
if (err != nil) != tst.error {
t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil))
}
if nt.Valid != tst.valid {
t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid)
}
if nt.Time != tst.time {
t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
}
}
}
The MIT License (MIT)
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
package gorm
import (
"errors"
"fmt"
"reflect"
)
type Association struct {
Scope *Scope
PrimaryKey interface{}
Column string
Error error
Field *Field
}
func (association *Association) setErr(err error) *Association {
if err != nil {
association.Error = err
}
return association
}
func (association *Association) Find(value interface{}) *Association {
association.Scope.related(value, association.Column)
return association.setErr(association.Scope.db.Error)
}
func (association *Association) Append(values ...interface{}) *Association {
scope := association.Scope
field := association.Field
for _, value := range values {
reflectvalue := reflect.Indirect(reflect.ValueOf(value))
if reflectvalue.Kind() == reflect.Struct {
field.Set(reflect.Append(field.Field, reflectvalue))
} else if reflectvalue.Kind() == reflect.Slice {
field.Set(reflect.AppendSlice(field.Field, reflectvalue))
} else {
association.setErr(errors.New("invalid association type"))
}
}
scope.Search.Select(association.Column)
scope.callCallbacks(scope.db.parent.callback.updates)
return association.setErr(scope.db.Error)
}
func (association *Association) getPrimaryKeys(values ...interface{}) []interface{} {
primaryKeys := []interface{}{}
scope := association.Scope
for _, value := range values {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
if reflectValue.Kind() == reflect.Slice {
for i := 0; i < reflectValue.Len(); i++ {
if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank {
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
}
}
} else if reflectValue.Kind() == reflect.Struct {
if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank {
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
}
}
}
return primaryKeys
}
func (association *Association) Delete(values ...interface{}) *Association {
primaryKeys := association.getPrimaryKeys(values...)
if len(primaryKeys) == 0 {
association.setErr(errors.New("no primary key found"))
} else {
scope := association.Scope
relationship := association.Field.Relationship
// many to many
if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ {
value := association.Field.Field.Index(i)
if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil {
var included = false
for _, primaryKey := range primaryKeys {
if equalAsString(primaryKey, primaryField.Field.Interface()) {
included = true
}
}
if !included {
leftValues = reflect.Append(leftValues, value)
}
}
}
association.Field.Set(leftValues)
}
} else {
association.setErr(errors.New("delete only support many to many"))
}
}
return association
}
func (association *Association) Replace(values ...interface{}) *Association {
relationship := association.Field.Relationship
scope := association.Scope
if relationship.Kind == "many_to_many" {
field := association.Field.Field
oldPrimaryKeys := association.getPrimaryKeys(field.Interface())
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
association.Append(values...)
newPrimaryKeys := association.getPrimaryKeys(field.Interface())
var addedPrimaryKeys = []interface{}{}
for _, newKey := range newPrimaryKeys {
hasEqual := false
for _, oldKey := range oldPrimaryKeys {
if reflect.DeepEqual(newKey, oldKey) {
hasEqual = true
break
}
}
if !hasEqual {
addedPrimaryKeys = append(addedPrimaryKeys, newKey)
}
}
for _, primaryKey := range association.getPrimaryKeys(values...) {
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
}
if len(addedPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship))
}
} else {
association.setErr(errors.New("replace only support many to many"))
}
return association
}
func (association *Association) Clear() *Association {
relationship := association.Field.Relationship
scope := association.Scope
if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey)
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
} else {
association.setErr(err)
}
} else {
association.setErr(errors.New("clear only support many to many"))
}
return association
}
func (association *Association) Count() int {
count := -1
relationship := association.Field.Relationship
scope := association.Scope
newScope := scope.New(association.Field.Field.Interface())
if relationship.Kind == "many_to_many" {
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
if relationship.PolymorphicType != "" {
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
}
countScope.Count(&count)
} else if relationship.Kind == "belongs_to" {
if v, ok := scope.FieldByName(association.Column); ok {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
scope.DB().Table(newScope.TableName()).Where(whereSql, v).Count(&count)
}
}
return count
}
package gorm_test
import (
"fmt"
"testing"
)
func TestHasOneAndHasManyAssociation(t *testing.T) {
DB.DropTable(Category{})
DB.DropTable(Post{})
DB.DropTable(Comment{})
DB.CreateTable(Category{})
DB.CreateTable(Post{})
DB.CreateTable(Comment{})
post := Post{
Title: "post 1",
Body: "body 1",
Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
Category: Category{Name: "Category 1"},
MainCategory: Category{Name: "Main Category 1"},
}
if err := DB.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post")
}
if DB.First(&Category{}, "name = ?", "Category 1").Error != nil {
t.Errorf("Category should be saved")
}
var p Post
DB.First(&p, post.Id)
if post.CategoryId.Int64 == 0 || p.CategoryId.Int64 == 0 || post.MainCategoryId == 0 || p.MainCategoryId == 0 {
t.Errorf("Category Id should exist")
}
if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
t.Errorf("Comment 1 should be saved")
}
if post.Comments[0].PostId == 0 {
t.Errorf("Comment Should have post id")
}
var comment Comment
if DB.First(&comment, "content = ?", "Comment 2").Error != nil {
t.Errorf("Comment 2 should be saved")
}
if comment.PostId == 0 {
t.Errorf("Comment 2 Should have post id")
}
comment3 := Comment{Content: "Comment 3", Post: Post{Title: "Title 3", Body: "Body 3"}}
DB.Save(&comment3)
}
func TestRelated(t *testing.T) {
user := User{
Name: "jinzhu",
BillingAddress: Address{Address1: "Billing Address - Address 1"},
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
CreditCard: CreditCard{Number: "1234567890"},
Company: Company{Name: "company1"},
}
DB.Save(&user)
if user.CreditCard.ID == 0 {
t.Errorf("After user save, credit card should have id")
}
if user.BillingAddress.ID == 0 {
t.Errorf("After user save, billing address should have id")
}
if user.Emails[0].Id == 0 {
t.Errorf("After user save, billing address should have id")
}
var emails []Email
DB.Model(&user).Related(&emails)
if len(emails) != 2 {
t.Errorf("Should have two emails")
}
var emails2 []Email
DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
if len(emails2) != 1 {
t.Errorf("Should have two emails")
}
var user1 User
DB.Model(&user).Related(&user1.Emails)
if len(user1.Emails) != 2 {
t.Errorf("Should have only one email match related condition")
}
var address1 Address
DB.Model(&user).Related(&address1, "BillingAddressId")
if address1.Address1 != "Billing Address - Address 1" {
t.Errorf("Should get billing address from user correctly")
}
user1 = User{}
DB.Model(&address1).Related(&user1, "BillingAddressId")
if DB.NewRecord(user1) {
t.Errorf("Should get user from address correctly")
}
var user2 User
DB.Model(&emails[0]).Related(&user2)
if user2.Id != user.Id || user2.Name != user.Name {
t.Errorf("Should get user from email correctly")
}
var creditcard CreditCard
var user3 User
DB.First(&creditcard, "number = ?", "1234567890")
DB.Model(&creditcard).Related(&user3)
if user3.Id != user.Id || user3.Name != user.Name {
t.Errorf("Should get user from credit card correctly")
}
if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() {
t.Errorf("RecordNotFound for Related")
}
var company Company
if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" {
t.Errorf("RecordNotFound for Related")
}
}
func TestManyToMany(t *testing.T) {
DB.Raw("delete from languages")
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
user := User{Name: "Many2Many", Languages: languages}
DB.Save(&user)
// Query
var newLanguages []Language
DB.Model(&user).Related(&newLanguages, "Languages")
if len(newLanguages) != len([]string{"ZH", "EN"}) {
t.Errorf("Query many to many relations")
}
DB.Model(&user).Association("Languages").Find(&newLanguages)
if len(newLanguages) != len([]string{"ZH", "EN"}) {
t.Errorf("Should be able to find many to many relations")
}
if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) {
t.Errorf("Count should return correct result")
}
// Append
DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"})
if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() {
t.Errorf("New record should be saved when append")
}
languageA := Language{Name: "AA"}
DB.Save(&languageA)
DB.Model(&User{Id: user.Id}).Association("Languages").Append(languageA)
languageC := Language{Name: "CC"}
DB.Save(&languageC)
DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) {
t.Errorf("All appended languages should be saved")
}
// Delete
user.Languages = []Language{}
DB.Model(&user).Association("Languages").Find(&user.Languages)
var language Language
DB.Where("name = ?", "EE").First(&language)
DB.Model(&user).Association("Languages").Delete(language, &language)
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
t.Errorf("Relations should be deleted with Delete")
}
if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() {
t.Errorf("Language EE should not be deleted")
}
DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
user2 := User{Name: "Many2Many_User2", Languages: languages}
DB.Save(&user2)
DB.Model(&user).Association("Languages").Delete(languages, &languages)
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 {
t.Errorf("Relations should be deleted with Delete")
}
if DB.Model(&user2).Association("Languages").Count() == 0 {
t.Errorf("Other user's relations should not be deleted")
}
// Replace
var languageB Language
DB.Where("name = ?", "BB").First(&languageB)
DB.Model(&user).Association("Languages").Replace(languageB)
if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 {
t.Errorf("Relations should be replaced")
}
DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}})
if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) {
t.Errorf("Relations should be replaced")
}
// Clear
DB.Model(&user).Association("Languages").Clear()
if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
t.Errorf("Relations should be cleared")
}
}
func TestForeignKey(t *testing.T) {
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
for _, foreignKey := range []string{"UserId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
for _, foreignKey := range []string{"PostId"} {
if structField.Name == foreignKey && !structField.IsForeignKey {
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
}
}
}
}
package gorm
import (
"fmt"
)
type callback struct {
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
processors []*callbackProcessor
}
type callbackProcessor struct {
name string
before string
after string
replace bool
remove bool
typ string
processor *func(scope *Scope)
callback *callback
}
func (c *callback) addProcessor(typ string) *callbackProcessor {
cp := &callbackProcessor{typ: typ, callback: c}
c.processors = append(c.processors, cp)
return cp
}
func (c *callback) clone() *callback {
return &callback{
creates: c.creates,
updates: c.updates,
deletes: c.deletes,
queries: c.queries,
processors: c.processors,
}
}
func (c *callback) Create() *callbackProcessor {
return c.addProcessor("create")
}
func (c *callback) Update() *callbackProcessor {
return c.addProcessor("update")
}
func (c *callback) Delete() *callbackProcessor {
return c.addProcessor("delete")
}
func (c *callback) Query() *callbackProcessor {
return c.addProcessor("query")
}
func (c *callback) RowQuery() *callbackProcessor {
return c.addProcessor("row_query")
}
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
cp.before = name
return cp
}
func (cp *callbackProcessor) After(name string) *callbackProcessor {
cp.after = name
return cp
}
func (cp *callbackProcessor) Register(name string, fc func(scope *Scope)) {
cp.name = name
cp.processor = &fc
cp.callback.sort()
}
func (cp *callbackProcessor) Remove(name string) {
fmt.Printf("[info] removing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name
cp.remove = true
cp.callback.sort()
}
func (cp *callbackProcessor) Replace(name string, fc func(scope *Scope)) {
fmt.Printf("[info] replacing callback `%v` from %v\n", name, fileWithLineNum())
cp.name = name
cp.processor = &fc
cp.replace = true
cp.callback.sort()
}
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
var sortCallbackProcessor func(c *callbackProcessor)
var names, sortedNames = []string{}, []string{}
for _, cp := range cps {
if index := getRIndex(names, cp.name); index > -1 {
if !cp.replace && !cp.remove {
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
}
}
names = append(names, cp.name)
}
sortCallbackProcessor = func(c *callbackProcessor) {
if getRIndex(sortedNames, c.name) > -1 {
return
}
if len(c.before) > 0 {
if index := getRIndex(sortedNames, c.before); index > -1 {
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
} else if index := getRIndex(names, c.before); index > -1 {
sortedNames = append(sortedNames, c.name)
sortCallbackProcessor(cps[index])
} else {
sortedNames = append(sortedNames, c.name)
}
}
if len(c.after) > 0 {
if index := getRIndex(sortedNames, c.after); index > -1 {
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
} else if index := getRIndex(names, c.after); index > -1 {
cp := cps[index]
if len(cp.before) == 0 {
cp.before = c.name
}
sortCallbackProcessor(cp)
} else {
sortedNames = append(sortedNames, c.name)
}
}
if getRIndex(sortedNames, c.name) == -1 {
sortedNames = append(sortedNames, c.name)
}
}
for _, cp := range cps {
sortCallbackProcessor(cp)
}
var funcs = []*func(scope *Scope){}
var sortedFuncs = []*func(scope *Scope){}
for _, name := range sortedNames {
index := getRIndex(names, name)
if !cps[index].remove {
sortedFuncs = append(sortedFuncs, cps[index].processor)
}
}
for _, cp := range cps {
if sindex := getRIndex(sortedNames, cp.name); sindex == -1 {
if !cp.remove {
funcs = append(funcs, cp.processor)
}
}
}
return append(sortedFuncs, funcs...)
}
func (c *callback) sort() {
var creates, updates, deletes, queries, rowQueries []*callbackProcessor
for _, processor := range c.processors {
switch processor.typ {
case "create":
creates = append(creates, processor)
case "update":
updates = append(updates, processor)
case "delete":
deletes = append(deletes, processor)
case "query":
queries = append(queries, processor)
case "row_query":
rowQueries = append(rowQueries, processor)
}
}
c.creates = sortProcessors(creates)
c.updates = sortProcessors(updates)
c.deletes = sortProcessors(deletes)
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
}
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
package gorm
import (
"fmt"
"strings"
)
func BeforeCreate(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeCreate")
}
func UpdateTimeStampWhenCreate(scope *Scope) {
if !scope.HasError() {
now := NowFunc()
scope.SetColumn("CreatedAt", now)
scope.SetColumn("UpdatedAt", now)
}
}
func Create(scope *Scope) {
defer scope.Trace(NowFunc())
if !scope.HasError() {
// set create sql
var sqls, columns []string
fields := scope.Fields()
for _, field := range fields {
if scope.changeableField(field) {
if field.IsNormal {
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
if !field.IsBlank || !field.HasDefaultValue {
columns = append(columns, scope.Quote(field.DBName))
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
}
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) {
columns = append(columns, scope.Quote(relationField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
}
}
}
}
returningKey := "*"
primaryField := scope.PrimaryField()
if primaryField != nil {
returningKey = scope.Quote(primaryField.DBName)
}
if len(columns) == 0 {
scope.Raw(fmt.Sprintf("INSERT INTO %v DEFAULT VALUES %v",
scope.QuotedTableName(),
scope.Dialect().ReturningStr(scope.TableName(), returningKey),
))
} else {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) VALUES (%v) %v",
scope.QuotedTableName(),
strings.Join(columns, ","),
strings.Join(sqls, ","),
scope.Dialect().ReturningStr(scope.TableName(), returningKey),
))
}
// execute create sql
if scope.Dialect().SupportLastInsertId() {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err := result.LastInsertId()
if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected()
if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id))
}
}
}
} else {
if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
scope.db.RowsAffected = 1
} else {
scope.Err(err)
}
}
}
}
}
func AfterCreate(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterCreate")
scope.CallMethodWithErrorCheck("AfterSave")
}
func init() {
DefaultCallback.Create().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Create().Register("gorm:before_create", BeforeCreate)
DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
DefaultCallback.Create().Register("gorm:create", Create)
DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}
package gorm
import "fmt"
func BeforeDelete(scope *Scope) {
scope.CallMethodWithErrorCheck("BeforeDelete")
}
func Delete(scope *Scope) {
if !scope.HasError() {
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
scope.Raw(
fmt.Sprintf("UPDATE %v SET deleted_at=%v %v",
scope.QuotedTableName(),
scope.AddToVars(NowFunc()),
scope.CombinedConditionSql(),
))
} else {
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.QuotedTableName(), scope.CombinedConditionSql()))
}
scope.Exec()
}
}
func AfterDelete(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterDelete")
}
func init() {
DefaultCallback.Delete().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Delete().Register("gorm:before_delete", BeforeDelete)
DefaultCallback.Delete().Register("gorm:delete", Delete)
DefaultCallback.Delete().Register("gorm:after_delete", AfterDelete)
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}
package gorm
import (
"errors"
"fmt"
"reflect"
)
func Query(scope *Scope) {
defer scope.Trace(NowFunc())
var (
isSlice bool
isPtr bool
anyRecordFound bool
destType reflect.Type
)
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy))
}
}
var dest = scope.IndirectValue()
if value, ok := scope.Get("gorm:query_destination"); ok {
dest = reflect.Indirect(reflect.ValueOf(value))
}
if kind := dest.Kind(); kind == reflect.Slice {
isSlice = true
destType = dest.Type().Elem()
dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType))))
if destType.Kind() == reflect.Ptr {
isPtr = true
destType = destType.Elem()
}
} else if kind != reflect.Struct {
scope.Err(errors.New("unsupported destination, should be slice or struct"))
return
}
scope.prepareQuerySql()
if !scope.HasError() {
rows, err := scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
scope.db.RowsAffected = 0
if scope.Err(err) != nil {
return
}
defer rows.Close()
columns, _ := rows.Columns()
for rows.Next() {
scope.db.RowsAffected++
anyRecordFound = true
elem := dest
if isSlice {
elem = reflect.New(destType).Elem()
}
var values = make([]interface{}, len(columns))
fields := scope.New(elem.Addr().Interface()).Fields()
for index, column := range columns {
if field, ok := fields[column]; ok {
if field.Field.Kind() == reflect.Ptr {
values[index] = field.Field.Addr().Interface()
} else {
values[index] = reflect.New(reflect.PtrTo(field.Field.Type())).Interface()
}
} else {
var value interface{}
values[index] = &value
}
}
scope.Err(rows.Scan(values...))
for index, column := range columns {
value := values[index]
if field, ok := fields[column]; ok {
if field.Field.Kind() == reflect.Ptr {
field.Field.Set(reflect.ValueOf(value).Elem())
} else if v := reflect.ValueOf(value).Elem().Elem(); v.IsValid() {
field.Field.Set(v)
}
}
}
if isSlice {
if isPtr {
dest.Set(reflect.Append(dest, elem.Addr()))
} else {
dest.Set(reflect.Append(dest, elem))
}
}
}
if !anyRecordFound && !isSlice {
scope.Err(RecordNotFound)
}
}
}
func AfterQuery(scope *Scope) {
scope.CallMethodWithErrorCheck("AfterFind")
}
func init() {
DefaultCallback.Query().Register("gorm:query", Query)
DefaultCallback.Query().Register("gorm:after_query", AfterQuery)
DefaultCallback.Query().Register("gorm:preload", Preload)
}
package gorm
import "reflect"
func BeginTransaction(scope *Scope) {
scope.Begin()
}
func CommitOrRollbackTransaction(scope *Scope) {
scope.CommitOrRollback()
}
func SaveBeforeAssociations(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
if relationship.ForeignFieldName != "" {
scope.Err(scope.SetColumn(relationship.ForeignFieldName, scope.New(value.Addr().Interface()).PrimaryKeyValue()))
}
}
}
}
}
func SaveAfterAssociations(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil &&
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
value := field.Field
switch value.Kind() {
case reflect.Slice:
for i := 0; i < value.Len(); i++ {
newDB := scope.NewDB()
elem := value.Index(i).Addr().Interface()
newScope := newDB.NewScope(elem)
if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" {
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
}
scope.Err(newDB.Save(elem).Error)
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
}
}
default:
elem := value.Addr().Interface()
newScope := scope.New(elem)
if relationship.ForeignFieldName != "" {
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
}
if relationship.PolymorphicType != "" {
scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName()))
}
scope.Err(scope.NewDB().Save(elem).Error)
}
}
}
}
}
package gorm
import (
"reflect"
"runtime"
"strings"
"testing"
)
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
var names []string
for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
names = append(names, fnames[len(fnames)-1])
}
return reflect.DeepEqual(names, fnames)
}
func create(s *Scope) {}
func beforeCreate1(s *Scope) {}
func beforeCreate2(s *Scope) {}
func afterCreate1(s *Scope) {}
func afterCreate2(s *Scope) {}
func TestRegisterCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}}
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("before_create2", beforeCreate2)
callback.Create().Register("create", create)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Register("after_create2", afterCreate2)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback")
}
}
func TestRegisterCallbackWithOrder(t *testing.T) {
var callback1 = &callback{processors: []*callbackProcessor{}}
callback1.Create().Register("before_create1", beforeCreate1)
callback1.Create().Register("create", create)
callback1.Create().Register("after_create1", afterCreate1)
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
var callback2 = &callback{processors: []*callbackProcessor{}}
callback2.Update().Register("create", create)
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
callback2.Update().Register("after_create2", afterCreate2)
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
t.Errorf("register callback with order")
}
}
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
var callback1 = &callback{processors: []*callbackProcessor{}}
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
callback1.Query().Register("before_create1", beforeCreate1)
callback1.Query().Register("after_create1", afterCreate1)
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
t.Errorf("register callback with order")
}
var callback2 = &callback{processors: []*callbackProcessor{}}
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
callback2.Delete().Register("after_create1", afterCreate1)
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
t.Errorf("register callback with order")
}
}
func replaceCreate(s *Scope) {}
func TestReplaceCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Replace("create", replaceCreate)
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
t.Errorf("replace callback")
}
}
func TestRemoveCallback(t *testing.T) {
var callback = &callback{processors: []*callbackProcessor{}}
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
callback.Create().Register("before_create1", beforeCreate1)
callback.Create().Register("after_create1", afterCreate1)
callback.Create().Remove("create")
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
t.Errorf("remove callback")
}
}
package gorm
import (
"fmt"
"strings"
)
func AssignUpdateAttributes(scope *Scope) {
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
protected, ok := scope.Get("gorm:ignore_protected_attrs")
_, updateColumn := scope.Get("gorm:update_column")
updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
if updateColumn {
scope.InstanceSet("gorm:update_attrs", maps)
} else if len(updateAttrs) > 0 {
scope.InstanceSet("gorm:update_attrs", updateAttrs)
} else if !hasUpdate {
scope.SkipLeft()
return
}
}
}
}
func BeforeUpdate(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("BeforeSave")
scope.CallMethodWithErrorCheck("BeforeUpdate")
}
}
func UpdateTimeStampWhenUpdate(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.SetColumn("UpdatedAt", NowFunc())
}
}
func Update(scope *Scope) {
if !scope.HasError() {
var sqls []string
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
for key, value := range updateAttrs.(map[string]interface{}) {
if scope.changeableDBColumn(key) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
}
}
} else {
fields := scope.Fields()
for _, field := range fields {
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
if !field.IsBlank || !field.HasDefaultValue {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) {
if !relationField.IsBlank {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())))
}
}
}
}
}
if len(sqls) > 0 {
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v",
scope.QuotedTableName(),
strings.Join(sqls, ", "),
scope.CombinedConditionSql(),
))
scope.Exec()
}
}
}
func AfterUpdate(scope *Scope) {
if _, ok := scope.Get("gorm:update_column"); !ok {
scope.CallMethodWithErrorCheck("AfterUpdate")
scope.CallMethodWithErrorCheck("AfterSave")
}
}
func init() {
DefaultCallback.Update().Register("gorm:assign_update_attributes", AssignUpdateAttributes)
DefaultCallback.Update().Register("gorm:begin_transaction", BeginTransaction)
DefaultCallback.Update().Register("gorm:before_update", BeforeUpdate)
DefaultCallback.Update().Register("gorm:save_before_associations", SaveBeforeAssociations)
DefaultCallback.Update().Register("gorm:update_time_stamp_when_update", UpdateTimeStampWhenUpdate)
DefaultCallback.Update().Register("gorm:update", Update)
DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}
package gorm_test
import (
"errors"
"github.com/jinzhu/gorm"
"reflect"
"testing"
)
func (s *Product) BeforeCreate() (err error) {
if s.Code == "Invalid" {
err = errors.New("invalid product")
}
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
return
}
func (s *Product) BeforeUpdate() (err error) {
if s.Code == "dont_update" {
err = errors.New("can't update")
}
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
return
}
func (s *Product) BeforeSave() (err error) {
if s.Code == "dont_save" {
err = errors.New("can't save")
}
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
return
}
func (s *Product) AfterFind() {
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
}
func (s *Product) AfterCreate(tx *gorm.DB) {
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
}
func (s *Product) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
}
func (s *Product) AfterSave() (err error) {
if s.Code == "after_save_error" {
err = errors.New("can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
}
func (s *Product) BeforeDelete() (err error) {
if s.Code == "dont_delete" {
err = errors.New("can't delete")
}
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
return
}
func (s *Product) AfterDelete() (err error) {
if s.Code == "after_delete_error" {
err = errors.New("can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
}
func (s *Product) GetCallTimes() []int64 {
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
}
func TestRunCallbacks(t *testing.T) {
p := Product{Code: "unique_code", Price: 100}
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
}
p.Price = 200
DB.Save(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
}
var products []Product
DB.Find(&products, "code = ?", "unique_code")
if products[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should work with slice")
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
}
DB.Delete(&p)
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
}
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
t.Errorf("Can't find a deleted record")
}
}
func TestCallbacksWithErrors(t *testing.T) {
p := Product{Code: "Invalid", Price: 100}
if DB.Save(&p).Error == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value")
}
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
t.Errorf("Should not save record that have errors")
}
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value")
}
p2 := Product{Code: "update_callback", Price: 100}
DB.Save(&p2)
p2.Code = "dont_update"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value")
}
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
p2.Code = "dont_save"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value")
}
p3 := Product{Code: "dont_delete", Price: 100}
DB.Save(&p3)
if DB.Delete(&p3).Error == nil {
t.Errorf("An error from before delete callbacks happened when delete")
}
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("An error from before delete callbacks happened")
}
p4 := Product{Code: "after_save_error", Price: 100}
DB.Save(&p4)
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error in after save callback")
}
p5 := Product{Code: "after_delete_error", Price: 100}
DB.Save(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found")
}
DB.Delete(&p5)
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}
package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type commonDialect struct{}
func (commonDialect) BinVar(i int) string {
return "$$" // ?
}
func (commonDialect) SupportLastInsertId() bool {
return true
}
func (commonDialect) HasTop() bool {
return false
}
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "INTEGER AUTO_INCREMENT"
}
return "INTEGER"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "BIGINT AUTO_INCREMENT"
}
return "BIGINT"
case reflect.Float32, reflect.Float64:
return "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("VARCHAR(%d)", size)
}
return "VARCHAR(65532)"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "TIMESTAMP"
}
default:
if _, ok := value.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("BINARY(%d)", size)
}
return "BINARY(65532)"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
}
func (commonDialect) ReturningStr(tableName, key string) string {
return ""
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (commonDialect) databaseName(scope *Scope) string {
from := strings.Index(scope.db.parent.source, "/") + 1
to := strings.Index(scope.db.parent.source, "?")
if to == -1 {
to = len(scope.db.parent.source)
}
return scope.db.parent.source[from:to]
}
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}
package gorm_test
import (
"reflect"
"testing"
"time"
)
func TestCreate(t *testing.T) {
float := 35.03554004971999
user := User{Name: "CreateUser", Age: 18, Birthday: time.Now(), UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
if !DB.NewRecord(user) || !DB.NewRecord(&user) {
t.Error("User should be new record before create")
}
if count := DB.Save(&user).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
if DB.NewRecord(user) || DB.NewRecord(&user) {
t.Error("User should not new record after save")
}
var newUser User
DB.First(&newUser, user.Id)
if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
t.Errorf("User's PasswordHash should be saved ([]byte)")
}
if newUser.Age != 18 {
t.Errorf("User's Age should be saved (int)")
}
if newUser.UserNum != Num(111) {
t.Errorf("User's UserNum should be saved (custom type)")
}
if newUser.Latitude != float {
t.Errorf("Float64 should not be changed after save")
}
if user.CreatedAt.IsZero() {
t.Errorf("Should have created_at after create")
}
if newUser.CreatedAt.IsZero() {
t.Errorf("Should have created_at after create")
}
DB.Model(user).Update("name", "create_user_new_name")
DB.First(&user, user.Id)
if user.CreatedAt != newUser.CreatedAt {
t.Errorf("CreatedAt should not be changed after update")
}
}
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
jt := JoinTable{From: 1, To: 2}
err := DB.Create(&jt).Error
if err != nil {
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
}
}
func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
animal := Animal{Name: "Ferdinand"}
if DB.Save(&animal).Error != nil {
t.Errorf("No error should happen when create a record without std primary key")
}
if animal.Counter == 0 {
t.Errorf("No std primary key should be filled value after create")
}
if animal.Name != "Ferdinand" {
t.Errorf("Default value should be overrided")
}
// Test create with default value not overrided
an := Animal{From: "nerdz"}
if DB.Save(&an).Error != nil {
t.Errorf("No error should happen when create an record without std primary key")
}
// We must fetch the value again, to have the default fields updated
// (We can't do this in the update statements, since sql default can be expressions
// And be different from the fields' type (eg. a time.Time fiels has a default value of "now()"
DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
if an.Name != "galeone" {
t.Errorf("Default value should fill the field. But got %v", an.Name)
}
}
func TestAnonymousScanner(t *testing.T) {
user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
DB.Save(&user)
var user2 User
DB.First(&user2, "name = ?", "anonymous_scanner")
if user2.Role.Name != "admin" {
t.Errorf("Should be able to get anonymous scanner")
}
if !user2.IsAdmin() {
t.Errorf("Should be able to get anonymous scanner")
}
}
func TestAnonymousField(t *testing.T) {
user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
DB.Save(&user)
var user2 User
DB.First(&user2, "name = ?", "anonymous_field")
DB.Model(&user2).Related(&user2.Company)
if user2.Company.Name != "company" {
t.Errorf("Should be able to get anonymous field")
}
}
func TestSelectWithCreate(t *testing.T) {
user := getPreparedUser("select_user", "select_with_create")
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
var queryuser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
if queryuser.Name != user.Name || queryuser.Age == user.Age {
t.Errorf("Should only create users with name column")
}
if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 ||
queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 {
t.Errorf("Should only create selected relationships")
}
}
func TestOmitWithCreate(t *testing.T) {
user := getPreparedUser("omit_user", "omit_with_create")
DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
var queryuser User
DB.Preload("BillingAddress").Preload("ShippingAddress").
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
if queryuser.Name == user.Name || queryuser.Age != user.Age {
t.Errorf("Should only create users with age column")
}
if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 ||
queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 {
t.Errorf("Should not create omited relationships")
}
}
package gorm_test
import (
"testing"
"time"
)
type CustomizeColumn struct {
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
Name string `gorm:"column:mapped_name"`
Date time.Time `gorm:"column:mapped_time"`
}
// Make sure an ignored field does not interfere with another field's custom
// column name that matches the ignored field.
type CustomColumnAndIgnoredFieldClash struct {
Body string `sql:"-"`
RawBody string `gorm:"column:body"`
}
func TestCustomizeColumn(t *testing.T) {
col := "mapped_name"
DB.DropTable(&CustomizeColumn{})
DB.AutoMigrate(&CustomizeColumn{})
scope := DB.NewScope(&CustomizeColumn{})
if !scope.Dialect().HasColumn(scope, scope.TableName(), col) {
t.Errorf("CustomizeColumn should have column %s", col)
}
col = "mapped_id"
if scope.PrimaryKey() != col {
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
}
expected := "foo"
cc := CustomizeColumn{ID: 666, Name: expected, Date: time.Now()}
if count := DB.Create(&cc).RowsAffected; count != 1 {
t.Error("There should be one record be affected when create record")
}
var cc1 CustomizeColumn
DB.First(&cc1, 666)
if cc1.Name != expected {
t.Errorf("Failed to query CustomizeColumn")
}
cc.Name = "bar"
DB.Save(&cc)
var cc2 CustomizeColumn
DB.First(&cc2, 666)
if cc2.Name != "bar" {
t.Errorf("Failed to query CustomizeColumn")
}
}
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
t.Errorf("Should not raise error: %s", err)
}
}
package gorm_test
import (
"testing"
"time"
)
func TestDelete(t *testing.T) {
user1, user2 := User{Name: "delete1"}, User{Name: "delete2"}
DB.Save(&user1)
DB.Save(&user2)
if DB.Delete(&user1).Error != nil {
t.Errorf("No error should happen when delete a record")
}
if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
t.Errorf("Other users that not deleted should be found-able")
}
}
func TestInlineDelete(t *testing.T) {
user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"}
DB.Save(&user1)
DB.Save(&user2)
if DB.Delete(&User{}, user1.Id).Error != nil {
t.Errorf("No error should happen when delete a record")
} else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
if DB.Delete(&User{}, "name = ?", user2.Name).Error != nil {
t.Errorf("No error should happen when delete a record")
} else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
t.Errorf("User can't be found after delete")
}
}
func TestSoftDelete(t *testing.T) {
type User struct {
Id int64
Name string
DeletedAt time.Time
}
DB.AutoMigrate(&User{})
user := User{Name: "soft_delete"}
DB.Save(&user)
DB.Delete(&user)
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
t.Errorf("Can't find a soft deleted record")
}
if DB.Unscoped().First(&User{}, "name = ?", user.Name).Error != nil {
t.Errorf("Should be able to find soft deleted record with Unscoped")
}
DB.Unscoped().Delete(&user)
if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() {
t.Errorf("Can't find permanently deleted record")
}
}
package gorm
import (
"fmt"
"reflect"
)
type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
HasTop() bool
SqlTag(value reflect.Value, size int, autoIncrease bool) string
ReturningStr(tableName, key string) string
SelectFromDummyTable() string
Quote(key string) string
HasTable(scope *Scope, tableName string) bool
HasColumn(scope *Scope, tableName string, columnName string) bool
HasIndex(scope *Scope, tableName string, indexName string) bool
RemoveIndex(scope *Scope, indexName string)
}
func NewDialect(driver string) Dialect {
var d Dialect
switch driver {
case "foundation":
d = &foundation{}
case "mysql":
d = &mysql{}
case "sqlite3":
d = &sqlite3{}
case "mssql":
d = &mssql{}
default:
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", driver)
d = &commonDialect{}
}
return d
}
package gorm_test
import "testing"
type BasePost struct {
Id int64
Title string
URL string
}
type HNPost struct {
BasePost
Upvotes int32
}
type EngadgetPost struct {
BasePost BasePost `gorm:"embedded"`
ImageUrl string
}
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
var news HNPost
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if news.Title != "hn_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
var egNews EngadgetPost
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
} else if egNews.BasePost.Title != "engadget_news" {
t.Errorf("embedded struct's value should be scanned correctly")
}
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
t.Errorf("primary key with embedded struct should works")
}
for _, field := range DB.NewScope(&HNPost{}).Fields() {
if field.Name == "BasePost" {
t.Errorf("scope Fields should not contain embedded struct")
}
}
}
package gorm
import "errors"
var (
RecordNotFound = errors.New("record not found")
InvalidSql = errors.New("invalid sql")
NoNewAttrs = errors.New("no new attributes")
NoValidTransaction = errors.New("no valid transaction")
CantStartTransaction = errors.New("can't start transaction")
)
package gorm
import (
"database/sql"
"errors"
"reflect"
)
type Field struct {
*StructField
IsBlank bool
Field reflect.Value
}
func (field *Field) Set(value interface{}) error {
if !field.Field.IsValid() {
return errors.New("field value not valid")
}
if !field.Field.CanAddr() {
return errors.New("unaddressable value")
}
if rvalue, ok := value.(reflect.Value); ok {
value = rvalue.Interface()
}
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
if v, ok := value.(reflect.Value); ok {
if err := scanner.Scan(v.Interface()); err != nil {
return err
}
} else {
if err := scanner.Scan(value); err != nil {
return err
}
}
} else {
reflectValue, ok := value.(reflect.Value)
if !ok {
reflectValue = reflect.ValueOf(value)
}
if reflectValue.Type().ConvertibleTo(field.Field.Type()) {
field.Field.Set(reflectValue.Convert(field.Field.Type()))
} else {
return errors.New("could not convert argument")
}
}
field.IsBlank = isBlank(field.Field)
return nil
}
// Fields get value's fields
func (scope *Scope) Fields() map[string]*Field {
if scope.fields == nil {
fields := map[string]*Field{}
structFields := scope.GetStructFields()
indirectValue := scope.IndirectValue()
isStruct := indirectValue.Kind() == reflect.Struct
for _, structField := range structFields {
if isStruct {
fields[structField.DBName] = getField(indirectValue, structField)
} else {
fields[structField.DBName] = &Field{StructField: structField, IsBlank: true}
}
}
scope.fields = fields
}
return scope.fields
}
func getField(indirectValue reflect.Value, structField *StructField) *Field {
field := &Field{StructField: structField}
for _, name := range structField.Names {
indirectValue = reflect.Indirect(indirectValue).FieldByName(name)
}
field.Field = indirectValue
field.IsBlank = isBlank(indirectValue)
return field
}
package gorm
import (
"fmt"
"reflect"
"time"
)
type foundation struct {
commonDialect
}
func (foundation) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (foundation) SupportLastInsertId() bool {
return false
}
func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "serial"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigserial"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "double"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "clob"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
}
func (f foundation) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", f.Quote(tableName), key)
}
func (foundation) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName).Row().Scan(&count)
return count > 0
}
func (foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
return count > 0
}
func (f foundation) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", f.Quote(indexName)))
}
func (foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
package gorm
import "database/sql"
type sqlCommon interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type sqlDb interface {
Begin() (*sql.Tx, error)
}
type sqlTx interface {
Commit() error
Rollback() error
}
package gorm
import (
"errors"
"fmt"
"reflect"
"strings"
)
type JoinTableHandlerInterface interface {
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
Table(db *DB) string
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
SourceForeignKeys() []JoinTableForeignKey
DestinationForeignKeys() []JoinTableForeignKey
}
type JoinTableForeignKey struct {
DBName string
AssociationDBName string
}
type JoinTableSource struct {
ModelType reflect.Type
ForeignKeys []JoinTableForeignKey
}
type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
}
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
return s.Source.ForeignKeys
}
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
return s.Destination.ForeignKeys
}
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName
s.Source = JoinTableSource{ModelType: source}
sourceScope := &Scope{Value: reflect.New(source).Interface()}
sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields
for _, primaryField := range sourcePrimaryFields {
if relationship.ForeignDBName == "" {
relationship.ForeignFieldName = source.Name() + primaryField.Name
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
}
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.ForeignDBName
} else {
dbName = ToDBName(source.Name() + primaryField.Name)
}
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: dbName,
AssociationDBName: primaryField.DBName,
})
}
s.Destination = JoinTableSource{ModelType: destination}
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields
for _, primaryField := range destinationPrimaryFields {
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.AssociationForeignDBName
} else {
dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name)
}
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: dbName,
AssociationDBName: primaryField.DBName,
})
}
}
func (s JoinTableHandler) Table(db *DB) string {
return s.TableName
}
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{}
for _, source := range sources {
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
}
} else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
}
}
}
return values
}
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
scope := db.NewScope("")
searchMap := s.GetSearchMap(db, source1, source2)
var assignColumns, binVars, conditions []string
var values []interface{}
for key, value := range searchMap {
assignColumns = append(assignColumns, key)
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}
for _, value := range values {
values = append(values, value)
}
quotedTable := handler.Table(db)
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
quotedTable,
strings.Join(assignColumns, ","),
strings.Join(binVars, ","),
scope.Dialect().SelectFromDummyTable(),
quotedTable,
strings.Join(conditions, " AND "),
)
return db.Exec(sql, values...).Error
}
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var conditions []string
var values []interface{}
for key, value := range s.GetSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", key))
values = append(values, value)
}
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
quotedTable := handler.Table(db)
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
var joinConditions []string
var queryConditions []string
var values []interface{}
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}
for _, foreignKey := range s.Source.ForeignKeys {
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
Where(strings.Join(queryConditions, " AND "), values...)
} else {
db.Error = errors.New("wrong source type for join table handler")
return db
}
}
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment