2025-11-11 06:30:02 +01:00

615 lines
20 KiB
Go

package sqlite // name the package as you see fit, it is intended to be vendored
import (
"bytes"
"database/sql"
"errors"
"context"
"fmt"
"os"
"strconv"
"text/template"
_ "modernc.org/sqlite"
)
/*
Package sqlite provides a simplified wrapper around the modernc.org/sqlite driver.
It aims to provide a convenient, developer-friendly interface for common database
operations, prioritizing ease of use with a map-based data exchange format (Record).
Key Concepts:
- Database Instance: A single `Database` struct instance manages the connection to
a specific database file or an in-memory database.
- Lifecycle: Use `New()` to create an instance, `Open()` or `OpenInMemory()`
to establish the connection, and `defer Close()` to release resources.
- Record Type: `type Record = map[string]any` is the primary type for exchanging
data with the database. Column names become map keys.
- Underlying DB Access: The `DB()` method provides access to the raw `*sql.DB`
object for operations not covered by the wrapper.
Features:
- Reading Data:
- `ReadTable(tablename string)`: Reads all rows and columns from a specified table.
- `ReadRecords(query string, args ...any)`: Executes a custom SQL SELECT query
with parameterized arguments and returns multiple records.
- `GetRecord(tablename string, idfield string, key any)`: Retrieves a single
record from a table based on a unique identifier.
- Writing Data:
- `UpsertRecord(tablename string, idfield string, record Record)`: Inserts a new
record or updates an existing one based on the value of the `idfield`.
Uses SQLite's `ON CONFLICT` clause.
- Supports partial updates: Only include fields you want to insert/update in the `Record`.
- Returns the full resulting record (including auto-generated IDs) using `RETURNING *`.
- Deleting Data:
- `DeleteRecord(tablename string, idfield string, id any)`: Deletes a single
record from a table based on its identifier.
- Metadata:
- `TableList()`: Lists all tables in the database.
- `Version()`: Gets the SQLite library version.
- `UserVersion()`: Gets the database's user_version PRAGMA.
Transaction Handling:
- `Begin()`: Starts a new database transaction, returning a `*Transaction` object.
- Chaining: Transaction methods (`GetRecord`, `UpsertRecord`, `DeleteRecord`, `Next`)
return the `*Transaction` object, allowing operations to be chained.
- Error Propagation: If any operation within a transaction chain fails, the error
is stored in the `Transaction` object (`tx.Err()`), and subsequent chained
operations become no-ops.
- `Next(action Action)`: Allows executing custom logic within the transaction
by providing a function that receives the raw `*sql.Tx`.
- `End()`: Finalizes the transaction. If `tx.Err()` is non-nil, it performs a
ROLLBACK; otherwise, it performs a COMMIT. Returns the accumulated error.
Helper Functions:
- `ValueT any`: A generic helper to safely extract
and type-assert a value from a `Record` map.
- `NoRowsOk([]Record, error)`: A helper to wrap calls that might return
`sql.ErrNoRows` and treat that specific error as a non-error case, returning
nil records and a nil error.
Prerequisites:
- For `UpsertRecord` to function correctly, the target table must have a unique
index defined on the specified `idfield`.
- It is highly recommended that the `idfield` is an `INTEGER PRIMARY KEY AUTOINCREMENT`
to leverage SQLite's built-in ID generation and efficient lookups.
Shortcomings and Important Considerations:
- SQL Injection Risk:
- Identifiers: Table names, field names, and record keys (used as field names)
are validated to contain only alphanumeric characters and underscores. They are
also quoted by the library. This significantly mitigates SQL injection risks
through identifiers. However, the caller MUST still ensure that these identifiers
refer to the *intended* database objects.
- Query Structure: For `ReadRecords` and `Transaction.Next` actions, if the raw
SQL query string itself is constructed from untrusted user input, it remains a
potential SQL injection vector. Parameterization is used by this library (and
`database/sql`) only for *values*, not for the query structure or identifiers
within a user-provided query string.
- Simplicity over Edge Cases: This is a simplified layer. More complex scenarios
or advanced SQLite features might require using the underlying `*sql.DB` object
via the `DB()` method.
- Room for Improvement: As a fresh implementation, there is potential for
further optimization and refinement.
Implementation Details:
- Uses the `modernc.org/sqlite` driver.
- SQL commands for `UpsertRecord` are dynamically generated using Go's `text/template`.
- Internal interfaces (`iquery`, `iExec`) are used to allow functions like `upsert`
and `deleteRecord` to work seamlessly with both `*sql.DB` and `*sql.Tx`.
Unit Tests:
- The package includes unit tests (`database_test.go`, `transaction_test.go`, `helpers_test.go`)
covering core functionality and transaction handling.
*/
// ErrInvalidIdentifier is returned when a table or column name contains disallowed characters.
var ErrInvalidIdentifier = errors.New("invalid identifier: contains disallowed characters")
// This is the data type to exchange data with the database
type Record = map[string]any
type Database struct {
databaseName string
database *sql.DB
}
func New(DBName string) *Database {
return &Database{databaseName: DBName}
}
func (d *Database) Close() error {
return d.database.Close()
}
// provides access to the internal database object
func (d *Database) DB() *sql.DB {
return d.database
}
func (d *Database) Name() string {
return d.databaseName
}
// basePragmas returns a string of common PRAGMA settings for SQLite.
// It excludes user_version, which is typically managed by schema migrations.
func basePragmas() string {
return `
PRAGMA page_size = 4096;
PRAGMA synchronous = NORMAL;
PRAGMA foreign_keys = ON;
PRAGMA journal_mode = WAL;
`
}
func (d *Database) Open(ctx context.Context) (err error) {
d.database, err = openSqliteDB(ctx, d.databaseName)
return err
}
func (d *Database) OpenInMemory(ctx context.Context) (err error) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
return err
}
// Apply base PRAGMAs for consistency in in-memory databases.
_, err = db.ExecContext(ctx, basePragmas())
d.database = db
return err
}
func openSqliteDB(ctx context.Context, databasefilename string) (*sql.DB, error) {
_, err := os.Stat(databasefilename)
if errors.Is(err, os.ErrNotExist) {
return createDB(ctx, databasefilename)
}
if err != nil {
return nil, err
}
return sql.Open("sqlite", databasefilename)
}
func createDB(ctx context.Context, dbfileName string) (*sql.DB, error) {
// Apply base pragmas and set initial user_version for new database files.
query := basePragmas() + "PRAGMA user_version = 1;\n"
db, err := sql.Open("sqlite", dbfileName)
if err != nil {
return nil, err
}
_, err = db.ExecContext(ctx, query)
if err != nil {
db.Close() // Best effort to close if ExecContext fails
os.Remove(dbfileName) // Best effort to remove partially created file
return nil, err
}
return db, nil
}
func (d *Database) TableList(ctx context.Context) (result []Record, err error) {
return d.ReadRecords(ctx, "select name from sqlite_master where type='table';")
}
func (d *Database) ReadTable(ctx context.Context, tablename string) (result []Record, err error) {
if !isValidIdentifier(tablename) {
return nil, fmt.Errorf("ReadTable: %w: table name '%s'", ErrInvalidIdentifier, tablename)
}
return d.ReadRecords(ctx, fmt.Sprintf("select * from \"%s\";", tablename)) // Use double quotes for identifiers
}
func (d *Database) ReadRecords(ctx context.Context, query string, args ...any) (result []Record, err error) {
// Note: For ReadRecords, the query string itself is provided by the caller.
// The library cannot validate the structure of this query beyond what the driver does.
// The SQL injection caveat for arbitrary query strings remains critical here.
rows, err := d.DB().QueryContext(ctx, query, args...)
if err != nil {
return result, err
}
defer rows.Close()
return Rows2records(rows)
}
func (d *Database) GetRecord(ctx context.Context, tablename string, idfield string, key any) (result Record, err error) {
if !isValidIdentifier(tablename) {
return nil, fmt.Errorf("GetRecord: %w: table name '%s'", ErrInvalidIdentifier, tablename)
}
if !isValidIdentifier(idfield) {
return nil, fmt.Errorf("GetRecord: %w: id field '%s'", ErrInvalidIdentifier, idfield)
}
query := fmt.Sprintf("select * from \"%s\" where \"%s\" = ?;", tablename, idfield) // Quote identifiers
res, err := d.DB().QueryContext(ctx, query, key)
if err != nil {
return result, err
}
defer res.Close()
return Rows2record(res)
}
func (d *Database) UpsertRecord(ctx context.Context, tablename string, idfield string, record Record) (result Record, err error) {
if !isValidIdentifier(tablename) {
return nil, fmt.Errorf("UpsertRecord: %w: table name '%s'", ErrInvalidIdentifier, tablename)
}
if !isValidIdentifier(idfield) {
return nil, fmt.Errorf("UpsertRecord: %w: id field '%s'", ErrInvalidIdentifier, idfield)
}
return upsert(ctx, d.DB(), tablename, idfield, record)
}
func (d *Database) DeleteRecord(ctx context.Context, tablename string, idfield string, id any) (err error) {
// Validation for tablename and idfield will be done by deleteRecord internal helper
// to ensure consistency for both Database and Transaction calls.
return deleteRecord(ctx, d.DB(), tablename, idfield, id)
}
// *sql.DB and *sql.Tx both have a method named 'Query',
// this way they can both be passed into upsert and deleteRecord function
type iqueryContext interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
// iExec is an interface satisfied by both *sql.DB and *sql.Tx for Exec method
type iExecContext interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
func upsert(ctx context.Context, q iqueryContext, tablename string, idfield string, record Record) (result Record, err error) {
// tablename and idfield are assumed to be validated by the public-facing methods (Database.UpsertRecord, Transaction.UpsertRecord)
fields := []string{}
data := []any{}
for k, v := range record {
if !isValidIdentifier(k) {
return nil, fmt.Errorf("upsert: %w: field name '%s'", ErrInvalidIdentifier, k)
}
fields = append(fields, k)
data = append(data, v)
}
// Ensure idfield is part of the record if it's used for conflict and update,
// or handle cases where it might only be for conflict target and not in SET.
// The current buildUpsertCommand uses all fields from the record for the SET clause.
if _, present := record[idfield]; !present && len(record) > 0 {
// This situation is complex: if idfield is not in the record,
// it implies it might be auto-generated on INSERT, but for UPDATE,
// it's needed to identify the row. The ON CONFLICT target uses idfield.
// The current template includes all record fields in the SET clause.
// If idfield is not in record, it won't be in the SET clause unless explicitly added.
// For simplicity and current template, we assume if idfield is for update, it should be in the record.
}
if len(fields) == 0 {
return nil, errors.New("UpsertRecord: input record cannot be empty")
}
query, err := buildUpsertCommand(tablename, idfield, fields)
if err != nil {
return result, err
}
res, err := q.QueryContext(ctx, query, data...) // res contains the full record - see SQLite: RETURNING *
if err != nil {
return result, err
}
defer res.Close()
return Rows2record(res)
}
func deleteRecord(ctx context.Context, e iExecContext, tablename string, idfield string, id any) (err error) {
if !isValidIdentifier(tablename) {
return fmt.Errorf("deleteRecord: %w: table name '%s'", ErrInvalidIdentifier, tablename)
}
if !isValidIdentifier(idfield) {
return fmt.Errorf("deleteRecord: %w: id field '%s'", ErrInvalidIdentifier, idfield)
}
query := fmt.Sprintf("DELETE FROM \"%s\" WHERE \"%s\" = ?;", tablename, idfield)
_, err = e.ExecContext(ctx, query, id)
// Note: err could be sql.ErrNoRows if the driver/db supports it for Exec,
// or nil if delete affected 0 rows. Caller might want to check result.RowsAffected().
// For simplicity here, we just return the error from Exec.
return err
}
func buildUpsertCommand(tablename string, idfield string, fields []string) (result string, err error) {
// Assumes tablename, idfield, and all elements in fields are already validated
// by the calling function (e.g., upsert).
// And that fields is not empty.
pname := map[string]string{} // assign correct index for parameter name
// parameter position, starts at 1 in sql! So it needs to be calculated by function pname inside template
for i, k := range fields {
pname[k] = strconv.Itoa(i + 1)
}
funcMap := template.FuncMap{
"pname": func(fieldname string) string {
return pname[fieldname]
},
}
tableDef := struct {
Tablename string
KeyField string
LastField int
FieldNames []string
}{
Tablename: tablename,
KeyField: idfield,
LastField: len(fields) - 1,
FieldNames: fields,
}
var templString = `{{$last := .LastField}}INSERT INTO "{{ .Tablename }}"({{ range $i,$el := .FieldNames }} "{{$el}}"{{if ne $i $last}},{{end}}{{end}})
VALUES({{ range $i,$el := .FieldNames }} ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}})
ON CONFLICT("{{ .Tablename }}"."{{.KeyField}}")
DO UPDATE SET {{ range $i,$el := .FieldNames }}"{{$el}}"= ?{{pname $el}}{{if ne $i $last}},{{end}}{{end}}
RETURNING *;`
dbTempl, err := template.New("upsertDB").Funcs(funcMap).Parse(templString)
if err != nil {
return result, err
}
var templBytes bytes.Buffer
err = dbTempl.Execute(&templBytes, tableDef)
if err != nil {
return result, err
}
return templBytes.String(), nil
}
func Rows2record(rows *sql.Rows) (Record, error) {
columns, err := rows.Columns()
if err != nil {
return nil, err
}
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
result := Record{}
if !rows.Next() {
if err := rows.Err(); err != nil { // Check for errors during iteration attempt
return nil, err
}
return nil, sql.ErrNoRows // Standard error for no rows
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
for i, col := range columns {
result[col] = values[i]
}
// Check for errors encountered during iteration (e.g., if Next() was called multiple times).
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func Rows2records(rows *sql.Rows) ([]Record, error) {
columns, err := rows.Columns()
if err != nil {
return nil, err
}
recLength := len(columns)
results := []Record{}
for rows.Next() {
valuePtrs := make([]any, recLength)
values := make([]any, recLength)
for i := range values {
valuePtrs[i] = &values[i]
}
record := Record{}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
for i, col := range columns {
record[col] = values[i]
}
results = append(results, record)
}
// Check for errors encountered during iteration.
if err := rows.Err(); err != nil {
return nil, err
}
if len(results) == 0 {
// For a function returning a slice, an empty slice and nil error is often preferred for "no rows".
// However, if the expectation is that Rows2records is used where rows *should* exist, sql.ErrNoRows is appropriate.
return nil, sql.ErrNoRows // Or: return []Record{}, nil if empty slice is the desired "no rows" outcome
}
return results, nil
}
func (d *Database) Version(ctx context.Context) (string, error) {
var version string
err := d.DB().QueryRowContext(ctx, "SELECT sqlite_version();").Scan(&version)
return version, err
}
func (d *Database) UserVersion(ctx context.Context) (int64, error) {
var result int64
// PRAGMA user_version; returns a single row with a single column named "user_version".
// QueryRow().Scan() is appropriate here.
err := d.DB().QueryRowContext(ctx, "PRAGMA user_version;").Scan(&result)
return result, err
}
func (d *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) *Transaction {
tx, err := d.database.BeginTx(ctx, opts)
return &Transaction{tx, err}
}
type Transaction struct {
tx *sql.Tx
err error
}
// Err returns the current error state of the transaction.
func (t *Transaction) Err() error {
return t.err
}
type Action func(ctx context.Context, tx *sql.Tx) error
func (t *Transaction) Next(ctx context.Context, action Action) *Transaction {
if t.err != nil {
return t
}
t.err = action(ctx, t.tx)
return t
}
func (t *Transaction) End() error {
if t.tx == nil { // Transaction was never begun or already ended
return t.err // Return any prior error
}
if t.err != nil {
err := t.tx.Rollback() // Rollback does not take context
if err != nil {
t.err = errors.Join(t.err, err)
}
return t.err
}
t.err = t.tx.Commit()
return t.err
}
func (t *Transaction) GetRecord(ctx context.Context, tablename string, idfield string, key any, output Record) *Transaction {
if !isValidIdentifier(tablename) {
t.err = fmt.Errorf("Transaction.GetRecord: %w: table name '%s'", ErrInvalidIdentifier, tablename)
return t
}
if !isValidIdentifier(idfield) {
t.err = fmt.Errorf("Transaction.GetRecord: %w: id field '%s'", ErrInvalidIdentifier, idfield)
return t
}
if t.err != nil {
return t
}
query := fmt.Sprintf("select * from \"%s\" where \"%s\" = ?;", tablename, idfield) // Quote identifiers
res, err := t.tx.QueryContext(ctx, query, key)
if err != nil {
t.err = err
return t
}
defer res.Close()
result, err := Rows2record(res)
if err != nil {
t.err = err
return t
}
for k := range output {
delete(output, k)
}
for k, v := range result {
output[k] = v
}
return t
}
func (t *Transaction) UpsertRecord(ctx context.Context, tablename string, idfield string, record Record, output Record) *Transaction {
if !isValidIdentifier(tablename) {
t.err = fmt.Errorf("Transaction.UpsertRecord: %w: table name '%s'", ErrInvalidIdentifier, tablename)
return t
}
if !isValidIdentifier(idfield) {
t.err = fmt.Errorf("Transaction.UpsertRecord: %w: id field '%s'", ErrInvalidIdentifier, idfield)
return t
}
if t.err != nil {
return t
}
result, err := upsert(ctx, t.tx, tablename, idfield, record)
if err != nil {
t.err = err
return t
}
for k := range output {
delete(output, k)
}
for k, v := range result {
output[k] = v
}
return t
}
func (t *Transaction) DeleteRecord(ctx context.Context, tablename string, idfield string, id any) *Transaction {
// Validation will be done by the internal deleteRecord helper
// if !isValidIdentifier(tablename) {
// t.err = fmt.Errorf("Transaction.DeleteRecord: %w: table name '%s'", ErrInvalidIdentifier, tablename)
// return t
// }
// if !isValidIdentifier(idfield) {
// t.err = fmt.Errorf("Transaction.DeleteRecord: %w: id field '%s'", ErrInvalidIdentifier, idfield)
// return t
// }
if t.err != nil {
return t
}
err := deleteRecord(ctx, t.tx, tablename, idfield, id) // t.tx satisfies iExecContext
if err != nil {
t.err = err
}
return t
}
// returns a value of the provided type, if the field exist and if it can be cast into the provided type parameter
func Value[T any](rec Record, field string) (value T, ok bool) {
var v any
// No validation for 'field' here as it's used to access a map key from an existing Record,
// not to construct SQL.
if v, ok = rec[field]; ok {
value, ok = v.(T)
}
return
}
// don't report an error if there are simply just 'no rows found'
func NoRowsOk(recs []Record, err error) ([]Record, error) {
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// Return an empty, non-nil slice and nil error to indicate "no rows found, but that's okay".
// This makes it safer for callers to immediately use len() or range over the result.
return []Record{}, nil
}
return recs, err
}
return recs, nil
}
// isValidIdentifier checks if the given string is a safe identifier.
// Allows alphanumeric characters and underscores. Must not be empty.
func isValidIdentifier(identifier string) bool {
if len(identifier) == 0 {
return false
}
for _, r := range identifier {
if !((r >= 'a' && r <= 'z') ||
(r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') ||
r == '_') {
return false
}
}
return true
}