package sqlite // name the package as you see fit import ( "database/sql" "errors" "fmt" "os" "strconv" "strings" _ "modernc.org/sqlite" ) // This is the data type to exchange data with the database type Record = map[string]any type Database struct { databaseName string database *sql.DB } type Transaction struct { tx *sql.Tx err error } type Action func(tx *sql.Tx) error func New(DBName string) (*Database, error) { return &Database{databaseName: DBName}, nil } func (d *Database) Close() error { return d.database.Close() } // provides access to the internal database object func (d *Database) DB() *sql.DB { return d.database } func (d *Database) Name() string { return d.databaseName } func (d *Database) Open() (err error) { d.database, err = openSqliteDB(d.databaseName) return err } func (d *Database) OpenInMemory() (err error) { d.database, err = sql.Open("sqlite", ":memory:") return err } func openSqliteDB(databasefilename string) (*sql.DB, error) { _, err := os.Stat(databasefilename) if errors.Is(err, os.ErrNotExist) { return createDB(databasefilename) } if err != nil { return nil, err } return sql.Open("sqlite", databasefilename) } func createDB(dbfileName string) (*sql.DB, error) { query := ` PRAGMA page_size = 4096; PRAGMA synchronous = off; PRAGMA foreign_keys = off; PRAGMA journal_mode = WAL; PRAGMA user_version = 1; ` db, err := sql.Open("sqlite", dbfileName) if err != nil { return nil, err } _, err = db.Exec(query) if err != nil { return nil, err } return db, nil } func (d *Database) TableList() (result []Record, err error) { return d.ReadRecords("select name from sqlite_master where type='table';") } func (d *Database) ReadTable(tablename string) (result []Record, err error) { return d.ReadRecords(fmt.Sprintf("select * from '%s';", tablename)) } func (d *Database) ReadRecords(query string, args ...any) (result []Record, err error) { rows, err := d.DB().Query(query, args...) if err != nil { return result, err } defer rows.Close() return Rows2records(rows) } func (d *Database) GetRecord(tablename string, idfield string, key any) (result Record, err error) { query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield) res, err := d.DB().Query(query, key) if err != nil { return result, err } defer res.Close() return Rows2record(res) } func (d *Database) UpsertRecord(tablename string, idfield string, record Record) (result Record, err error) { return upsert(d.DB(), tablename, idfield, record) } func (d *Database) DeleteRecord(tablename string, idfield string, id any) (err error) { return deleteRecord(d.DB(), tablename, idfield, id) } // *sql.DB and *sql.Tx both have a method named 'Query', // this way they can both be passed into upsert and deleteRecord function type iquery interface { Query(query string, args ...any) (*sql.Rows, error) } func upsert(t iquery, tablename string, idfield string, record Record) (result Record, err error) { fields := []string{} data := []any{} for k, v := range record { fields = append(fields, k) data = append(data, v) } query, err := buildUpsertCommand(tablename, idfield, fields) if err != nil { return result, err } res, err := t.Query(query, data...) // res contains the full record - see SQLite: RETURNING * if err != nil { return result, err } defer res.Close() return Rows2record(res) } func deleteRecord(t iquery, tablename string, idfield string, id any) (err error) { query := fmt.Sprintf("DELETE FROM \"%s\" WHERE \"%s\" = ?;", tablename, idfield) _, err = t.Query(query, id) return err } func buildUpsertCommand(tablename string, idfield string, fields []string) (string, error) { var sb strings.Builder sb.Grow(256 + len(fields)*20) // rough preallocation // INSERT INTO sb.WriteString(`INSERT INTO "`) sb.WriteString(tablename) sb.WriteString(`"(`) for i, f := range fields { sb.WriteString(` "`) sb.WriteString(f) sb.WriteByte('"') if i < len(fields)-1 { sb.WriteByte(',') } } sb.WriteString(")\n\tVALUES(") // VALUES for i := 0; i < len(fields); i++ { sb.WriteString(" ?") sb.Write(strconv.AppendInt(nil, int64(i+1), 10)) if i < len(fields)-1 { sb.WriteByte(',') } } sb.WriteString(")\n\tON CONFLICT(\"") sb.WriteString(tablename) sb.WriteString(`"."`) sb.WriteString(idfield) sb.WriteString("\")\n\tDO UPDATE SET ") // UPDATE SET for i, f := range fields { sb.WriteByte('"') sb.WriteString(f) sb.WriteString(`"= ?`) sb.Write(strconv.AppendInt(nil, int64(i+1), 10)) if i < len(fields)-1 { sb.WriteByte(',') } } sb.WriteString("\n\tRETURNING *;") return sb.String(), nil } func Rows2record(rows *sql.Rows) (Record, error) { columns, err := rows.Columns() if err != nil { return nil, err } values := make([]any, len(columns)) valuePtrs := make([]any, len(columns)) for i := range values { valuePtrs[i] = &values[i] } result := Record{} for rows.Next() { if err := rows.Scan(valuePtrs...); err != nil { return nil, err } for i, col := range columns { result[col] = values[i] } } if len(result) == 0 { return nil, errors.New("no rows found") } return result, nil } func Rows2records(rows *sql.Rows) ([]Record, error) { columns, err := rows.Columns() if err != nil { return nil, err } recLength := len(columns) results := []Record{} for rows.Next() { values := make([]any, recLength) valuePtrs := make([]any, recLength) for i := range values { valuePtrs[i] = &values[i] } record := Record{} if err := rows.Scan(valuePtrs...); err != nil { return nil, err } for i, col := range columns { record[col] = values[i] } results = append(results, record) } if len(results) == 0 { return nil, errors.New("no rows found") } return results, nil } func (d *Database) Version() (string, error) { result := "" sqliteversion, err := d.ReadRecords("SELECT sqlite_version();") if len(sqliteversion) == 1 { result = sqliteversion[0]["sqlite_version()"].(string) } return result, err } func (d *Database) UserVersion() (int64, error) { var result int64 userversion, err := d.ReadRecords("PRAGMA user_version;") if len(userversion) == 1 { result = userversion[0]["user_version"].(int64) } return result, err } func (d *Database) Begin() *Transaction { tx, err := d.database.Begin() return &Transaction{tx, err} } func (t *Transaction) Next(action Action) *Transaction { if t.err != nil { return t } t.err = action(t.tx) return t } func (t *Transaction) End() error { if t.err != nil { err := t.tx.Rollback() if err != nil { t.err = errors.Join(t.err, err) } return t.err } t.err = t.tx.Commit() return t.err } func (t *Transaction) GetRecord(tablename string, idfield string, key any, output Record) *Transaction { if t.err != nil { return t } query := fmt.Sprintf("select * from %s where %s = ?;", tablename, idfield) res, err := t.tx.Query(query, key) if err != nil { t.err = err return t } defer res.Close() result, err := Rows2record(res) if err != nil { t.err = err return t } for k := range output { delete(output, k) } for k, v := range result { output[k] = v } return t } func (t *Transaction) UpsertRecord(tablename string, idfield string, record Record, output Record) *Transaction { if t.err != nil { return t } result, err := upsert(t.tx, tablename, idfield, record) if err != nil { t.err = err return t } for k := range output { delete(output, k) } for k, v := range result { output[k] = v } return t } func (t *Transaction) DeleteRecord(tablename string, idfield string, id any) *Transaction { if t.err != nil { return t } err := deleteRecord(t.tx, tablename, idfield, id) if err != nil { t.err = err } return t } // returns a value of the provided type, if the field exist and if it can be cast into the provided type parameter func Value[T any](rec Record, field string) (value T, ok bool) { var v any if v, ok = rec[field]; ok { value, ok = v.(T) } return } // don't report an error if there are simply just 'no rows found' func NoRowsOk(recs []Record, err error) ([]Record, error) { if err != nil && err.Error() != "no rows found" { return recs, err } return recs, nil }