Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .github/workflows/prc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,14 @@ jobs:
go-apidiff:
if: github.event_name == 'pull_request'
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write # Required for commenting on PRs
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-go@v4
- uses: actions/setup-go@v5
with:
go-version: 1.19
- uses: joelanford/go-apidiff@main
go-version: 'stable'
- uses: imjasonh/apidiff[email protected]
45 changes: 29 additions & 16 deletions admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,29 @@ import (
"trpc.group/trpc-go/trpc-go/transport"
)

func init() {
// The pprof functionality supported by the admin package relies on the imported net/http/pprof package.
// However, the imported net/http/pprof package implicitly registers HTTP handlers for
// "/debug/pprof/", "/debug/pprof/cmdline", "/debug/pprof/profile", "/debug/pprof/symbol", "/debug/pprof/trace"
// in http.DefaultServeMux in its init function. This implicit behavior is too subtle and may contribute to people
// inadvertently leaving such endpoints open, and may cause security problems:https://github.com/golang/go/issues/22085
// if people use http.DefaultServeMux. So we decide to reset default serve mux to remove pprof registration.
// This requires making sure that people are not using http.DefaultServeMux before we reset it.
// In most cases, this works, which is guaranteed by the execution order of the init function.
// If you need to enable pprof on http.DefaultServeMux you need to
// register it explicitly after importing the admin package:
//
// http.DefaultServeMux.HandleFunc("/debug/pprof/", pprof.Index)
// http.DefaultServeMux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
// http.DefaultServeMux.HandleFunc("/debug/pprof/profile", pprof.Profile)
// http.DefaultServeMux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
// http.DefaultServeMux.HandleFunc("/debug/pprof/trace", pprof.Trace)
//
// Simply importing the net/http/pprof package anonymously will not work.
// More details see: https://git.woa.com/trpc-go/trpc-go/issues/912, and https://github.com/golang/go/issues/42834.
http.DefaultServeMux = http.NewServeMux()
}

// ServiceName is the service name of admin service.
const ServiceName = "admin"

Expand Down Expand Up @@ -122,21 +145,6 @@ func (s *Server) configRouter(r *router) *router {
for pattern, handler := range pattern2Handler {
r.add(pattern, handler)
}

// Delete the router registered with http.DefaultServeMux.
// Avoid causing security problems: https://github.com/golang/go/issues/22085.
err := unregisterHandlers(
[]string{
pprofPprof,
pprofCmdline,
pprofProfile,
pprofSymbol,
pprofTrace,
},
)
if err != nil {
log.Errorf("failed to unregister pprof handlers from http.DefaultServeMux, err: %+v", err)
}
return r
}

Expand Down Expand Up @@ -173,13 +181,18 @@ func (s *Server) Serve() error {
return err
}

log.Infof("admin service launch success, %s:%s, serving ...", ln.Addr().Network(), ln.Addr().String())

s.server = &http.Server{
Addr: ln.Addr().String(),
ReadTimeout: cfg.readTimeout,
WriteTimeout: cfg.writeTimeout,
Handler: s.router,
}
if err := s.server.Serve(ln); err != nil && err != http.ErrServerClosed {
// Restricted access to the internal/poll.ErrNetClosing type necessitates comparing a string literal.
const closeError = "use of closed network connection"
if err := s.server.Serve(ln); err != nil &&
err != http.ErrServerClosed && !strings.Contains(err.Error(), closeError) {
return err
}
return nil
Expand Down
133 changes: 73 additions & 60 deletions admin/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"io"
"net"
"net/http"
"net/http/pprof"
"os"
"reflect"
"strings"
Expand Down Expand Up @@ -574,10 +575,10 @@ func TestOptionsConfig(t *testing.T) {

func httpRequest(method string, url string, body string) ([]byte, error) {
request, err := http.NewRequest(method, url, strings.NewReader(body))
request.Header.Set("content-type", "application/x-www-form-urlencoded")
if err != nil {
return nil, err
}
request.Header.Set("content-type", "application/x-www-form-urlencoded")

response, err := http.DefaultClient.Do(request)
if err != nil {
Expand All @@ -599,72 +600,84 @@ func panicHandle(w http.ResponseWriter, r *http.Request) {
panic("panic error handle")
}

func TestUnregisterHandlers(t *testing.T) {
_ = newDefaultAdminServer()
mux, err := extractServeMuxData()
require.Nil(t, err)
require.Len(t, mux.m, 0)
require.Len(t, mux.es, 0)
require.False(t, mux.hosts)

http.HandleFunc("/usercmd", userCmd)
http.HandleFunc("/errout", errOutput)
http.HandleFunc("/panicHandle", panicHandle)
http.HandleFunc("www.qq.com/", userCmd)
http.HandleFunc("anything/", userCmd)
func Test_init(t *testing.T) {
t.Run("reset default serve mux to remove pprof registration at admin init func", func(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
require.Nil(t, err)
go func() {
server := &http.Server{
Handler: nil,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}

if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
t.Logf("http serving: %v", err)
}
}()
time.Sleep(200 * time.Millisecond)

r, err := http.Get(fmt.Sprintf("http://%s/debug/pprof/", l.Addr().String()))
require.Nil(t, err)
require.Equal(t, http.StatusNotFound, r.StatusCode)

l := mustListenTCP(t)
go func() {
if err := http.Serve(l, nil); err != nil {
t.Log(err)
}
}()
time.Sleep(200 * time.Millisecond)
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/cmdline", l.Addr().String()))
require.Nil(t, err)
require.Equal(t, http.StatusNotFound, r.StatusCode)

mux, err = extractServeMuxData()
require.Nil(t, err)
require.Equal(t, 5, len(mux.m))
require.Equal(t, 2, len(mux.es))
require.Equal(t, true, mux.hosts)

err = unregisterHandlers(
[]string{
"/usercmd",
"/errout",
"/panicHandle",
"www.qq.com/",
"anything/",
},
)
require.Nil(t, err)
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/profile", l.Addr().String()))
require.Nil(t, err)
require.Equal(t, http.StatusNotFound, r.StatusCode)

mux, err = extractServeMuxData()
require.Nil(t, err)
require.Len(t, mux.m, 0)
require.Len(t, mux.es, 0)
require.False(t, mux.hosts)
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/symbol", l.Addr().String()))
require.Nil(t, err)
require.Equal(t, http.StatusNotFound, r.StatusCode)

resp1, err := http.Get(fmt.Sprintf("http://%v/usercmd", l.Addr()))
require.Nil(t, err)
defer resp1.Body.Close()
require.Equal(t, http.StatusNotFound, resp1.StatusCode)
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/trace", l.Addr().String()))
require.Nil(t, err)
require.Equal(t, http.StatusNotFound, r.StatusCode)
})
t.Run("register pprof handler explicitly after importing the admin package", func(t *testing.T) {
http.DefaultServeMux.HandleFunc("/debug/pprof/", pprof.Index)
http.DefaultServeMux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
http.DefaultServeMux.HandleFunc("/debug/pprof/profile", pprof.Profile)
http.DefaultServeMux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
http.DefaultServeMux.HandleFunc("/debug/pprof/trace", pprof.Trace)
t.Cleanup(func() {
http.DefaultServeMux = http.NewServeMux()
})
l, err := net.Listen("tcp", "127.0.0.1:0")
require.Nil(t, err)
go func() {
server := &http.Server{
Handler: nil,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
t.Logf("http serving: %v", err)
}
}()
time.Sleep(200 * time.Millisecond)

r, err := http.Get(fmt.Sprintf("http://%s/debug/pprof/", l.Addr().String()))
require.Nil(t, err)
require.Equal(t, http.StatusOK, r.StatusCode)

http.HandleFunc("/usercmd", userCmd)
http.HandleFunc("/errout", errOutput)
http.HandleFunc("/panicHandle", panicHandle)
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/cmdline", l.Addr().String()))
require.Nil(t, err)
require.Equal(t, http.StatusOK, r.StatusCode)

mux, err = extractServeMuxData()
require.Nil(t, err)
require.Len(t, mux.m, 3)
require.Len(t, mux.es, 0)
require.False(t, mux.hosts)
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/symbol", l.Addr().String()))
require.Nil(t, err)
require.Equal(t, http.StatusOK, r.StatusCode)

resp2, err := http.Get(fmt.Sprintf("http://%v/usercmd", l.Addr()))
require.Nil(t, err)
defer resp2.Body.Close()
respBody, err := io.ReadAll(resp2.Body)
require.Nil(t, err)
require.Equal(t, []byte("usercmd"), respBody)
r, err = http.Get(fmt.Sprintf("http://%s/debug/pprof/trace", l.Addr().String()))
require.Nil(t, err)
require.Equal(t, http.StatusOK, r.StatusCode)
})
}
func mustListenTCP(t *testing.T) *net.TCPListener {
l, err := net.Listen("tcp", testAddress)
Expand Down
92 changes: 0 additions & 92 deletions admin/mux.go

This file was deleted.

Loading