From 991950bdd4cf0dc563c44ccf6d8d7ecc602f68f3 Mon Sep 17 00:00:00 2001 From: thomashamburg Date: Mon, 22 Sep 2025 10:43:52 +0200 Subject: [PATCH] getting started --- mvwa/sqlite/database.go | 428 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 428 insertions(+) create mode 100644 mvwa/sqlite/database.go diff --git a/mvwa/sqlite/database.go b/mvwa/sqlite/database.go new file mode 100644 index 0000000..f3c4b83 --- /dev/null +++ b/mvwa/sqlite/database.go @@ -0,0 +1,428 @@ +package sqlite // name the package as you see fit + +import ( + "database/sql" + "errors" + "fmt" + "os" + "strconv" + "strings" + + _ "modernc.org/sqlite" +) + +// This is the data type to exchange data with the database +type Record = map[string]any + +type Database struct { + databaseName string + database *sql.DB +} + +type Transaction struct { + tx *sql.Tx + err error +} + +type Action func(tx *sql.Tx) error + +func New(DBName string) (*Database, error) { + return &Database{databaseName: DBName}, nil +} + +func (d *Database) Close() error { + return d.database.Close() +} + +// provides access to the internal database object +func (d *Database) DB() *sql.DB { + return d.database +} + +func (d *Database) Name() string { + return d.databaseName +} + +func (d *Database) Open() (err error) { + d.database, err = openSqliteDB(d.databaseName) + return err +} + +func (d *Database) OpenInMemory() (err error) { + d.database, err = sql.Open("sqlite", ":memory:") + return err +} + +func openSqliteDB(databasefilename string) (*sql.DB, error) { + + _, err := os.Stat(databasefilename) + if errors.Is(err, os.ErrNotExist) { + return createDB(databasefilename) + } + if err != nil { + return nil, err + } + return sql.Open("sqlite", databasefilename) + +} + +func createDB(dbfileName string) (*sql.DB, error) { + + query := ` + PRAGMA page_size = 4096; + PRAGMA synchronous = off; + PRAGMA foreign_keys = off; + PRAGMA journal_mode = WAL; + PRAGMA user_version = 1; + ` + db, err := sql.Open("sqlite", dbfileName) + if err != nil { + + return nil, err + } + _, err = db.Exec(query) + if err != nil { + return nil, err + } + return db, nil +} + +func (d *Database) TableList() (result []Record, err error) { + return d.ReadRecords("select name from sqlite_master where type='table';") +} + +func (d *Database) ReadTable(tablename string) (result []Record, err error) { + + return d.ReadRecords(fmt.Sprintf("select * from '%s';", tablename)) +} + +func (d *Database) ReadRecords(query string, args ...any) (result []Record, err error) { + + rows, err := d.DB().Query(query, args...) + if err != nil { + return result, err + } + defer rows.Close() + return Rows2records(rows) + +} + +func (d *Database) GetRecord(tablename string, idfield string, key any) (result Record, err error) { + + query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield) + res, err := d.DB().Query(query, key) + if err != nil { + return result, err + } + defer res.Close() + return Rows2record(res) + +} + +func (d *Database) UpsertRecord(tablename string, idfield string, record Record) (result Record, err error) { + + return upsert(d.DB(), tablename, idfield, record) + +} + +func (d *Database) DeleteRecord(tablename string, idfield string, id any) (err error) { + + return deleteRecord(d.DB(), tablename, idfield, id) + +} + +// *sql.DB and *sql.Tx both have a method named 'Query', +// this way they can both be passed into upsert and deleteRecord function +type iquery interface { + Query(query string, args ...any) (*sql.Rows, error) +} + +func upsert(t iquery, tablename string, idfield string, record Record) (result Record, err error) { + + fields := []string{} + data := []any{} + for k, v := range record { + fields = append(fields, k) + data = append(data, v) + } + query, err := buildUpsertCommand(tablename, idfield, fields) + if err != nil { + return result, err + } + res, err := t.Query(query, data...) // res contains the full record - see SQLite: RETURNING * + if err != nil { + return result, err + } + defer res.Close() + return Rows2record(res) + +} + +func deleteRecord(t iquery, tablename string, idfield string, id any) (err error) { + + query := fmt.Sprintf("DELETE FROM \"%s\" WHERE \"%s\" = ?;", tablename, idfield) + _, err = t.Query(query, id) + return err + +} + +func buildUpsertCommand(tablename string, idfield string, fields []string) (string, error) { + var sb strings.Builder + sb.Grow(256 + len(fields)*20) // rough preallocation + + // INSERT INTO + sb.WriteString(`INSERT INTO "`) + sb.WriteString(tablename) + sb.WriteString(`"(`) + for i, f := range fields { + sb.WriteString(` "`) + sb.WriteString(f) + sb.WriteByte('"') + if i < len(fields)-1 { + sb.WriteByte(',') + } + } + sb.WriteString(")\n\tVALUES(") + + // VALUES + for i := 0; i < len(fields); i++ { + sb.WriteString(" ?") + sb.Write(strconv.AppendInt(nil, int64(i+1), 10)) + if i < len(fields)-1 { + sb.WriteByte(',') + } + } + sb.WriteString(")\n\tON CONFLICT(\"") + sb.WriteString(tablename) + sb.WriteString(`"."`) + sb.WriteString(idfield) + sb.WriteString("\")\n\tDO UPDATE SET ") + + // UPDATE SET + for i, f := range fields { + sb.WriteByte('"') + sb.WriteString(f) + sb.WriteString(`"= ?`) + sb.Write(strconv.AppendInt(nil, int64(i+1), 10)) + if i < len(fields)-1 { + sb.WriteByte(',') + } + } + sb.WriteString("\n\tRETURNING *;") + + return sb.String(), nil +} + +// func buildUpsertCommand(tablename string, idfield string, fields []string) (result string, err error) { + +// pname := map[string]string{} // assign correct index for parameter name +// // parameter position, starts at 1 in sql! So it needs to be calculated by function pname inside template + +// for i, k := range fields { +// pname[k] = strconv.Itoa(i + 1) +// } +// funcMap := template.FuncMap{ +// "pname": func(fieldname string) string { +// return pname[fieldname] +// }, +// } +// tableDef := struct { +// Tablename string +// KeyField string +// LastField int +// FieldNames []string +// }{ +// Tablename: tablename, +// KeyField: idfield, +// LastField: len(fields) - 1, +// FieldNames: fields, +// } +// var templString = `{{$last := .LastField}}INSERT INTO "{{ .Tablename }}"({{ range $i,$el := .FieldNames }} "{{$el}}"{{if ne $i $last}},{{end}}{{end}}) +// VALUES({{ range $i,$el := .FieldNames }} ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}}) +// ON CONFLICT("{{ .Tablename }}"."{{.KeyField}}") +// DO UPDATE SET {{ range $i,$el := .FieldNames }}"{{$el}}"= ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}} +// RETURNING *;` + +// dbTempl, err := template.New("upsertDB").Funcs(funcMap).Parse(templString) +// if err != nil { +// return result, err +// } +// var templBytes bytes.Buffer +// err = dbTempl.Execute(&templBytes, tableDef) +// if err != nil { +// return result, err +// } +// return templBytes.String(), nil +// } + +func Rows2record(rows *sql.Rows) (Record, error) { + columns, err := rows.Columns() + if err != nil { + return nil, err + } + values := make([]any, len(columns)) + valuePtrs := make([]any, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + result := Record{} + for rows.Next() { + if err := rows.Scan(valuePtrs...); err != nil { + return nil, err + } + for i, col := range columns { + result[col] = values[i] + } + } + if len(result) == 0 { + return nil, errors.New("no rows found") + } + return result, nil +} + +func Rows2records(rows *sql.Rows) ([]Record, error) { + columns, err := rows.Columns() + if err != nil { + return nil, err + } + recLength := len(columns) + results := []Record{} + for rows.Next() { + values := make([]any, recLength) + valuePtrs := make([]any, recLength) + for i := range values { + valuePtrs[i] = &values[i] + } + record := Record{} + if err := rows.Scan(valuePtrs...); err != nil { + return nil, err + } + for i, col := range columns { + record[col] = values[i] + } + results = append(results, record) + } + if len(results) == 0 { + return nil, errors.New("no rows found") + } + return results, nil +} + +func (d *Database) Version() (string, error) { + result := "" + sqliteversion, err := d.ReadRecords("SELECT sqlite_version();") + if len(sqliteversion) == 1 { + result = sqliteversion[0]["sqlite_version()"].(string) + } + return result, err +} + +func (d *Database) UserVersion() (int64, error) { + + var result int64 + userversion, err := d.ReadRecords("PRAGMA user_version;") + if len(userversion) == 1 { + result = userversion[0]["user_version"].(int64) + } + return result, err +} + +func (d *Database) Begin() *Transaction { + tx, err := d.database.Begin() + return &Transaction{tx, err} +} + +func (t *Transaction) Next(action Action) *Transaction { + if t.err != nil { + return t + } + t.err = action(t.tx) + return t +} + +func (t *Transaction) End() error { + + if t.err != nil { + err := t.tx.Rollback() + if err != nil { + t.err = errors.Join(t.err, err) + } + return t.err + } + t.err = t.tx.Commit() + return t.err +} + +func (t *Transaction) GetRecord(tablename string, idfield string, key any, output Record) *Transaction { + + if t.err != nil { + return t + } + query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield) + res, err := t.tx.Query(query, key) + if err != nil { + t.err = err + return t + } + defer res.Close() + result, err := Rows2record(res) + if err != nil { + t.err = err + return t + } + for k := range output { + delete(output, k) + } + for k, v := range result { + output[k] = v + } + return t +} + +func (t *Transaction) UpsertRecord(tablename string, idfield string, record Record, output Record) *Transaction { + + if t.err != nil { + return t + } + result, err := upsert(t.tx, tablename, idfield, record) + if err != nil { + t.err = err + return t + } + for k := range output { + delete(output, k) + } + for k, v := range result { + output[k] = v + } + return t +} + +func (t *Transaction) DeleteRecord(tablename string, idfield string, id any) *Transaction { + + if t.err != nil { + return t + } + err := deleteRecord(t.tx, tablename, idfield, id) + if err != nil { + t.err = err + } + return t +} + +// returns a value of the provided type, if the field exist and if it can be cast into the provided type parameter +func Value[T any](rec Record, field string) (value T, ok bool) { + var v any + if v, ok = rec[field]; ok { + value, ok = v.(T) + } + return +} + +// don't report an error if there are simply just 'no rows found' +func NoRowsOk(recs []Record, err error) ([]Record, error) { + if err != nil && err.Error() != "no rows found" { + return recs, err + } + return recs, nil +}