From fdb22a8dee4f8d8fbad960ccbdbd47736d61e08a Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Fri, 23 Feb 2024 12:24:40 +0200 Subject: [PATCH] Fixed uncontrolled buffer growth in restore command --- internal/db/postgres/restorers/table.go | 37 ++++++++++++++++--------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/internal/db/postgres/restorers/table.go b/internal/db/postgres/restorers/table.go index cc3db25c..9641f360 100644 --- a/internal/db/postgres/restorers/table.go +++ b/internal/db/postgres/restorers/table.go @@ -44,6 +44,10 @@ func NewTableRestorer(entry *toc.Entry, st storages.Storager) *TableRestorer { } func (td *TableRestorer) Execute(ctx context.Context, tx pgx.Tx) error { + // TODO: Refactor this logic + // 1. Decompose the Execute method into separate functions + // 2. Add tests + // 3. Get rid of the anonymous functions below return func() error { if td.Entry.FileName == nil { @@ -63,12 +67,10 @@ func (td *TableRestorer) Execute(ctx context.Context, tx pgx.Tx) error { log.Debug().Str("copyStmt", *td.Entry.CopyStmt).Msgf("performing pgcopy statement") frontend := tx.Conn().PgConn().Frontend() - frontend.Send(&pgproto3.Query{ - String: *td.Entry.CopyStmt, - }) - if err = frontend.Flush(); err != nil { - return err + err = sendMessage(frontend, &pgproto3.Query{String: *td.Entry.CopyStmt}) + if err != nil { + return fmt.Errorf("error sending Query message: %w", err) } // Prepare for streaming the pgcopy data @@ -108,19 +110,19 @@ func (td *TableRestorer) Execute(ctx context.Context, tx pgx.Tx) error { n, err = gz.Read(buf) if err != nil { if errors.Is(err, io.EOF) { - frontend.Send(&pgproto3.CopyDone{}) + completionErr := sendMessage(frontend, &pgproto3.CopyDone{}) + if completionErr != nil { + return fmt.Errorf("error sending CopyDone message: %w", err) + } break } return fmt.Errorf("error readimg from table dump: %w", err) } - frontend.Send(&pgproto3.CopyData{ - Data: buf[:n], - }) - } - - if err = frontend.Flush(); err != nil { - return err + err = sendMessage(frontend, &pgproto3.CopyData{Data: buf[:n]}) + if err != nil { + return fmt.Errorf("error sending DopyData message: %w", err) + } } // Perform post streaming handling @@ -152,3 +154,12 @@ func (td *TableRestorer) Execute(ctx context.Context, tx pgx.Tx) error { func (td *TableRestorer) DebugInfo() string { return fmt.Sprintf("table %s.%s", *td.Entry.Namespace, *td.Entry.Tag) } + +// sendMessage - send a message to the PostgreSQL backend and flush a buffer +func sendMessage(frontend *pgproto3.Frontend, msg pgproto3.FrontendMessage) error { + frontend.Send(msg) + if err := frontend.Flush(); err != nil { + return fmt.Errorf("error flushing pgx frontend buffer: %w", err) + } + return nil +}