diff --git a/cbor/cbor.go b/cbor/cbor.go index f3cdb2a..69ed614 100644 --- a/cbor/cbor.go +++ b/cbor/cbor.go @@ -1243,15 +1243,10 @@ func (e *Encoder) encodeTextOrBinary(rv reflect.Value) error { // Unaddressable arrays cannot be made into slices, so we must create a // slice and copy contents into it - slice := reflect.MakeSlice( - reflect.SliceOf(rv.Type().Elem()), - rv.Len(), - rv.Len(), - ) - if n := reflect.Copy(slice, rv); n != rv.Len() { + b = make([]byte, rv.Len()) + if n := reflect.Copy(reflect.ValueOf(b), rv); n != rv.Len() { panic("array contents were not fully copied into a slice for encoding") } - b = slice.Bytes() } info := u64Bytes(uint64(len(b))) diff --git a/examples/go.mod b/examples/go.mod index 108a3eb..9c26344 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -17,6 +17,7 @@ require ( github.com/fido-device-onboard/go-fdo/tpm v0.0.0-00010101000000-000000000000 github.com/google/go-tpm v0.9.2-0.20240920144513-364d5f2f78b9 github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba + github.com/syumai/workers v0.26.3 hermannm.dev/devlog v0.4.1 ) diff --git a/examples/go.sum b/examples/go.sum index 7cad18f..e79be3f 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -47,6 +47,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/syumai/workers v0.26.3 h1:AF+IBaRccbR4JIj2kNJLJblruPFMD/pAbzkopejGcP8= +github.com/syumai/workers v0.26.3/go.mod h1:ZnqmdiHNBrbxOLrZ/HJ5jzHy6af9cmiNZk10R9NrIEA= github.com/tetratelabs/wazero v1.8.1 h1:NrcgVbWfkWvVc4UtT4LRLDf91PsOzDzefMdwhLfA550= github.com/tetratelabs/wazero v1.8.1/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= diff --git a/examples/wasm/main.go b/examples/wasm/main.go new file mode 100644 index 0000000..7c4293b --- /dev/null +++ b/examples/wasm/main.go @@ -0,0 +1,147 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build tinygo + +// Package main implements a Rendezvous Server which can be compiled with +// TinyGo and run on Cloudflare Workers within the free tier (under reasonable +// load). +// +// Create a Cloudflare Worker with the name "rv" and start a new repo with the +// following configuration: +// +// wrangler.toml: +// +// name = "rv" +// main = "./build/worker.mjs" +// +// [build] +// command = """ +// go run github.com/syumai/workers/cmd/workers-assets-gen@v0.26.3 -mode tinygo && +// tinygo build -o ./build/app.wasm -target wasm -no-debug ./main.go +// """ +// +// [observability] +// enabled = true +// +// [[ d1_databases ]] +// binding = "RendezvousDB" +// database_name = "rv" +// database_id = "COPY_YOUR_UUID_HERE" +// +// Upon creating a D1 database instance with name "rv" and updating +// wrangler.toml, execute the following schema.sql file using: +// +// $ wrangler d1 execute rv --remote --file=./schema.sql +// +// schema.sql: +// +// PRAGMA foreign_keys = ON; +// CREATE TABLE IF NOT EXISTS sessions +// ( id BLOB PRIMARY KEY +// , protocol INTEGER NOT NULL +// ); +// CREATE TABLE IF NOT EXISTS to0_sessions +// ( session BLOB UNIQUE NOT NULL +// , nonce BLOB +// , FOREIGN KEY(session) REFERENCES sessions(id) ON DELETE CASCADE +// ); +// CREATE TABLE IF NOT EXISTS to1_sessions +// ( session BLOB UNIQUE NOT NULL +// , nonce BLOB +// , alg INTEGER +// , FOREIGN KEY(session) REFERENCES sessions(id) ON DELETE CASCADE +// ); +// CREATE TABLE IF NOT EXISTS rv_blobs +// ( guid BLOB PRIMARY KEY +// , rv BLOB NOT NULL +// , voucher BLOB NOT NULL +// , exp INTEGER NOT NULL +// ); +// CREATE TABLE IF NOT EXISTS trusted_emails +// ( email TEXT PRIMARY KEY +// ); +// CREATE TABLE IF NOT EXISTS trusted_owners +// ( pkix BLOB PRIMARY KEY +// , email TEXT NOT NULL +// , FOREIGN KEY(email) REFERENCES trusted_emails(email) ON DELETE CASCADE +// ); +// +// Add users by executing SQL with "wrangler d1 execute rv --remote" to add to +// the trusted_emails table. Then, add owner keys for users in the +// trusted_owners table. +// +// Setting up Cloudflare cron jobs to remove expired rendezvous blobs is left +// as an exercise for the implementer. +package main + +import ( + "context" + "crypto/x509" + "database/sql" + "errors" + "fmt" + "log/slog" + "net/http" + "os" + + "github.com/syumai/workers" + _ "github.com/syumai/workers/cloudflare/d1" + + "github.com/fido-device-onboard/go-fdo" + fdo_http "github.com/fido-device-onboard/go-fdo/http" + "github.com/fido-device-onboard/go-fdo/sqlite" +) + +const oneWeekInSeconds uint32 = 7 * 24 * 60 * 60 + +func main() { + db, err := sql.Open("d1", "RendezvousDB") + if err != nil { + slog.Error("d1 connect", "error", err) + os.Exit(1) + } + state := sqlite.New(db) + + // If building with Go instead of TinyGo, use: + // (*http.ServeMux).Handle("POST /fdo/101/{msg}", ...) + handler := http.NewServeMux() + handler.Handle("/fdo/101/", &fdo_http.Handler{ + TO0Responder: &fdo.TO0Server{ + Session: state, + RVBlobs: state, + AcceptVoucher: func(ctx context.Context, ov fdo.Voucher) (accept bool, err error) { + owner, err := ov.OwnerPublicKey() + if err != nil { + return false, fmt.Errorf("error getting voucher owner key: %w", err) + } + der, err := x509.MarshalPKIXPublicKey(owner) + if err != nil { + return false, fmt.Errorf("error marshaling voucher owner key: %w", err) + } + return trustedOwner(ctx, db, der) + }, + NegotiateTTL: func(requestedSeconds uint32, ov fdo.Voucher) (waitSeconds uint32) { + return min(requestedSeconds, oneWeekInSeconds) + }, + }, + TO1Responder: &fdo.TO1Server{ + Session: state, + RVBlobs: state, + }, + }) + workers.Serve(handler) +} + +func trustedOwner(ctx context.Context, db *sql.DB, pkixKey []byte) (bool, error) { + var email string + row := db.QueryRowContext(ctx, `SELECT email FROM trusted_owners WHERE pkix = ?`, pkixKey) + if err := row.Scan(&email); errors.Is(err, sql.ErrNoRows) { + return false, nil + } else if err != nil { + return false, err + } + slog.Info("accepting voucher", "user", email) + + return true, nil +} diff --git a/http/debug.go b/http/debug.go index f1e6b79..1227a7a 100644 --- a/http/debug.go +++ b/http/debug.go @@ -8,6 +8,7 @@ import ( "encoding/hex" "log/slog" + "github.com/fido-device-onboard/go-fdo/cbor" "github.com/fido-device-onboard/go-fdo/cbor/cdn" ) @@ -22,3 +23,18 @@ func tryDebugNotation(b []byte) string { } return d } + +func debugUnencryptedMessage(msgType uint8, msg any) { + if debugEnabled() { + return + } + body, _ := cbor.Marshal(msg) + slog.Debug("unencrypted request", "msg", msgType, "body", tryDebugNotation(body)) +} + +func debugDecryptedMessage(msgType uint8, decrypted []byte) { + if debugEnabled() { + return + } + slog.Debug("decrypted response", "msg", msgType, "body", tryDebugNotation(decrypted)) +} diff --git a/http/handler.go b/http/handler.go index 067aeba..8a10455 100644 --- a/http/handler.go +++ b/http/handler.go @@ -12,8 +12,6 @@ import ( "io" "log/slog" "net/http" - "net/http/httptest" - "net/http/httputil" "strconv" "strings" "time" @@ -42,12 +40,10 @@ type Handler struct { func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Parse message type from request URL - typ, err := strconv.ParseUint(r.PathValue("msg"), 10, 8) - if err != nil { - writeErr(w, 0, fmt.Errorf("invalid message type")) + msgType, ok := msgTypeFromPath(w, r) + if !ok { return } - msgType := uint8(typ) proto := protocol.Of(msgType) // Parse request headers @@ -108,39 +104,14 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if debugEnabled() { - h.debugRequest(ctx, w, r, msgType, resp) + debugRequest(w, r, func(w http.ResponseWriter, r *http.Request) { + h.handleRequest(ctx, w, r, msgType, resp) + }) return } h.handleRequest(ctx, w, r, msgType, resp) } -func (h Handler) debugRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, msgType uint8, resp protocol.Responder) { - // Dump request - debugReq, _ := httputil.DumpRequest(r, false) - var saveBody bytes.Buffer - if _, err := saveBody.ReadFrom(r.Body); err == nil { - r.Body = io.NopCloser(&saveBody) - } - slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), - "body", tryDebugNotation(saveBody.Bytes())) - - // Dump response - rr := httptest.NewRecorder() - h.handleRequest(ctx, rr, r, msgType, resp) - debugResp, _ := httputil.DumpResponse(rr.Result(), false) - slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), - "body", tryDebugNotation(rr.Body.Bytes())) - - // Copy recorded response into response writer - for key, values := range rr.Header() { - for _, value := range values { - w.Header().Add(key, value) - } - } - w.WriteHeader(rr.Code) - _, _ = w.Write(rr.Body.Bytes()) -} - func (h Handler) handleRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, msgType uint8, resp protocol.Responder) { // Validate content length maxSize := h.MaxContentLength diff --git a/http/transport.go b/http/transport.go index 8ff15ab..9b8105f 100644 --- a/http/transport.go +++ b/http/transport.go @@ -10,9 +10,7 @@ import ( "errors" "fmt" "io" - "log/slog" "net/http" - "net/http/httputil" "net/url" "path" "strconv" @@ -51,8 +49,6 @@ type Transport struct { } // Send sends a single message and receives a single response message. -// -//nolint:gocyclo func (t *Transport) Send(ctx context.Context, msgType uint8, msg any, sess kex.Session) (respType uint8, _ io.ReadCloser, _ error) { // Initialize default values if t.Client == nil { @@ -64,10 +60,7 @@ func (t *Transport) Send(ctx context.Context, msgType uint8, msg any, sess kex.S // Encrypt if a key exchange session is provided if sess != nil { - if debugEnabled() { - body, _ := cbor.Marshal(msg) - slog.Debug("unencrypted request", "msg", msgType, "body", tryDebugNotation(body)) - } + debugUnencryptedMessage(msgType, msg) var err error msg, err = sess.Encrypt(rand.Reader, msg) if err != nil { @@ -105,24 +98,12 @@ func (t *Transport) Send(ctx context.Context, msgType uint8, msg any, sess kex.S } // Perform HTTP request - if debugEnabled() { - debugReq, _ := httputil.DumpRequestOut(req, false) - slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), - "body", tryDebugNotation(body.Bytes())) - } + debugRequestOut(req, body) resp, err := t.Client.Do(req) if err != nil { return 0, nil, fmt.Errorf("error making HTTP request for message %d: %w", msgType, err) } - if debugEnabled() { - debugResp, _ := httputil.DumpResponse(resp, false) - var saveBody bytes.Buffer - if _, err := saveBody.ReadFrom(resp.Body); err == nil { - resp.Body = io.NopCloser(&saveBody) - } - slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), - "body", tryDebugNotation(saveBody.Bytes())) - } + debugResponse(resp) return t.handleResponse(resp, sess) } @@ -189,10 +170,7 @@ func (t *Transport) handleResponse(resp *http.Response, sess kex.Session) (msgTy if err != nil { return 0, nil, fmt.Errorf("error decrypting message %d: %w", msgType, err) } - - if debugEnabled() { - slog.Debug("decrypted response", "msg", msgType, "body", tryDebugNotation(decrypted)) - } + debugDecryptedMessage(msgType, decrypted) content = io.NopCloser(bytes.NewBuffer(decrypted)) } diff --git a/http/util.go b/http/util.go new file mode 100644 index 0000000..3270968 --- /dev/null +++ b/http/util.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build !tinygo + +package http + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "net/http/httputil" + "strconv" +) + +func msgTypeFromPath(w http.ResponseWriter, r *http.Request) (uint8, bool) { + typ, err := strconv.ParseUint(r.PathValue("msg"), 10, 8) + if err != nil { + writeErr(w, 0, fmt.Errorf("invalid message type")) + return 0, false + } + return uint8(typ), true +} + +func debugRequest(w http.ResponseWriter, r *http.Request, handler http.HandlerFunc) { + // Dump request + debugReq, _ := httputil.DumpRequest(r, false) + var saveBody bytes.Buffer + if _, err := saveBody.ReadFrom(r.Body); err == nil { + r.Body = io.NopCloser(&saveBody) + } + slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), + "body", tryDebugNotation(saveBody.Bytes())) + + // Dump response + rr := httptest.NewRecorder() + handler(rr, r) + debugResp, _ := httputil.DumpResponse(rr.Result(), false) + slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), + "body", tryDebugNotation(rr.Body.Bytes())) + + // Copy recorded response into response writer + for key, values := range rr.Header() { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(rr.Code) + _, _ = w.Write(rr.Body.Bytes()) +} + +func debugRequestOut(req *http.Request, body *bytes.Buffer) { + if !debugEnabled() { + return + } + debugReq, _ := httputil.DumpRequestOut(req, false) + slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), + "body", tryDebugNotation(body.Bytes())) +} + +func debugResponse(resp *http.Response) { + if !debugEnabled() { + return + } + debugResp, _ := httputil.DumpResponse(resp, false) + var saveBody bytes.Buffer + if _, err := saveBody.ReadFrom(resp.Body); err == nil { + resp.Body = io.NopCloser(&saveBody) + } + slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), + "body", tryDebugNotation(saveBody.Bytes())) +} diff --git a/http/util_tinygo.go b/http/util_tinygo.go new file mode 100644 index 0000000..5a77e2a --- /dev/null +++ b/http/util_tinygo.go @@ -0,0 +1,223 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build tinygo + +package http + +import ( + "bytes" + "fmt" + "io" + "log/slog" + "net/http" + "strconv" + "strings" +) + +func msgTypeFromPath(w http.ResponseWriter, r *http.Request) (uint8, bool) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return 0, false + } + path := strings.TrimPrefix(r.URL.Path, "/fdo/101/msg/") + if strings.Contains(path, "/") { + w.WriteHeader(http.StatusNotFound) + return 0, false + } + typ, err := strconv.ParseUint(path, 10, 8) + if err != nil { + writeErr(w, 0, fmt.Errorf("invalid message type")) + return 0, false + } + return uint8(typ), true +} + +func debugRequest(w http.ResponseWriter, r *http.Request, handler http.HandlerFunc) { + // Dump request + debugReq, _ := dumpRequest(r) + var saveBody bytes.Buffer + if _, err := saveBody.ReadFrom(r.Body); err == nil { + r.Body = io.NopCloser(&saveBody) + } + slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), + "body", tryDebugNotation(saveBody.Bytes())) + + // Dump response + rr := new(responseRecorder) + handler(rr, r) + resp := rr.Result() + debugResp, _ := dumpResponse(resp) + slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), + "body", tryDebugNotation(rr.body.Bytes())) + + // Copy recorded response into response writer + for key, values := range rr.Header() { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(resp.StatusCode) + _, _ = w.Write(rr.body.Bytes()) +} + +func debugRequestOut(req *http.Request, body *bytes.Buffer) { + if !debugEnabled() { + return + } + + // Unlike httputil.DumpRequestOut, this does not use an actual HTTP + // transport to ensure that the output has all relevant headers updated and + // canonicalized. Improvements are welcome. + debugReq, _ := dumpRequest(req) + slog.Debug("request", "dump", string(bytes.TrimSpace(debugReq)), + "body", tryDebugNotation(body.Bytes())) +} + +func debugResponse(resp *http.Response) { + if !debugEnabled() { + return + } + + var saveBody bytes.Buffer + if _, err := saveBody.ReadFrom(resp.Body); err == nil { + _ = resp.Body.Close() + resp.Body = io.NopCloser(&saveBody) + } + debugResp, _ := dumpResponse(resp) + slog.Debug("response", "dump", string(bytes.TrimSpace(debugResp)), + "body", tryDebugNotation(saveBody.Bytes())) +} + +func dumpRequest(req *http.Request) ([]byte, error) { + var out bytes.Buffer + + fmt.Fprintf(&out, "%s %s HTTP/%d.%d\r\n", req.Method, req.RequestURI, req.ProtoMajor, req.ProtoMinor) + + absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://") + if !absRequestURI { + host := req.Host + if host == "" && req.URL != nil { + host = req.URL.Host + } + if host != "" { + fmt.Fprintf(&out, "Host: %s\r\n", host) + } + } + + if len(req.TransferEncoding) > 0 { + fmt.Fprintf(&out, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ",")) + } + + if err := req.Header.WriteSubset(&out, map[string]bool{ + "Transfer-Encoding": true, + "Trailer": true, + }); err != nil { + return nil, err + } + + io.WriteString(&out, "\r\n") + + return out.Bytes(), nil +} + +var errNoBody = fmt.Errorf("no body") + +type failureToReadBody struct{} + +func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } +func (failureToReadBody) Close() error { return nil } + +func dumpResponse(resp *http.Response) ([]byte, error) { + var out bytes.Buffer + savecl := resp.ContentLength + if resp.ContentLength == 0 { + resp.Body = io.NopCloser(strings.NewReader("")) + } else { + resp.Body = failureToReadBody{} + } + err := resp.Write(&out) + resp.ContentLength = savecl + if err != nil && err != errNoBody { + return nil, err + } + return out.Bytes(), nil +} + +type responseRecorder struct { + body *bytes.Buffer + code int + + header http.Header + headerAtFirstWrite http.Header + wroteHeader bool + + result *http.Response +} + +func (rr *responseRecorder) Header() http.Header { + if rr.header == nil { + rr.header = make(http.Header) + } + return rr.header +} + +func (rr *responseRecorder) Write(p []byte) (int, error) { + if !rr.wroteHeader { + m := rr.Header() + if _, hasType := m["Content-Type"]; !hasType && m.Get("Transfer-Encoding") == "" { + m.Set("Content-Type", http.DetectContentType(p)) + } + rr.WriteHeader(200) + } + if rr.body == nil { + rr.body = new(bytes.Buffer) + } + return rr.body.Write(p) +} + +func (rr *responseRecorder) WriteHeader(statusCode int) { + if rr.wroteHeader { + return + } + + rr.code = statusCode + rr.wroteHeader = true + rr.headerAtFirstWrite = rr.Header().Clone() +} + +func (rr *responseRecorder) Result() *http.Response { + if rr.result != nil { + return rr.result + } + if rr.code == 0 { + rr.code = 200 + } + if rr.headerAtFirstWrite == nil { + rr.headerAtFirstWrite = rr.Header().Clone() + } + + res := &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: rr.code, + Header: rr.headerAtFirstWrite, + } + if res.StatusCode == 0 { + res.StatusCode = 200 + } + res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) + res.Body = io.NopCloser(bytes.NewReader(rr.body.Bytes())) + res.ContentLength = func(length string) int64 { + n, err := strconv.ParseUint(strings.TrimSpace(length), 10, 63) + if err != nil { + return -1 + } + return int64(n) + }(res.Header.Get("Content-Length")) + // Trailers are not used in FDO + + rr.result = res + return res +} diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go index 168a34b..f45d6a1 100644 --- a/sqlite/sqlite.go +++ b/sqlite/sqlite.go @@ -19,15 +19,10 @@ import ( "fmt" "io" "maps" - "path/filepath" "slices" "strings" "time" - "github.com/ncruces/go-sqlite3/driver" // Load database/sql driver - _ "github.com/ncruces/go-sqlite3/embed" // Load sqlite WASM binary - _ "github.com/ncruces/go-sqlite3/vfs/xts" // Encryption VFS - "github.com/fido-device-onboard/go-fdo" "github.com/fido-device-onboard/go-fdo/cbor" "github.com/fido-device-onboard/go-fdo/cose" @@ -44,25 +39,6 @@ type DB struct { db *sql.DB } -// Open creates or opens a SQLite database file using a single non-pooled -// connection. If a password is specified, then the xts VFS will be used -// with a text key. -func Open(filename, password string) (*DB, error) { - var query string - if password != "" { - query += fmt.Sprintf("?vfs=xts&_pragma=textkey(%q)&_pragma=temp_store(memory)", password) - } - connector, err := (&driver.SQLite{}).OpenConnector("file:" + filepath.Clean(filename) + query) - if err != nil { - return nil, fmt.Errorf("error creating sqlite connector: %w", err) - } - db := sql.OpenDB(connector) - if err := Init(db); err != nil { - return nil, err - } - return New(db), nil -} - // New creates a DB. The expected tables must already be created and pragmas // must already be set, including foreign_keys=ON. func New(db *sql.DB) *DB { return &DB{db: db} } diff --git a/sqlite/wasm.go b/sqlite/wasm.go new file mode 100644 index 0000000..24ac00a --- /dev/null +++ b/sqlite/wasm.go @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache 2.0 + +//go:build !tinygo + +// Open is not implemented for tinygo, because it requires embedding a WASM +// runtime in the binary. + +package sqlite + +import ( + "database/sql" + "fmt" + "path/filepath" + + "github.com/ncruces/go-sqlite3/driver" // Load database/sql driver + _ "github.com/ncruces/go-sqlite3/embed" // Load sqlite WASM binary + _ "github.com/ncruces/go-sqlite3/vfs/xts" // Encryption VFS +) + +// Open creates or opens a SQLite database file using a single non-pooled +// connection. If a password is specified, then the xts VFS will be used +// with a text key. +func Open(filename, password string) (*DB, error) { + var query string + if password != "" { + query += fmt.Sprintf("?vfs=xts&_pragma=textkey(%q)&_pragma=temp_store(memory)", password) + } + connector, err := (&driver.SQLite{}).OpenConnector("file:" + filepath.Clean(filename) + query) + if err != nil { + return nil, fmt.Errorf("error creating sqlite connector: %w", err) + } + db := sql.OpenDB(connector) + if err := Init(db); err != nil { + return nil, err + } + return New(db), nil +}