Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔥 feat: Add support for iterator methods to Fiber client #3228

Merged
merged 15 commits into from
Dec 10, 2024
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ coverage:
format:
go run mvdan.cc/gofumpt@latest -w -l .

## markdown: 🎨 Find markdown format issues (Requires markdownlint-cli)
## markdown: 🎨 Find markdown format issues (Requires markdownlint-cli2)
.PHONY: markdown
markdown:
markdownlint-cli2 "**/*.md" "#vendor"
Expand Down
25 changes: 6 additions & 19 deletions client/hooks.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package client

import (
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -241,8 +240,8 @@ func parserRequestBodyFile(req *Request) error {
return fmt.Errorf("write formdata error: %w", err)
}

// add file
b := make([]byte, 512)
// add files
fileBuf := make([]byte, 1<<20) // Allocate 1MB buffer
for i, v := range req.files {
if v.name == "" && v.path == "" {
return ErrFileNoName
Expand Down Expand Up @@ -273,24 +272,12 @@ func parserRequestBodyFile(req *Request) error {
return fmt.Errorf("create file error: %w", err)
}

for {
n, err := v.reader.Read(b)
if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("read file error: %w", err)
}

if errors.Is(err, io.EOF) {
break
}

_, err = w.Write(b[:n])
if err != nil {
return fmt.Errorf("write file error: %w", err)
}
// Copy the file from reader to multipart writer
if _, err := io.CopyBuffer(w, v.reader, fileBuf); err != nil {
return fmt.Errorf("failed to copy file data: %w", err)
}

err = v.reader.Close()
if err != nil {
if err := v.reader.Close(); err != nil {
return fmt.Errorf("close file error: %w", err)
}
}
Expand Down
133 changes: 133 additions & 0 deletions client/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"context"
"errors"
"io"
"iter"
"path/filepath"
"reflect"
"slices"
"strconv"
"sync"
"time"
Expand Down Expand Up @@ -129,6 +131,31 @@ func (r *Request) Header(key string) []string {
return r.header.PeekMultiple(key)
}

// Headers returns all headers in the request using an iterator.
// You can use maps.Collect() to collect all headers into a map.
//
// The returned value is valid until the request object is released.
// Any future calls to Headers method will return the modified value. Do not store references to returned value. Make copies instead.
func (r *Request) Headers() iter.Seq2[string, []string] {
return func(yield func(string, []string) bool) {
peekKeys := r.header.PeekKeys()
keys := make([][]byte, len(peekKeys))
copy(keys, peekKeys) // It is necessary to have immutable byte slice.

for _, key := range keys {
vals := r.header.PeekAll(utils.UnsafeString(key))
valsStr := make([]string, len(vals))
for i, v := range vals {
valsStr[i] = utils.UnsafeString(v)
}

if !yield(utils.UnsafeString(key), valsStr) {
return
}
}
}
}

// AddHeader method adds a single header field and its value in the request instance.
func (r *Request) AddHeader(key, val string) *Request {
r.header.Add(key, val)
Expand Down Expand Up @@ -168,6 +195,33 @@ func (r *Request) Param(key string) []string {
return res
}

// Params returns all params in the request using an iterator.
// You can use maps.Collect() to collect all params into a map.
//
// The returned value is valid until the request object is released.
// Any future calls to Params method will return the modified value. Do not store references to returned value. Make copies instead.
func (r *Request) Params() iter.Seq2[string, []string] {
return func(yield func(string, []string) bool) {
keys := r.params.Keys()

for _, key := range keys {
if key == "" {
continue
}

vals := r.params.PeekMulti(key)
valsStr := make([]string, len(vals))
for i, v := range vals {
valsStr[i] = utils.UnsafeString(v)
}

if !yield(key, valsStr) {
return
}
}
}
}

// AddParam method adds a single param field and its value in the request instance.
func (r *Request) AddParam(key, val string) *Request {
r.params.Add(key, val)
Expand Down Expand Up @@ -254,6 +308,18 @@ func (r *Request) Cookie(key string) string {
return ""
}

// Cookies returns all cookies in the cookies using an iterator.
// You can use maps.Collect() to collect all cookies into a map.
func (r *Request) Cookies() iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
r.cookies.VisitAll(func(key, val string) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VisitAll in its current form does not support early returns. If the caller of the iterator stops iterating early (e.g. by break in for range), yield would be called again after it returned false, and this causes a runtime panic.

Please update VisitAll to support early returns, and add a test to verify that stopping the iterator early works.

if !yield(key, val) {
return
}
})
}
}

// SetCookie method sets a single cookie field and its value in the request instance.
// It will override cookie which set in client instance.
func (r *Request) SetCookie(key, val string) *Request {
Expand Down Expand Up @@ -291,6 +357,18 @@ func (r *Request) PathParam(key string) string {
return ""
}

// PathParams returns all path params in request instance.
// You can use maps.Collect() to collect all cookies into a map.
func (r *Request) PathParams() iter.Seq2[string, string] {
return func(yield func(string, string) bool) {
r.path.VisitAll(func(key, val string) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the above.

Thanks for the report. Will fix it

if !yield(key, val) {
return
}
})
}
}

// SetPathParam method sets a single path param field and its value in the request instance.
// It will override path param which set in client instance.
func (r *Request) SetPathParam(key, val string) *Request {
Expand Down Expand Up @@ -376,6 +454,33 @@ func (r *Request) FormData(key string) []string {
return res
}

// AllFormData method returns all form datas in request instance.
// You can use maps.Collect() to collect all cookies into a map.
//
// The returned value is valid until the request object is released.
// Any future calls to FormDatas method will return the modified value. Do not store references to returned value. Make copies instead.
func (r *Request) AllFormData() iter.Seq2[string, []string] {
return func(yield func(string, []string) bool) {
keys := r.formData.Keys()

for _, key := range keys {
if key == "" {
continue
}

vals := r.formData.PeekMulti(key)
valsStr := make([]string, len(vals))
for i, v := range vals {
valsStr[i] = utils.UnsafeString(v)
}

if !yield(key, valsStr) {
return
}
}
}
}

// AddFormData method adds a single form data field and its value in the request instance.
func (r *Request) AddFormData(key, val string) *Request {
r.formData.AddData(key, val)
Expand Down Expand Up @@ -435,6 +540,14 @@ func (r *Request) File(name string) *File {
return nil
}

// Files method returns all files in request instance.
//
// The returned value is valid until the request object is released.
// Any future calls to Files method will return the modified value. Do not store references to returned value. Make copies instead.
func (r *Request) Files() []*File {
return r.files
}

// FileByPath returns file ptr store in request obj by path.
func (r *Request) FileByPath(path string) *File {
for _, v := range r.files {
Expand Down Expand Up @@ -617,6 +730,16 @@ type QueryParam struct {
*fasthttp.Args
}

// Keys method returns all keys in the query params.
func (p *QueryParam) Keys() []string {
keys := make([]string, 0, p.Len())
p.VisitAll(func(key, _ []byte) {
keys = append(keys, utils.UnsafeString(key))
})

return slices.Compact(keys)
}

// AddParams receive a map and add each value to param.
func (p *QueryParam) AddParams(r map[string][]string) {
for k, v := range r {
Expand Down Expand Up @@ -747,6 +870,16 @@ type FormData struct {
*fasthttp.Args
}

// Keys method returns all keys in the form data.
func (f *FormData) Keys() []string {
keys := make([]string, 0, f.Len())
f.VisitAll(func(key, _ []byte) {
keys = append(keys, utils.UnsafeString(key))
})

return slices.Compact(keys)
}

// AddData method is a wrapper of Args's Add method.
func (f *FormData) AddData(key, val string) {
f.Add(key, val)
Expand Down
Loading
Loading