Skip to content

Commit 7eb9d25

Browse files
authored
Support Square Bracket Notation in Multipart Form data (#3268)
* Feature Request: Support Square Bracket Notation in Multipart Form Data #3224 * Feature Request: Support Square Bracket Notation in Multipart Form Data #3224
1 parent 47be681 commit 7eb9d25

File tree

5 files changed

+144
-92
lines changed

5 files changed

+144
-92
lines changed

ctx.go

Lines changed: 29 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -406,28 +406,30 @@ func (c *Ctx) BodyParser(out interface{}) error {
406406
k := c.app.getString(key)
407407
v := c.app.getString(val)
408408

409-
if strings.Contains(k, "[") {
410-
k, err = parseParamSquareBrackets(k)
411-
}
412-
413-
if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, bodyTag) {
414-
values := strings.Split(v, ",")
415-
for i := 0; i < len(values); i++ {
416-
data[k] = append(data[k], values[i])
417-
}
418-
} else {
419-
data[k] = append(data[k], v)
420-
}
409+
err = formatParserData(out, data, bodyTag, k, v, c.app.config.EnableSplittingOnParsers, true)
421410
})
422411

412+
if err != nil {
413+
return err
414+
}
415+
423416
return c.parseToStruct(bodyTag, out, data)
424417
}
425418
if strings.HasPrefix(ctype, MIMEMultipartForm) {
426-
data, err := c.fasthttp.MultipartForm()
419+
multipartForm, err := c.fasthttp.MultipartForm()
427420
if err != nil {
428421
return err
429422
}
430-
return c.parseToStruct(bodyTag, out, data.Value)
423+
424+
data := make(map[string][]string)
425+
for key, values := range multipartForm.Value {
426+
err = formatParserData(out, data, bodyTag, key, values, c.app.config.EnableSplittingOnParsers, true)
427+
if err != nil {
428+
return err
429+
}
430+
}
431+
432+
return c.parseToStruct(bodyTag, out, data)
431433
}
432434
if strings.HasPrefix(ctype, MIMETextXML) || strings.HasPrefix(ctype, MIMEApplicationXML) {
433435
if err := xml.Unmarshal(c.Body(), out); err != nil {
@@ -531,18 +533,7 @@ func (c *Ctx) CookieParser(out interface{}) error {
531533
k := c.app.getString(key)
532534
v := c.app.getString(val)
533535

534-
if strings.Contains(k, "[") {
535-
k, err = parseParamSquareBrackets(k)
536-
}
537-
538-
if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, cookieTag) {
539-
values := strings.Split(v, ",")
540-
for i := 0; i < len(values); i++ {
541-
data[k] = append(data[k], values[i])
542-
}
543-
} else {
544-
data[k] = append(data[k], v)
545-
}
536+
err = formatParserData(out, data, cookieTag, k, v, c.app.config.EnableSplittingOnParsers, true)
546537
})
547538
if err != nil {
548539
return err
@@ -1283,18 +1274,7 @@ func (c *Ctx) QueryParser(out interface{}) error {
12831274
k := c.app.getString(key)
12841275
v := c.app.getString(val)
12851276

1286-
if strings.Contains(k, "[") {
1287-
k, err = parseParamSquareBrackets(k)
1288-
}
1289-
1290-
if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, queryTag) {
1291-
values := strings.Split(v, ",")
1292-
for i := 0; i < len(values); i++ {
1293-
data[k] = append(data[k], values[i])
1294-
}
1295-
} else {
1296-
data[k] = append(data[k], v)
1297-
}
1277+
err = formatParserData(out, data, queryTag, k, v, c.app.config.EnableSplittingOnParsers, true)
12981278
})
12991279

13001280
if err != nil {
@@ -1304,61 +1284,26 @@ func (c *Ctx) QueryParser(out interface{}) error {
13041284
return c.parseToStruct(queryTag, out, data)
13051285
}
13061286

1307-
func parseParamSquareBrackets(k string) (string, error) {
1308-
bb := bytebufferpool.Get()
1309-
defer bytebufferpool.Put(bb)
1310-
1311-
kbytes := []byte(k)
1312-
openBracketsCount := 0
1313-
1314-
for i, b := range kbytes {
1315-
if b == '[' {
1316-
openBracketsCount++
1317-
if i+1 < len(kbytes) && kbytes[i+1] != ']' {
1318-
if err := bb.WriteByte('.'); err != nil {
1319-
return "", fmt.Errorf("failed to write: %w", err)
1320-
}
1321-
}
1322-
continue
1323-
}
1324-
1325-
if b == ']' {
1326-
openBracketsCount--
1327-
if openBracketsCount < 0 {
1328-
return "", errors.New("unmatched brackets")
1329-
}
1330-
continue
1331-
}
1332-
1333-
if err := bb.WriteByte(b); err != nil {
1334-
return "", fmt.Errorf("failed to write: %w", err)
1335-
}
1336-
}
1337-
1338-
if openBracketsCount > 0 {
1339-
return "", errors.New("unmatched brackets")
1340-
}
1341-
1342-
return bb.String(), nil
1343-
}
1344-
13451287
// ReqHeaderParser binds the request header strings to a struct.
13461288
func (c *Ctx) ReqHeaderParser(out interface{}) error {
13471289
data := make(map[string][]string)
1290+
var err error
1291+
13481292
c.fasthttp.Request.Header.VisitAll(func(key, val []byte) {
1293+
if err != nil {
1294+
return
1295+
}
1296+
13491297
k := c.app.getString(key)
13501298
v := c.app.getString(val)
13511299

1352-
if c.app.config.EnableSplittingOnParsers && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k, reqHeaderTag) {
1353-
values := strings.Split(v, ",")
1354-
for i := 0; i < len(values); i++ {
1355-
data[k] = append(data[k], values[i])
1356-
}
1357-
} else {
1358-
data[k] = append(data[k], v)
1359-
}
1300+
err = formatParserData(out, data, reqHeaderTag, k, v, c.app.config.EnableSplittingOnParsers, false)
13601301
})
13611302

1303+
if err != nil {
1304+
return err
1305+
}
1306+
13621307
return c.parseToStruct(reqHeaderTag, out, data)
13631308
}
13641309

ctx_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,48 @@ func Test_Ctx_BodyParser(t *testing.T) {
610610
utils.AssertEqual(t, 2, len(cq.Data))
611611
utils.AssertEqual(t, "john", cq.Data[0].Name)
612612
utils.AssertEqual(t, "doe", cq.Data[1].Name)
613+
614+
t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) {
615+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
616+
c.Request().Reset()
617+
618+
buf := &bytes.Buffer{}
619+
writer := multipart.NewWriter(buf)
620+
utils.AssertEqual(t, nil, writer.WriteField("data.0.name", "john"))
621+
utils.AssertEqual(t, nil, writer.WriteField("data.1.name", "doe"))
622+
utils.AssertEqual(t, nil, writer.Close())
623+
624+
c.Request().Header.SetContentType(writer.FormDataContentType())
625+
c.Request().SetBody(buf.Bytes())
626+
c.Request().Header.SetContentLength(len(c.Body()))
627+
628+
cq := new(CollectionQuery)
629+
utils.AssertEqual(t, nil, c.BodyParser(cq))
630+
utils.AssertEqual(t, len(cq.Data), 2)
631+
utils.AssertEqual(t, "john", cq.Data[0].Name)
632+
utils.AssertEqual(t, "doe", cq.Data[1].Name)
633+
})
634+
635+
t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) {
636+
c := app.AcquireCtx(&fasthttp.RequestCtx{})
637+
c.Request().Reset()
638+
639+
buf := &bytes.Buffer{}
640+
writer := multipart.NewWriter(buf)
641+
utils.AssertEqual(t, nil, writer.WriteField("data[0][name]", "john"))
642+
utils.AssertEqual(t, nil, writer.WriteField("data[1][name]", "doe"))
643+
utils.AssertEqual(t, nil, writer.Close())
644+
645+
c.Request().Header.SetContentType(writer.FormDataContentType())
646+
c.Request().SetBody(buf.Bytes())
647+
c.Request().Header.SetContentLength(len(c.Body()))
648+
649+
cq := new(CollectionQuery)
650+
utils.AssertEqual(t, nil, c.BodyParser(cq))
651+
utils.AssertEqual(t, len(cq.Data), 2)
652+
utils.AssertEqual(t, "john", cq.Data[0].Name)
653+
utils.AssertEqual(t, "doe", cq.Data[1].Name)
654+
})
613655
}
614656

615657
func Test_Ctx_ParamParser(t *testing.T) {

go.mod

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,4 @@ require (
1919
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect
2020
github.com/rivo/uniseg v0.2.0 // indirect
2121
github.com/valyala/tcplisten v1.0.0 // indirect
22-
golang.org/x/mod v0.18.0 // indirect
23-
golang.org/x/tools v0.22.0 // indirect
2422
)

go.sum

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1Gsh
1515
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
1616
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
1717
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
18-
github.com/tinylib/msgp v1.1.3 h1:3giwAkmtaEDLSV0MdO1lDLuPgklgPzmk8H9+So2BVfA=
19-
github.com/tinylib/msgp v1.1.3/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
2018
github.com/tinylib/msgp v1.2.5 h1:WeQg1whrXRFiZusidTQqzETkRpGjFjcIhW6uqWH09po=
2119
github.com/tinylib/msgp v1.2.5/go.mod h1:ykjzy2wzgrlvpDCRc4LA8UXy6D8bzMSuAF3WD57Gok0=
2220
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
@@ -25,11 +23,7 @@ github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1S
2523
github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g=
2624
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
2725
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
28-
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
29-
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
3026
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
3127
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
3228
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
3329
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
34-
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
35-
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=

helpers.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package fiber
77
import (
88
"bytes"
99
"crypto/tls"
10+
"errors"
1011
"fmt"
1112
"hash/crc32"
1213
"io"
@@ -1151,3 +1152,75 @@ func IndexRune(str string, needle int32) bool {
11511152
}
11521153
return false
11531154
}
1155+
1156+
func parseParamSquareBrackets(k string) (string, error) {
1157+
bb := bytebufferpool.Get()
1158+
defer bytebufferpool.Put(bb)
1159+
1160+
kbytes := []byte(k)
1161+
openBracketsCount := 0
1162+
1163+
for i, b := range kbytes {
1164+
if b == '[' {
1165+
openBracketsCount++
1166+
if i+1 < len(kbytes) && kbytes[i+1] != ']' {
1167+
if err := bb.WriteByte('.'); err != nil {
1168+
return "", fmt.Errorf("failed to write: %w", err)
1169+
}
1170+
}
1171+
continue
1172+
}
1173+
1174+
if b == ']' {
1175+
openBracketsCount--
1176+
if openBracketsCount < 0 {
1177+
return "", errors.New("unmatched brackets")
1178+
}
1179+
continue
1180+
}
1181+
1182+
if err := bb.WriteByte(b); err != nil {
1183+
return "", fmt.Errorf("failed to write: %w", err)
1184+
}
1185+
}
1186+
1187+
if openBracketsCount > 0 {
1188+
return "", errors.New("unmatched brackets")
1189+
}
1190+
1191+
return bb.String(), nil
1192+
}
1193+
1194+
func formatParserData(out interface{}, data map[string][]string, aliasTag, key string, value interface{}, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
1195+
var err error
1196+
if supportBracketNotation && strings.Contains(key, "[") {
1197+
key, err = parseParamSquareBrackets(key)
1198+
if err != nil {
1199+
return err
1200+
}
1201+
}
1202+
1203+
switch v := value.(type) {
1204+
case string:
1205+
assignBindData(out, data, aliasTag, key, v, enableSplitting)
1206+
case []string:
1207+
for _, val := range v {
1208+
assignBindData(out, data, aliasTag, key, val, enableSplitting)
1209+
}
1210+
default:
1211+
return fmt.Errorf("unsupported value type: %T", value)
1212+
}
1213+
1214+
return err
1215+
}
1216+
1217+
func assignBindData(out interface{}, data map[string][]string, aliasTag, key, value string, enableSplitting bool) { //nolint:revive // it's okay
1218+
if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key, aliasTag) {
1219+
values := strings.Split(value, ",")
1220+
for i := 0; i < len(values); i++ {
1221+
data[key] = append(data[key], values[i])
1222+
}
1223+
} else {
1224+
data[key] = append(data[key], value)
1225+
}
1226+
}

0 commit comments

Comments
 (0)