From bf12e35ec9f3976bbe85561e60c8b2ee08988226 Mon Sep 17 00:00:00 2001 From: Kevin Chabowski Date: Wed, 28 Aug 2013 15:52:03 +0200 Subject: Added MySQL implementation of database model. --- model/mysql/jobs.go | 202 +++++++++++++++++++++++++++++++++++++++++++++++++ model/mysql/mysql.go | 84 ++++++++++++++++++++ model/mysql/queries.go | 51 +++++++++++++ model/mysql/users.go | 181 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 518 insertions(+) create mode 100644 model/mysql/jobs.go create mode 100644 model/mysql/mysql.go create mode 100644 model/mysql/queries.go create mode 100644 model/mysql/users.go diff --git a/model/mysql/jobs.go b/model/mysql/jobs.go new file mode 100644 index 0000000..d5abb1b --- /dev/null +++ b/model/mysql/jobs.go @@ -0,0 +1,202 @@ +package mysql + +import ( + "database/sql" + "fmt" + "kch42.de/gostuff/mailremind/chronos" + "kch42.de/gostuff/mailremind/model" + "log" + "time" +) + +type Job struct { + con *MySQLDBCon + + id DBID + user DBID + subject string + content []byte + next time.Time + chron []chronos.Chronos +} + +func jobFromSQL(con *MySQLDBCon, s scanner) (*Job, error) { + var _id, _user uint64 + var subject string + var content []byte + var _next int64 + var _mchron string + + if err := s.Scan(&_id, &_user, &subject, &content, &_next, &_mchron); err != nil { + return nil, err + } + + chron, err := chronos.ParseMultiChronos(_mchron) + if err != nil { + return nil, err + } + + return &Job{ + con: con, + id: DBID(_id), + user: DBID(_user), + subject: subject, + content: content, + next: time.Unix(_next, 0), + chron: chron, + }, nil +} + +func (u *User) CountJobs() (c int) { + row := u.con.stmt[qCountJobs].QueryRow(uint64(u.id)) + if err := row.Scan(&c); err != nil { + log.Printf("Failed counting user's (%d) jobs: %s", u.id, err) + c = 0 + } + return +} + +func (u *User) Jobs() []model.Job { + rows, err := u.con.stmt[qJobsOfUser].Query(uint64(u.id)) + if err != nil { + log.Printf("Failed getting jobs of user %d: %s", u.id, err) + return nil + } + + jobs := make([]model.Job, 0) + for rows.Next() { + job, err := jobFromSQL(u.con, rows) + if err != nil { + log.Printf("Failed getting all jobs of user %d: %s", u.id, err) + break + } + jobs = append(jobs, job) + } + + return jobs +} + +func (u *User) JobByID(_id model.DBID) (model.Job, error) { + id := _id.(DBID) + + row := u.con.stmt[qJobFromUserAndID].QueryRow(uint64(u.id), uint64(id)) + switch job, err := jobFromSQL(u.con, row); err { + case nil: + return job, nil + case sql.ErrNoRows: + return nil, model.NotFound + default: + return nil, err + } +} + +func (u *User) AddJob(subject string, content []byte, chron chronos.MultiChronos, next time.Time) (model.Job, error) { + tx, err := u.con.con.Begin() + if err != nil { + return nil, err + } + + insjob := tx.Stmt(u.con.stmt[qInsertJob]) + + res, err := insjob.Exec(uint64(u.id), subject, content, next.Unix(), chron.String()) + if err != nil { + tx.Rollback() + return nil, err + } + + _id, err := res.LastInsertId() + if err != nil { + tx.Rollback() + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return &Job{ + con: u.con, + id: DBID(_id), + user: u.id, + subject: subject, + content: content, + next: next, + chron: chron, + }, nil +} + +func (j *Job) ID() model.DBID { return j.id } +func (j *Job) Subject() string { return j.subject } +func (j *Job) Content() []byte { return j.content } +func (j *Job) Chronos() chronos.MultiChronos { return j.chron } +func (j *Job) Next() time.Time { return j.next } + +func (j *Job) User() model.User { + u, err := j.con.UserByID(j.user) + if err != nil { + // TODO: Should we really panic here? If yes, we need to recover panics! + panic(fmt.Errorf("Could not get user (%d) of Job %d: %s", j.user, j.id, err)) + } + + return u +} + +func (j *Job) SetSubject(sub string) error { + if _, err := j.con.stmt[qSetSubject].Exec(sub, uint64(j.id)); err != nil { + return err + } + + j.subject = sub + return nil +} + +func (j *Job) SetContent(cont []byte) error { + if _, err := j.con.stmt[qSetContent].Exec(cont, uint64(j.id)); err != nil { + return err + } + + j.content = cont + return nil +} + +func (j *Job) SetChronos(chron chronos.MultiChronos) error { + if _, err := j.con.stmt[qSetChronos].Exec(chron.String(), uint64(j.id)); err != nil { + return err + } + + j.chron = chron + return nil +} + +func (j *Job) SetNext(next time.Time) error { + if _, err := j.con.stmt[qSetNext].Exec(next.Unix(), uint64(j.id)); err != nil { + return err + } + + j.next = next + return nil +} + +func (j *Job) Delete() error { + _, err := j.con.stmt[qDelJob].Exec(j.id) + return err +} + +func (con *MySQLDBCon) JobsBefore(t time.Time) []model.DBID { + rows, err := con.stmt[qJobsBefore].Query(t.Unix()) + if err != nil { + log.Fatalf("Could not get jobs before %s: %s", t, err) // TODO: Really fatal? + } + + ids := make([]model.DBID, 0) + for rows.Next() { + var _id uint64 + if err := rows.Scan(&_id); err != nil { + log.Printf("Could not get all jobs before %s: %s", t, err) + break + } + ids = append(ids, DBID(_id)) + } + + return ids +} diff --git a/model/mysql/mysql.go b/model/mysql/mysql.go new file mode 100644 index 0000000..08c55bc --- /dev/null +++ b/model/mysql/mysql.go @@ -0,0 +1,84 @@ +package mysql + +import ( + "database/sql" + "fmt" + _ "github.com/go-sql-driver/mysql" + "kch42.de/gostuff/mailremind/model" + "strconv" +) + +type scanner interface { + Scan(dest ...interface{}) error +} + +type DBID uint64 + +func (id DBID) String() string { + return strconv.FormatUint(uint64(id), 16) +} + +func parseDBID(s string) (model.DBID, error) { + _id, err := strconv.ParseUint(s, 16, 64) + return DBID(_id), err +} + +type MySQLDBCon struct { + con *sql.DB + stmt []*sql.Stmt +} + +func connect(dbconf string) (model.DBCon, error) { + con, err := sql.Open("mysql", dbconf) + if err != nil { + return nil, err + } + + dbc := &MySQLDBCon{ + con: con, + stmt: make([]*sql.Stmt, qEnd), + } + + for i := 0; i < qEnd; i++ { + stmt, err := con.Prepare(queries[i]) + if err != nil { + con.Close() + return nil, fmt.Errorf("Failed to prepare statement %d : <%s>: %s", i, queries[i], err) + } + dbc.stmt[i] = stmt + } + + return dbc, nil +} + +func init() { + model.Register("mysql", model.DBInfo{ + Connect: connect, + ParseDBID: parseDBID, + }) +} + +func (con *MySQLDBCon) Close() { + con.con.Close() +} + +func rollbackAfterFail(err error, tx *sql.Tx) error { + if rberr := tx.Rollback(); rberr != nil { + return fmt.Errorf("Rollback error: <%s>, Original error: %s", rberr, err) + } + return err +} + +func i2b(i int) bool { + if i == 0 { + return false + } + return true +} + +func b2i(b bool) int { + if b { + return 1 + } + return 0 +} diff --git a/model/mysql/queries.go b/model/mysql/queries.go new file mode 100644 index 0000000..da2a554 --- /dev/null +++ b/model/mysql/queries.go @@ -0,0 +1,51 @@ +package mysql + +const ( + qUserByID = iota + qUserByEmail + qSetPWHash + qSetActive + qSetAcCode + qDelUsersJobs + qDelUser + qGetOldInactiveUsers + qCountJobs + qJobsOfUser + qJobFromUserAndID + qSetSubject + qSetContent + qSetNext + qDelJob + qJobsBefore + qInsertJob + qInsertUser + qSetChronos + qEnd +) + +const ( + qfragSelUser = "SELECT `id`, `email`, `passwd`, `location`, `active`, `activationcode`, `added` FROM `users` " + qfragSelJob = "SELECT `id`, `user`, `subject`, `content`, `next`, `chronos` FROM `jobs` " +) + +var queries = map[int]string{ + qUserByID: qfragSelUser + "WHERE `id` = ?", + qUserByEmail: qfragSelUser + "WHERE `email` = ?", + qSetPWHash: "UPDATE `users` SET `passwd` = ? WHERE `id` = ?", + qSetActive: "UPDATE `users` SET `active` = ? WHERE `id` = ?", + qSetAcCode: "UPDATE `users` SET `activationcode` = ? WHERE `id` = ?", + qDelUsersJobs: "DELETE FROM `jobs` WHERE `user` = ?", + qDelUser: "DELETE FROM `users` WHERE `id` = ?", + qGetOldInactiveUsers: "SELECT `id` FROM `users` WHERE `active` = 0 AND `added` < ?", + qCountJobs: "SELECT COUNT(*) FROM `jobs` WHERE `user` = ?", + qJobsOfUser: qfragSelJob + "WHERE `user` = ?", + qJobFromUserAndID: qfragSelJob + "WHERE `user` = ? AND `id` = ?", + qSetSubject: "UPDATE `jobs` SET `subject` = ? WHERE `id` = ?", + qSetContent: "UPDATE `jobs` SET `content` = ? WHERE `id` = ?", + qSetNext: "UPDATE `jobs` SET `next` = ? WHERE `id` = ?", + qDelJob: "DELETE FROM `jobs` WHERE `id` = ?", + qJobsBefore: "SELECT `id` FROM `jobs` WHERE `next` <= ?", + qInsertJob: "INSERT INTO `jobs` (`user`, `subject`, `content`, `next`, `chronos`) VALUES (?, ?, ?, ?, ?)", + qInsertUser: "INSERT INTO `users` (`email`, `passwd`, `location`, `active`, `activationcode`, `added`) VALUES (?, ?, ?, ?, ?, ?)", + qSetChronos: "UPDATE `jobs` SET `chronos` = ? WHERE `id` = ?", +} diff --git a/model/mysql/users.go b/model/mysql/users.go new file mode 100644 index 0000000..f07f472 --- /dev/null +++ b/model/mysql/users.go @@ -0,0 +1,181 @@ +package mysql + +import ( + "database/sql" + "kch42.de/gostuff/mailremind/model" + "log" + "time" +) + +type User struct { + con *MySQLDBCon + + id DBID + email, passwd, acCode string + location *time.Location + added time.Time + active bool +} + +func userFromSQL(con *MySQLDBCon, s scanner) (*User, error) { + var id uint64 + var added int64 + var email, passwd, _loc, acCode string + var active int + + switch err := s.Scan(&id, &email, &passwd, &_loc, &active, &acCode, &added); err { + case nil: + case sql.ErrNoRows: + return nil, model.NotFound + default: + return nil, err + } + + user := &User{ + con: con, + id: DBID(id), + email: email, + passwd: passwd, + acCode: acCode, + added: time.Unix(added, 0), + active: i2b(active), + } + + loc, err := time.LoadLocation(_loc) + if err != nil { + loc = time.UTC + } + user.location = loc + + return user, nil +} + +func (con *MySQLDBCon) UserByID(_id model.DBID) (model.User, error) { + id := _id.(DBID) + + row := con.stmt[qUserByID].QueryRow(uint64(id)) + return userFromSQL(con, row) +} + +func (con *MySQLDBCon) UserByMail(email string) (model.User, error) { + row := con.stmt[qUserByEmail].QueryRow(email) + return userFromSQL(con, row) +} + +func (u *User) ID() model.DBID { return u.id } +func (u *User) Email() string { return u.email } +func (u *User) PWHash() []byte { return []byte(u.PWHash()) } +func (u *User) Active() bool { return u.active } +func (u *User) ActivationCode() string { return u.acCode } + +func (u *User) SetPWHash(_pwhash []byte) error { + pwhash := string(_pwhash) + + if _, err := u.con.stmt[qSetPWHash].Query(pwhash, uint64(u.id)); err != nil { + return err + } + + u.passwd = string(_pwhash) + return nil +} + +func (u *User) SetActive(b bool) error { + if _, err := u.con.stmt[qSetActive].Query(b2i(b), uint64(u.id)); err != nil { + return err + } + + u.active = b + return nil +} + +func (u *User) SetActivationCode(c string) error { + if _, err := u.con.stmt[qSetAcCode].Query(c, uint64(u.id)); err != nil { + return err + } + + u.acCode = c + return nil +} + +func (u *User) Delete() error { + tx, err := u.con.con.Begin() + if err != nil { + return err + } + + id := uint64(u.id) + + deljobs := tx.Stmt(u.con.stmt[qDelUsersJobs]) + deluser := tx.Stmt(u.con.stmt[qDelUser]) + + if _, err := deljobs.Query(id); err != nil { + return rollbackAfterFail(err, tx) + } + + if _, err := deluser.Query(id); err != nil { + return rollbackAfterFail(err, tx) + } + + return tx.Commit() +} + +func (con *MySQLDBCon) InactiveUsers(olderthan time.Time) []model.DBID { + ids := make([]model.DBID, 0) + + rows, err := con.stmt[qGetOldInactiveUsers].Query(olderthan.Unix()) + if err != nil { + log.Printf("Failed to get old, inactive users: %s", err) + return ids + } + + for rows.Next() { + var _id uint64 + + if err := rows.Scan(&_id); err != nil { + log.Printf("Failed to get old, inactive users: %s", err) + return ids + } + + ids = append(ids, DBID(_id)) + } + + return ids +} + +func (con *MySQLDBCon) AddUser(email string, pwhash []byte, location *time.Location, active bool, acCode string) (model.User, error) { + now := time.Now() + + tx, err := con.con.Begin() + if err != nil { + return nil, err + } + + insjob := tx.Stmt(con.stmt[qInsertUser]) + + res, err := insjob.Exec(email, string(pwhash), location.String(), b2i(active), acCode, now.Unix()) + if err != nil { + tx.Rollback() + return nil, err + } + + _id, err := res.LastInsertId() + if err != nil { + tx.Rollback() + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return &User{ + con: con, + id: DBID(_id), + email: email, + passwd: string(pwhash), + acCode: acCode, + location: location, + added: now, + active: active, + }, nil +} -- cgit v1.2.3-54-g00ecf