getting started
This commit is contained in:
parent
0ba2ba3db1
commit
991950bdd4
428
mvwa/sqlite/database.go
Normal file
428
mvwa/sqlite/database.go
Normal file
@ -0,0 +1,428 @@
|
||||
package sqlite // name the package as you see fit
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// This is the data type to exchange data with the database
|
||||
type Record = map[string]any
|
||||
|
||||
type Database struct {
|
||||
databaseName string
|
||||
database *sql.DB
|
||||
}
|
||||
|
||||
type Transaction struct {
|
||||
tx *sql.Tx
|
||||
err error
|
||||
}
|
||||
|
||||
type Action func(tx *sql.Tx) error
|
||||
|
||||
func New(DBName string) (*Database, error) {
|
||||
return &Database{databaseName: DBName}, nil
|
||||
}
|
||||
|
||||
func (d *Database) Close() error {
|
||||
return d.database.Close()
|
||||
}
|
||||
|
||||
// provides access to the internal database object
|
||||
func (d *Database) DB() *sql.DB {
|
||||
return d.database
|
||||
}
|
||||
|
||||
func (d *Database) Name() string {
|
||||
return d.databaseName
|
||||
}
|
||||
|
||||
func (d *Database) Open() (err error) {
|
||||
d.database, err = openSqliteDB(d.databaseName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *Database) OpenInMemory() (err error) {
|
||||
d.database, err = sql.Open("sqlite", ":memory:")
|
||||
return err
|
||||
}
|
||||
|
||||
func openSqliteDB(databasefilename string) (*sql.DB, error) {
|
||||
|
||||
_, err := os.Stat(databasefilename)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return createDB(databasefilename)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.Open("sqlite", databasefilename)
|
||||
|
||||
}
|
||||
|
||||
func createDB(dbfileName string) (*sql.DB, error) {
|
||||
|
||||
query := `
|
||||
PRAGMA page_size = 4096;
|
||||
PRAGMA synchronous = off;
|
||||
PRAGMA foreign_keys = off;
|
||||
PRAGMA journal_mode = WAL;
|
||||
PRAGMA user_version = 1;
|
||||
`
|
||||
db, err := sql.Open("sqlite", dbfileName)
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
}
|
||||
_, err = db.Exec(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (d *Database) TableList() (result []Record, err error) {
|
||||
return d.ReadRecords("select name from sqlite_master where type='table';")
|
||||
}
|
||||
|
||||
func (d *Database) ReadTable(tablename string) (result []Record, err error) {
|
||||
|
||||
return d.ReadRecords(fmt.Sprintf("select * from '%s';", tablename))
|
||||
}
|
||||
|
||||
func (d *Database) ReadRecords(query string, args ...any) (result []Record, err error) {
|
||||
|
||||
rows, err := d.DB().Query(query, args...)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return Rows2records(rows)
|
||||
|
||||
}
|
||||
|
||||
func (d *Database) GetRecord(tablename string, idfield string, key any) (result Record, err error) {
|
||||
|
||||
query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield)
|
||||
res, err := d.DB().Query(query, key)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
defer res.Close()
|
||||
return Rows2record(res)
|
||||
|
||||
}
|
||||
|
||||
func (d *Database) UpsertRecord(tablename string, idfield string, record Record) (result Record, err error) {
|
||||
|
||||
return upsert(d.DB(), tablename, idfield, record)
|
||||
|
||||
}
|
||||
|
||||
func (d *Database) DeleteRecord(tablename string, idfield string, id any) (err error) {
|
||||
|
||||
return deleteRecord(d.DB(), tablename, idfield, id)
|
||||
|
||||
}
|
||||
|
||||
// *sql.DB and *sql.Tx both have a method named 'Query',
|
||||
// this way they can both be passed into upsert and deleteRecord function
|
||||
type iquery interface {
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func upsert(t iquery, tablename string, idfield string, record Record) (result Record, err error) {
|
||||
|
||||
fields := []string{}
|
||||
data := []any{}
|
||||
for k, v := range record {
|
||||
fields = append(fields, k)
|
||||
data = append(data, v)
|
||||
}
|
||||
query, err := buildUpsertCommand(tablename, idfield, fields)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
res, err := t.Query(query, data...) // res contains the full record - see SQLite: RETURNING *
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
defer res.Close()
|
||||
return Rows2record(res)
|
||||
|
||||
}
|
||||
|
||||
func deleteRecord(t iquery, tablename string, idfield string, id any) (err error) {
|
||||
|
||||
query := fmt.Sprintf("DELETE FROM \"%s\" WHERE \"%s\" = ?;", tablename, idfield)
|
||||
_, err = t.Query(query, id)
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func buildUpsertCommand(tablename string, idfield string, fields []string) (string, error) {
|
||||
var sb strings.Builder
|
||||
sb.Grow(256 + len(fields)*20) // rough preallocation
|
||||
|
||||
// INSERT INTO
|
||||
sb.WriteString(`INSERT INTO "`)
|
||||
sb.WriteString(tablename)
|
||||
sb.WriteString(`"(`)
|
||||
for i, f := range fields {
|
||||
sb.WriteString(` "`)
|
||||
sb.WriteString(f)
|
||||
sb.WriteByte('"')
|
||||
if i < len(fields)-1 {
|
||||
sb.WriteByte(',')
|
||||
}
|
||||
}
|
||||
sb.WriteString(")\n\tVALUES(")
|
||||
|
||||
// VALUES
|
||||
for i := 0; i < len(fields); i++ {
|
||||
sb.WriteString(" ?")
|
||||
sb.Write(strconv.AppendInt(nil, int64(i+1), 10))
|
||||
if i < len(fields)-1 {
|
||||
sb.WriteByte(',')
|
||||
}
|
||||
}
|
||||
sb.WriteString(")\n\tON CONFLICT(\"")
|
||||
sb.WriteString(tablename)
|
||||
sb.WriteString(`"."`)
|
||||
sb.WriteString(idfield)
|
||||
sb.WriteString("\")\n\tDO UPDATE SET ")
|
||||
|
||||
// UPDATE SET
|
||||
for i, f := range fields {
|
||||
sb.WriteByte('"')
|
||||
sb.WriteString(f)
|
||||
sb.WriteString(`"= ?`)
|
||||
sb.Write(strconv.AppendInt(nil, int64(i+1), 10))
|
||||
if i < len(fields)-1 {
|
||||
sb.WriteByte(',')
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n\tRETURNING *;")
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// func buildUpsertCommand(tablename string, idfield string, fields []string) (result string, err error) {
|
||||
|
||||
// pname := map[string]string{} // assign correct index for parameter name
|
||||
// // parameter position, starts at 1 in sql! So it needs to be calculated by function pname inside template
|
||||
|
||||
// for i, k := range fields {
|
||||
// pname[k] = strconv.Itoa(i + 1)
|
||||
// }
|
||||
// funcMap := template.FuncMap{
|
||||
// "pname": func(fieldname string) string {
|
||||
// return pname[fieldname]
|
||||
// },
|
||||
// }
|
||||
// tableDef := struct {
|
||||
// Tablename string
|
||||
// KeyField string
|
||||
// LastField int
|
||||
// FieldNames []string
|
||||
// }{
|
||||
// Tablename: tablename,
|
||||
// KeyField: idfield,
|
||||
// LastField: len(fields) - 1,
|
||||
// FieldNames: fields,
|
||||
// }
|
||||
// var templString = `{{$last := .LastField}}INSERT INTO "{{ .Tablename }}"({{ range $i,$el := .FieldNames }} "{{$el}}"{{if ne $i $last}},{{end}}{{end}})
|
||||
// VALUES({{ range $i,$el := .FieldNames }} ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}})
|
||||
// ON CONFLICT("{{ .Tablename }}"."{{.KeyField}}")
|
||||
// DO UPDATE SET {{ range $i,$el := .FieldNames }}"{{$el}}"= ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}}
|
||||
// RETURNING *;`
|
||||
|
||||
// dbTempl, err := template.New("upsertDB").Funcs(funcMap).Parse(templString)
|
||||
// if err != nil {
|
||||
// return result, err
|
||||
// }
|
||||
// var templBytes bytes.Buffer
|
||||
// err = dbTempl.Execute(&templBytes, tableDef)
|
||||
// if err != nil {
|
||||
// return result, err
|
||||
// }
|
||||
// return templBytes.String(), nil
|
||||
// }
|
||||
|
||||
func Rows2record(rows *sql.Rows) (Record, error) {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values := make([]any, len(columns))
|
||||
valuePtrs := make([]any, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
result := Record{}
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i, col := range columns {
|
||||
result[col] = values[i]
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errors.New("no rows found")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func Rows2records(rows *sql.Rows) ([]Record, error) {
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
recLength := len(columns)
|
||||
results := []Record{}
|
||||
for rows.Next() {
|
||||
values := make([]any, recLength)
|
||||
valuePtrs := make([]any, recLength)
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
record := Record{}
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i, col := range columns {
|
||||
record[col] = values[i]
|
||||
}
|
||||
results = append(results, record)
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, errors.New("no rows found")
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (d *Database) Version() (string, error) {
|
||||
result := ""
|
||||
sqliteversion, err := d.ReadRecords("SELECT sqlite_version();")
|
||||
if len(sqliteversion) == 1 {
|
||||
result = sqliteversion[0]["sqlite_version()"].(string)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (d *Database) UserVersion() (int64, error) {
|
||||
|
||||
var result int64
|
||||
userversion, err := d.ReadRecords("PRAGMA user_version;")
|
||||
if len(userversion) == 1 {
|
||||
result = userversion[0]["user_version"].(int64)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (d *Database) Begin() *Transaction {
|
||||
tx, err := d.database.Begin()
|
||||
return &Transaction{tx, err}
|
||||
}
|
||||
|
||||
func (t *Transaction) Next(action Action) *Transaction {
|
||||
if t.err != nil {
|
||||
return t
|
||||
}
|
||||
t.err = action(t.tx)
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Transaction) End() error {
|
||||
|
||||
if t.err != nil {
|
||||
err := t.tx.Rollback()
|
||||
if err != nil {
|
||||
t.err = errors.Join(t.err, err)
|
||||
}
|
||||
return t.err
|
||||
}
|
||||
t.err = t.tx.Commit()
|
||||
return t.err
|
||||
}
|
||||
|
||||
func (t *Transaction) GetRecord(tablename string, idfield string, key any, output Record) *Transaction {
|
||||
|
||||
if t.err != nil {
|
||||
return t
|
||||
}
|
||||
query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield)
|
||||
res, err := t.tx.Query(query, key)
|
||||
if err != nil {
|
||||
t.err = err
|
||||
return t
|
||||
}
|
||||
defer res.Close()
|
||||
result, err := Rows2record(res)
|
||||
if err != nil {
|
||||
t.err = err
|
||||
return t
|
||||
}
|
||||
for k := range output {
|
||||
delete(output, k)
|
||||
}
|
||||
for k, v := range result {
|
||||
output[k] = v
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Transaction) UpsertRecord(tablename string, idfield string, record Record, output Record) *Transaction {
|
||||
|
||||
if t.err != nil {
|
||||
return t
|
||||
}
|
||||
result, err := upsert(t.tx, tablename, idfield, record)
|
||||
if err != nil {
|
||||
t.err = err
|
||||
return t
|
||||
}
|
||||
for k := range output {
|
||||
delete(output, k)
|
||||
}
|
||||
for k, v := range result {
|
||||
output[k] = v
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Transaction) DeleteRecord(tablename string, idfield string, id any) *Transaction {
|
||||
|
||||
if t.err != nil {
|
||||
return t
|
||||
}
|
||||
err := deleteRecord(t.tx, tablename, idfield, id)
|
||||
if err != nil {
|
||||
t.err = err
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// returns a value of the provided type, if the field exist and if it can be cast into the provided type parameter
|
||||
func Value[T any](rec Record, field string) (value T, ok bool) {
|
||||
var v any
|
||||
if v, ok = rec[field]; ok {
|
||||
value, ok = v.(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// don't report an error if there are simply just 'no rows found'
|
||||
func NoRowsOk(recs []Record, err error) ([]Record, error) {
|
||||
if err != nil && err.Error() != "no rows found" {
|
||||
return recs, err
|
||||
}
|
||||
return recs, nil
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user