Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions internal/langserver/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@ package langserver
import (
"fmt"
"net/http"

"github.com/x1unix/go-playground/pkg/goplay"
)

// ErrSnippetTooLarge is snippet max size limit error
var ErrSnippetTooLarge = Errorf(
http.StatusRequestEntityTooLarge,
"code snippet too large (max %d bytes)",
goplay.MaxSnippetSize,
)

// HTTPError is HTTP response error
Expand Down
10 changes: 0 additions & 10 deletions internal/langserver/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package langserver

import (
"errors"
"github.com/x1unix/go-playground/pkg/goplay"
"net/http"
"syscall"
)
Expand All @@ -29,15 +28,6 @@ func WrapHandler(h HandlerFunc, guards ...GuardFn) http.HandlerFunc {
}
}

// ValidateContentLength validates Go code snippet size
func ValidateContentLength(r *http.Request) error {
if err := goplay.ValidateContentLength(int(r.ContentLength)); err != nil {
return NewHTTPError(http.StatusRequestEntityTooLarge, err)
}

return nil
}

func handleError(err error, w http.ResponseWriter) {
if err == nil {
return
Expand Down
11 changes: 0 additions & 11 deletions internal/langserver/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package langserver
import (
"encoding/json"
"go.uber.org/zap"
"io"
"net/http"
"strconv"
)
Expand Down Expand Up @@ -60,13 +59,3 @@ func shouldFormatCode(r *http.Request) (bool, error) {

return boolVal, nil
}

func getPayloadFromRequest(r *http.Request) ([]byte, error) {
src, err := io.ReadAll(r.Body)
if err != nil {
return nil, Errorf(http.StatusBadGateway, "failed to read request: %s", err)
}

r.Body.Close()
return src, nil
}
55 changes: 41 additions & 14 deletions internal/langserver/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package langserver

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -70,13 +71,13 @@ func (s *Service) Mount(r *mux.Router) {
r.Path("/suggest").
HandlerFunc(WrapHandler(s.HandleGetSuggestion))
r.Path("/run").Methods(http.MethodPost).
HandlerFunc(WrapHandler(s.HandleRunCode, ValidateContentLength))
HandlerFunc(WrapHandler(s.HandleRunCode))
r.Path("/compile").Methods(http.MethodPost).
HandlerFunc(WrapHandler(s.HandleCompile, ValidateContentLength))
HandlerFunc(WrapHandler(s.HandleCompile))
r.Path("/format").Methods(http.MethodPost).
HandlerFunc(WrapHandler(s.HandleFormatCode, ValidateContentLength))
HandlerFunc(WrapHandler(s.HandleFormatCode))
r.Path("/share").Methods(http.MethodPost).
HandlerFunc(WrapHandler(s.HandleShare, ValidateContentLength))
HandlerFunc(WrapHandler(s.HandleShare))
r.Path("/snippet/{id}").Methods(http.MethodGet).
HandlerFunc(WrapHandler(s.HandleGetSnippet))
r.Path("/backends/info").Methods(http.MethodGet).
Expand Down Expand Up @@ -159,7 +160,7 @@ func (s *Service) HandleGetSuggestion(w http.ResponseWriter, r *http.Request) er

// HandleFormatCode handles goimports action
func (s *Service) HandleFormatCode(w http.ResponseWriter, r *http.Request) error {
src, err := getPayloadFromRequest(r)
src, err := s.getPayloadFromRequest(r)
if err != nil {
return err
}
Expand All @@ -186,8 +187,8 @@ func (s *Service) HandleShare(w http.ResponseWriter, r *http.Request) error {
shareID, err := s.client.Share(r.Context(), r.Body)
defer r.Body.Close()
if err != nil {
if errors.Is(err, goplay.ErrSnippetTooLarge) {
return NewHTTPError(http.StatusRequestEntityTooLarge, err)
if isContentLengthError(err) {
return ErrSnippetTooLarge
}

s.log.Error("failed to share code: ", err)
Expand Down Expand Up @@ -225,7 +226,7 @@ func (s *Service) HandleGetSnippet(w http.ResponseWriter, r *http.Request) error
// HandleRunCode handles code run
func (s *Service) HandleRunCode(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
src, err := getPayloadFromRequest(r)
src, err := s.getPayloadFromRequest(r)
if err != nil {
return err
}
Expand Down Expand Up @@ -304,17 +305,15 @@ func (s *Service) HandleArtifactRequest(w http.ResponseWriter, r *http.Request)
w.Header().Set("Content-Length", contentLength)
w.Header().Set(rawContentLengthHeader, contentLength)

n, err := io.Copy(w, data)
defer data.Close()
if err != nil {
if _, err := io.Copy(w, data); err != nil {
s.log.Errorw("failed to send artifact",
"artifactID", artifactId,
"err", err,
)
return err
}

w.Header().Set("Content-Length", strconv.FormatInt(n, 10))
return nil
}

Expand All @@ -329,7 +328,7 @@ func (s *Service) HandleCompile(w http.ResponseWriter, r *http.Request) error {
return NewHTTPError(http.StatusTooManyRequests, err)
}

src, err := getPayloadFromRequest(r)
src, err := s.getPayloadFromRequest(r)
if err != nil {
return err
}
Expand Down Expand Up @@ -391,8 +390,8 @@ func backendFromRequest(r *http.Request) (goplay.Backend, error) {
func (s *Service) goImportsCode(ctx context.Context, src []byte, backend goplay.Backend) ([]byte, bool, error) {
resp, err := s.client.GoImports(ctx, src, backend)
if err != nil {
if errors.Is(err, goplay.ErrSnippetTooLarge) {
return nil, false, NewHTTPError(http.StatusRequestEntityTooLarge, err)
if isContentLengthError(err) {
return nil, false, ErrSnippetTooLarge
}

s.log.Error(err)
Expand All @@ -406,3 +405,31 @@ func (s *Service) goImportsCode(ctx context.Context, src []byte, backend goplay.
changed := resp.Body != string(src)
return []byte(resp.Body), changed, nil
}

func (s *Service) getPayloadFromRequest(r *http.Request) ([]byte, error) {
// see: https://github.com/golang/playground/blob/master/share.go#L69
var buff bytes.Buffer
buff.Grow(goplay.MaxSnippetSize)

defer r.Body.Close()
_, err := io.Copy(&buff, io.LimitReader(r.Body, goplay.MaxSnippetSize+1))
if err != nil {
return nil, Errorf(http.StatusBadGateway, "failed to read request: %w", err)
}

if buff.Len() > goplay.MaxSnippetSize {
return nil, ErrSnippetTooLarge
}

return buff.Bytes(), nil
}

func isContentLengthError(err error) bool {
if httpErr, ok := goplay.IsHTTPError(err); ok {
if httpErr.StatusCode == http.StatusRequestEntityTooLarge {
return true
}
}

return false
}
14 changes: 4 additions & 10 deletions pkg/goplay/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ const (
DefaultUserAgent = "goplay.tools/1.0 (http://goplay.tools/)"
DefaultPlaygroundURL = "https://go.dev/_"

// maxSnippetSize value taken from
// MaxSnippetSize value taken from
// https://github.com/golang/playground/blob/master/app/goplay/share.go
maxSnippetSize = 64 * 1024
MaxSnippetSize = 64 * 1024
)

// ErrSnippetTooLarge is snippet max size limit error
var ErrSnippetTooLarge = fmt.Errorf("code snippet too large (max %d bytes)", maxSnippetSize)

// Client is Go Playground API client
type Client struct {
client http.Client
Expand Down Expand Up @@ -89,14 +86,11 @@ func (c *Client) doRequest(ctx context.Context, method, url, contentType string,
return nil, NewHTTPError(response)
}

bodyBytes := &bytes.Buffer{}
_, err = io.Copy(bodyBytes, io.LimitReader(response.Body, maxSnippetSize+1))
bodyBytes := bytes.Buffer{}
_, err = io.Copy(&bodyBytes, response.Body)
if err != nil {
return nil, err
}
if err = ValidateContentLength(bodyBytes.Len()); err != nil {
return nil, err
}

return bodyBytes.Bytes(), nil
}
Expand Down
8 changes: 0 additions & 8 deletions pkg/goplay/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,6 @@ import (
"net/url"
)

// ValidateContentLength validates snippet size
func ValidateContentLength(itemLen int) error {
if itemLen > maxSnippetSize {
return ErrSnippetTooLarge
}
return nil
}

// GetSnippet returns snippet from Go playground
func (c *Client) GetSnippet(ctx context.Context, snippetID string) (*Snippet, error) {
fileName := snippetID + ".go"
Expand Down
19 changes: 0 additions & 19 deletions pkg/goplay/methods_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"time"

Expand All @@ -16,24 +15,6 @@ import (
"github.com/x1unix/go-playground/pkg/testutil"
)

func TestValidateContentLength(t *testing.T) {
cases := map[int]bool{
maxSnippetSize: false,
maxSnippetSize + 10: true,
10: false,
}
for i, c := range cases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
err := ValidateContentLength(i)
if !c {
require.NoError(t, err)
return
}
require.Error(t, err)
})
}
}

func TestClient_Compile(t *testing.T) {
cases := map[string]struct {
expect *CompileResponse
Expand Down