441 lines
9.8 KiB
Go
441 lines
9.8 KiB
Go
package sqlite // name the package as you see fit
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
// This is the data type to exchange data with the database
|
|
type Record = map[string]any
|
|
|
|
type Database struct {
|
|
databaseName string
|
|
database *sql.DB
|
|
}
|
|
|
|
type Transaction struct {
|
|
tx *sql.Tx
|
|
err error
|
|
}
|
|
|
|
type Action func(tx *sql.Tx) error
|
|
|
|
func New(DBName string) (*Database, error) {
|
|
return &Database{databaseName: DBName}, nil
|
|
}
|
|
|
|
func (d *Database) Close() error {
|
|
query := `
|
|
PRAGMA analysis_limit = 400;
|
|
PRAGMA optimize;
|
|
`
|
|
_, err := d.DB().Exec(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return d.database.Close()
|
|
}
|
|
|
|
// provides access to the internal database object
|
|
func (d *Database) DB() *sql.DB {
|
|
return d.database
|
|
}
|
|
|
|
func (d *Database) Name() string {
|
|
return d.databaseName
|
|
}
|
|
|
|
func (d *Database) Open() (err error) {
|
|
d.database, err = openSqliteDB(d.databaseName)
|
|
return err
|
|
}
|
|
|
|
func (d *Database) OpenInMemory() (err error) {
|
|
d.database, err = sql.Open("sqlite", ":memory:")
|
|
return err
|
|
}
|
|
|
|
func openSqliteDB(databasefilename string) (*sql.DB, error) {
|
|
|
|
_, err := os.Stat(databasefilename)
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return createDB(databasefilename)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return sql.Open("sqlite", databasefilename)
|
|
|
|
}
|
|
|
|
func createDB(dbfileName string) (*sql.DB, error) {
|
|
|
|
query := `
|
|
PRAGMA synchronous = off;
|
|
PRAGMA foreign_keys = on;
|
|
PRAGMA journal_mode = WAL;
|
|
PRAGMA busy_timeout = 5000;
|
|
PRAGMA cache_size = 2000;
|
|
PRAGMA temp_store = memory;
|
|
PRAGMA mmap_size = 30000000000;
|
|
PRAGMA page_size = 4096;
|
|
PRAGMA user_version = 1;
|
|
`
|
|
db, err := sql.Open("sqlite", dbfileName)
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
}
|
|
_, err = db.Exec(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return db, nil
|
|
}
|
|
|
|
func (d *Database) TableList() (result []Record, err error) {
|
|
return d.ReadRecords("select name from sqlite_master where type='table';")
|
|
}
|
|
|
|
func (d *Database) ReadTable(tablename string) (result []Record, err error) {
|
|
|
|
return d.ReadRecords(fmt.Sprintf("select * from '%s';", tablename))
|
|
}
|
|
|
|
func (d *Database) ReadRecords(query string, args ...any) (result []Record, err error) {
|
|
|
|
rows, err := d.DB().Query(query, args...)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
defer rows.Close()
|
|
return Rows2records(rows)
|
|
|
|
}
|
|
|
|
func (d *Database) GetRecord(tablename string, idfield string, key any) (result Record, err error) {
|
|
|
|
query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield)
|
|
res, err := d.DB().Query(query, key)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
defer res.Close()
|
|
return Rows2record(res)
|
|
|
|
}
|
|
|
|
func (d *Database) UpsertRecord(tablename string, idfield string, record Record) (result Record, err error) {
|
|
|
|
return upsert(d.DB(), tablename, idfield, record)
|
|
|
|
}
|
|
|
|
func (d *Database) DeleteRecord(tablename string, idfield string, id any) (err error) {
|
|
|
|
return deleteRecord(d.DB(), tablename, idfield, id)
|
|
|
|
}
|
|
|
|
// *sql.DB and *sql.Tx both have a method named 'Query',
|
|
// this way they can both be passed into upsert and deleteRecord function
|
|
type iquery interface {
|
|
Query(query string, args ...any) (*sql.Rows, error)
|
|
}
|
|
|
|
func upsert(t iquery, tablename string, idfield string, record Record) (result Record, err error) {
|
|
|
|
fields := []string{}
|
|
data := []any{}
|
|
for k, v := range record {
|
|
fields = append(fields, k)
|
|
data = append(data, v)
|
|
}
|
|
query, err := buildUpsertCommand(tablename, idfield, fields)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
res, err := t.Query(query, data...) // res contains the full record - see SQLite: RETURNING *
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
defer res.Close()
|
|
return Rows2record(res)
|
|
|
|
}
|
|
|
|
func deleteRecord(t iquery, tablename string, idfield string, id any) (err error) {
|
|
|
|
query := fmt.Sprintf("DELETE FROM \"%s\" WHERE \"%s\" = ?;", tablename, idfield)
|
|
_, err = t.Query(query, id)
|
|
return err
|
|
|
|
}
|
|
|
|
func buildUpsertCommand(tablename string, idfield string, fields []string) (string, error) {
|
|
var sb strings.Builder
|
|
sb.Grow(256 + len(fields)*20) // rough preallocation
|
|
|
|
// INSERT INTO
|
|
sb.WriteString(`INSERT INTO "`)
|
|
sb.WriteString(tablename)
|
|
sb.WriteString(`"(`)
|
|
for i, f := range fields {
|
|
sb.WriteString(` "`)
|
|
sb.WriteString(f)
|
|
sb.WriteByte('"')
|
|
if i < len(fields)-1 {
|
|
sb.WriteByte(',')
|
|
}
|
|
}
|
|
sb.WriteString(")\n\tVALUES(")
|
|
|
|
// VALUES
|
|
for i := 0; i < len(fields); i++ {
|
|
sb.WriteString(" ?")
|
|
sb.Write(strconv.AppendInt(nil, int64(i+1), 10))
|
|
if i < len(fields)-1 {
|
|
sb.WriteByte(',')
|
|
}
|
|
}
|
|
sb.WriteString(")\n\tON CONFLICT(\"")
|
|
sb.WriteString(tablename)
|
|
sb.WriteString(`"."`)
|
|
sb.WriteString(idfield)
|
|
sb.WriteString("\")\n\tDO UPDATE SET ")
|
|
|
|
// UPDATE SET
|
|
for i, f := range fields {
|
|
sb.WriteByte('"')
|
|
sb.WriteString(f)
|
|
sb.WriteString(`"= ?`)
|
|
sb.Write(strconv.AppendInt(nil, int64(i+1), 10))
|
|
if i < len(fields)-1 {
|
|
sb.WriteByte(',')
|
|
}
|
|
}
|
|
sb.WriteString("\n\tRETURNING *;")
|
|
|
|
return sb.String(), nil
|
|
}
|
|
|
|
// func buildUpsertCommand(tablename string, idfield string, fields []string) (result string, err error) {
|
|
|
|
// pname := map[string]string{} // assign correct index for parameter name
|
|
// // parameter position, starts at 1 in sql! So it needs to be calculated by function pname inside template
|
|
|
|
// for i, k := range fields {
|
|
// pname[k] = strconv.Itoa(i + 1)
|
|
// }
|
|
// funcMap := template.FuncMap{
|
|
// "pname": func(fieldname string) string {
|
|
// return pname[fieldname]
|
|
// },
|
|
// }
|
|
// tableDef := struct {
|
|
// Tablename string
|
|
// KeyField string
|
|
// LastField int
|
|
// FieldNames []string
|
|
// }{
|
|
// Tablename: tablename,
|
|
// KeyField: idfield,
|
|
// LastField: len(fields) - 1,
|
|
// FieldNames: fields,
|
|
// }
|
|
// var templString = `{{$last := .LastField}}INSERT INTO "{{ .Tablename }}"({{ range $i,$el := .FieldNames }} "{{$el}}"{{if ne $i $last}},{{end}}{{end}})
|
|
// VALUES({{ range $i,$el := .FieldNames }} ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}})
|
|
// ON CONFLICT("{{ .Tablename }}"."{{.KeyField}}")
|
|
// DO UPDATE SET {{ range $i,$el := .FieldNames }}"{{$el}}"= ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}}
|
|
// RETURNING *;`
|
|
|
|
// dbTempl, err := template.New("upsertDB").Funcs(funcMap).Parse(templString)
|
|
// if err != nil {
|
|
// return result, err
|
|
// }
|
|
// var templBytes bytes.Buffer
|
|
// err = dbTempl.Execute(&templBytes, tableDef)
|
|
// if err != nil {
|
|
// return result, err
|
|
// }
|
|
// return templBytes.String(), nil
|
|
// }
|
|
|
|
func Rows2record(rows *sql.Rows) (Record, error) {
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
values := make([]any, len(columns))
|
|
valuePtrs := make([]any, len(columns))
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
result := Record{}
|
|
for rows.Next() {
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
return nil, err
|
|
}
|
|
for i, col := range columns {
|
|
result[col] = values[i]
|
|
}
|
|
}
|
|
if len(result) == 0 {
|
|
return nil, errors.New("no rows found")
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func Rows2records(rows *sql.Rows) ([]Record, error) {
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
recLength := len(columns)
|
|
results := []Record{}
|
|
for rows.Next() {
|
|
values := make([]any, recLength)
|
|
valuePtrs := make([]any, recLength)
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
record := Record{}
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
return nil, err
|
|
}
|
|
for i, col := range columns {
|
|
record[col] = values[i]
|
|
}
|
|
results = append(results, record)
|
|
}
|
|
if len(results) == 0 {
|
|
return nil, errors.New("no rows found")
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (d *Database) Version() (string, error) {
|
|
result := ""
|
|
sqliteversion, err := d.ReadRecords("SELECT sqlite_version();")
|
|
if len(sqliteversion) == 1 {
|
|
result = sqliteversion[0]["sqlite_version()"].(string)
|
|
}
|
|
return result, err
|
|
}
|
|
|
|
func (d *Database) UserVersion() (int64, error) {
|
|
|
|
var result int64
|
|
userversion, err := d.ReadRecords("PRAGMA user_version;")
|
|
if len(userversion) == 1 {
|
|
result = userversion[0]["user_version"].(int64)
|
|
}
|
|
return result, err
|
|
}
|
|
|
|
func (d *Database) Begin() *Transaction {
|
|
tx, err := d.database.Begin()
|
|
return &Transaction{tx, err}
|
|
}
|
|
|
|
func (t *Transaction) Next(action Action) *Transaction {
|
|
if t.err != nil {
|
|
return t
|
|
}
|
|
t.err = action(t.tx)
|
|
return t
|
|
}
|
|
|
|
func (t *Transaction) End() error {
|
|
|
|
if t.err != nil {
|
|
err := t.tx.Rollback()
|
|
if err != nil {
|
|
t.err = errors.Join(t.err, err)
|
|
}
|
|
return t.err
|
|
}
|
|
t.err = t.tx.Commit()
|
|
return t.err
|
|
}
|
|
|
|
func (t *Transaction) GetRecord(tablename string, idfield string, key any, output Record) *Transaction {
|
|
|
|
if t.err != nil {
|
|
return t
|
|
}
|
|
query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield)
|
|
res, err := t.tx.Query(query, key)
|
|
if err != nil {
|
|
t.err = err
|
|
return t
|
|
}
|
|
defer res.Close()
|
|
result, err := Rows2record(res)
|
|
if err != nil {
|
|
t.err = err
|
|
return t
|
|
}
|
|
for k := range output {
|
|
delete(output, k)
|
|
}
|
|
for k, v := range result {
|
|
output[k] = v
|
|
}
|
|
return t
|
|
}
|
|
|
|
func (t *Transaction) UpsertRecord(tablename string, idfield string, record Record, output Record) *Transaction {
|
|
|
|
if t.err != nil {
|
|
return t
|
|
}
|
|
result, err := upsert(t.tx, tablename, idfield, record)
|
|
if err != nil {
|
|
t.err = err
|
|
return t
|
|
}
|
|
for k := range output {
|
|
delete(output, k)
|
|
}
|
|
for k, v := range result {
|
|
output[k] = v
|
|
}
|
|
return t
|
|
}
|
|
|
|
func (t *Transaction) DeleteRecord(tablename string, idfield string, id any) *Transaction {
|
|
|
|
if t.err != nil {
|
|
return t
|
|
}
|
|
err := deleteRecord(t.tx, tablename, idfield, id)
|
|
if err != nil {
|
|
t.err = err
|
|
}
|
|
return t
|
|
}
|
|
|
|
// returns a value of the provided type, if the field exist and if it can be cast into the provided type parameter
|
|
func Value[T any](rec Record, field string) (value T, ok bool) {
|
|
var v any
|
|
if v, ok = rec[field]; ok {
|
|
value, ok = v.(T)
|
|
}
|
|
return
|
|
}
|
|
|
|
// don't report an error if there are simply just 'no rows found'
|
|
func NoRowsOk(recs []Record, err error) ([]Record, error) {
|
|
if err != nil && err.Error() != "no rows found" {
|
|
return recs, err
|
|
}
|
|
return recs, nil
|
|
}
|