2025-11-11 06:30:02 +01:00

429 lines
9.5 KiB
Go

package sqlite // name the package as you see fit
import (
"database/sql"
"errors"
"fmt"
"os"
"strconv"
"strings"
_ "modernc.org/sqlite"
)
// This is the data type to exchange data with the database
type Record = map[string]any
type Database struct {
databaseName string
database *sql.DB
}
type Transaction struct {
tx *sql.Tx
err error
}
type Action func(tx *sql.Tx) error
func New(DBName string) (*Database, error) {
return &Database{databaseName: DBName}, nil
}
func (d *Database) Close() error {
return d.database.Close()
}
// provides access to the internal database object
func (d *Database) DB() *sql.DB {
return d.database
}
func (d *Database) Name() string {
return d.databaseName
}
func (d *Database) Open() (err error) {
d.database, err = openSqliteDB(d.databaseName)
return err
}
func (d *Database) OpenInMemory() (err error) {
d.database, err = sql.Open("sqlite", ":memory:")
return err
}
func openSqliteDB(databasefilename string) (*sql.DB, error) {
_, err := os.Stat(databasefilename)
if errors.Is(err, os.ErrNotExist) {
return createDB(databasefilename)
}
if err != nil {
return nil, err
}
return sql.Open("sqlite", databasefilename)
}
func createDB(dbfileName string) (*sql.DB, error) {
query := `
PRAGMA page_size = 4096;
PRAGMA synchronous = off;
PRAGMA foreign_keys = off;
PRAGMA journal_mode = WAL;
PRAGMA user_version = 1;
`
db, err := sql.Open("sqlite", dbfileName)
if err != nil {
return nil, err
}
_, err = db.Exec(query)
if err != nil {
return nil, err
}
return db, nil
}
func (d *Database) TableList() (result []Record, err error) {
return d.ReadRecords("select name from sqlite_master where type='table';")
}
func (d *Database) ReadTable(tablename string) (result []Record, err error) {
return d.ReadRecords(fmt.Sprintf("select * from '%s';", tablename))
}
func (d *Database) ReadRecords(query string, args ...any) (result []Record, err error) {
rows, err := d.DB().Query(query, args...)
if err != nil {
return result, err
}
defer rows.Close()
return Rows2records(rows)
}
func (d *Database) GetRecord(tablename string, idfield string, key any) (result Record, err error) {
query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield)
res, err := d.DB().Query(query, key)
if err != nil {
return result, err
}
defer res.Close()
return Rows2record(res)
}
func (d *Database) UpsertRecord(tablename string, idfield string, record Record) (result Record, err error) {
return upsert(d.DB(), tablename, idfield, record)
}
func (d *Database) DeleteRecord(tablename string, idfield string, id any) (err error) {
return deleteRecord(d.DB(), tablename, idfield, id)
}
// *sql.DB and *sql.Tx both have a method named 'Query',
// this way they can both be passed into upsert and deleteRecord function
type iquery interface {
Query(query string, args ...any) (*sql.Rows, error)
}
func upsert(t iquery, tablename string, idfield string, record Record) (result Record, err error) {
fields := []string{}
data := []any{}
for k, v := range record {
fields = append(fields, k)
data = append(data, v)
}
query, err := buildUpsertCommand(tablename, idfield, fields)
if err != nil {
return result, err
}
res, err := t.Query(query, data...) // res contains the full record - see SQLite: RETURNING *
if err != nil {
return result, err
}
defer res.Close()
return Rows2record(res)
}
func deleteRecord(t iquery, tablename string, idfield string, id any) (err error) {
query := fmt.Sprintf("DELETE FROM \"%s\" WHERE \"%s\" = ?;", tablename, idfield)
_, err = t.Query(query, id)
return err
}
func buildUpsertCommand(tablename string, idfield string, fields []string) (string, error) {
var sb strings.Builder
sb.Grow(256 + len(fields)*20) // rough preallocation
// INSERT INTO
sb.WriteString(`INSERT INTO "`)
sb.WriteString(tablename)
sb.WriteString(`"(`)
for i, f := range fields {
sb.WriteString(` "`)
sb.WriteString(f)
sb.WriteByte('"')
if i < len(fields)-1 {
sb.WriteByte(',')
}
}
sb.WriteString(")\n\tVALUES(")
// VALUES
for i := 0; i < len(fields); i++ {
sb.WriteString(" ?")
sb.Write(strconv.AppendInt(nil, int64(i+1), 10))
if i < len(fields)-1 {
sb.WriteByte(',')
}
}
sb.WriteString(")\n\tON CONFLICT(\"")
sb.WriteString(tablename)
sb.WriteString(`"."`)
sb.WriteString(idfield)
sb.WriteString("\")\n\tDO UPDATE SET ")
// UPDATE SET
for i, f := range fields {
sb.WriteByte('"')
sb.WriteString(f)
sb.WriteString(`"= ?`)
sb.Write(strconv.AppendInt(nil, int64(i+1), 10))
if i < len(fields)-1 {
sb.WriteByte(',')
}
}
sb.WriteString("\n\tRETURNING *;")
return sb.String(), nil
}
// func buildUpsertCommand(tablename string, idfield string, fields []string) (result string, err error) {
// pname := map[string]string{} // assign correct index for parameter name
// // parameter position, starts at 1 in sql! So it needs to be calculated by function pname inside template
// for i, k := range fields {
// pname[k] = strconv.Itoa(i + 1)
// }
// funcMap := template.FuncMap{
// "pname": func(fieldname string) string {
// return pname[fieldname]
// },
// }
// tableDef := struct {
// Tablename string
// KeyField string
// LastField int
// FieldNames []string
// }{
// Tablename: tablename,
// KeyField: idfield,
// LastField: len(fields) - 1,
// FieldNames: fields,
// }
// var templString = `{{$last := .LastField}}INSERT INTO "{{ .Tablename }}"({{ range $i,$el := .FieldNames }} "{{$el}}"{{if ne $i $last}},{{end}}{{end}})
// VALUES({{ range $i,$el := .FieldNames }} ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}})
// ON CONFLICT("{{ .Tablename }}"."{{.KeyField}}")
// DO UPDATE SET {{ range $i,$el := .FieldNames }}"{{$el}}"= ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}}
// RETURNING *;`
// dbTempl, err := template.New("upsertDB").Funcs(funcMap).Parse(templString)
// if err != nil {
// return result, err
// }
// var templBytes bytes.Buffer
// err = dbTempl.Execute(&templBytes, tableDef)
// if err != nil {
// return result, err
// }
// return templBytes.String(), nil
// }
func Rows2record(rows *sql.Rows) (Record, error) {
columns, err := rows.Columns()
if err != nil {
return nil, err
}
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
result := Record{}
for rows.Next() {
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
for i, col := range columns {
result[col] = values[i]
}
}
if len(result) == 0 {
return nil, errors.New("no rows found")
}
return result, nil
}
func Rows2records(rows *sql.Rows) ([]Record, error) {
columns, err := rows.Columns()
if err != nil {
return nil, err
}
recLength := len(columns)
results := []Record{}
for rows.Next() {
values := make([]any, recLength)
valuePtrs := make([]any, recLength)
for i := range values {
valuePtrs[i] = &values[i]
}
record := Record{}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
for i, col := range columns {
record[col] = values[i]
}
results = append(results, record)
}
if len(results) == 0 {
return nil, errors.New("no rows found")
}
return results, nil
}
func (d *Database) Version() (string, error) {
result := ""
sqliteversion, err := d.ReadRecords("SELECT sqlite_version();")
if len(sqliteversion) == 1 {
result = sqliteversion[0]["sqlite_version()"].(string)
}
return result, err
}
func (d *Database) UserVersion() (int64, error) {
var result int64
userversion, err := d.ReadRecords("PRAGMA user_version;")
if len(userversion) == 1 {
result = userversion[0]["user_version"].(int64)
}
return result, err
}
func (d *Database) Begin() *Transaction {
tx, err := d.database.Begin()
return &Transaction{tx, err}
}
func (t *Transaction) Next(action Action) *Transaction {
if t.err != nil {
return t
}
t.err = action(t.tx)
return t
}
func (t *Transaction) End() error {
if t.err != nil {
err := t.tx.Rollback()
if err != nil {
t.err = errors.Join(t.err, err)
}
return t.err
}
t.err = t.tx.Commit()
return t.err
}
func (t *Transaction) GetRecord(tablename string, idfield string, key any, output Record) *Transaction {
if t.err != nil {
return t
}
query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield)
res, err := t.tx.Query(query, key)
if err != nil {
t.err = err
return t
}
defer res.Close()
result, err := Rows2record(res)
if err != nil {
t.err = err
return t
}
for k := range output {
delete(output, k)
}
for k, v := range result {
output[k] = v
}
return t
}
func (t *Transaction) UpsertRecord(tablename string, idfield string, record Record, output Record) *Transaction {
if t.err != nil {
return t
}
result, err := upsert(t.tx, tablename, idfield, record)
if err != nil {
t.err = err
return t
}
for k := range output {
delete(output, k)
}
for k, v := range result {
output[k] = v
}
return t
}
func (t *Transaction) DeleteRecord(tablename string, idfield string, id any) *Transaction {
if t.err != nil {
return t
}
err := deleteRecord(t.tx, tablename, idfield, id)
if err != nil {
t.err = err
}
return t
}
// returns a value of the provided type, if the field exist and if it can be cast into the provided type parameter
func Value[T any](rec Record, field string) (value T, ok bool) {
var v any
if v, ok = rec[field]; ok {
value, ok = v.(T)
}
return
}
// don't report an error if there are simply just 'no rows found'
func NoRowsOk(recs []Record, err error) ([]Record, error) {
if err != nil && err.Error() != "no rows found" {
return recs, err
}
return recs, nil
}