Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ jobs:
strategy:
matrix:
go:
- 1.18.x
- 1.17.x
- 1.16.x
- 1.26.x
- 1.25.x
os:
- ubuntu-latest
name: ${{ matrix.os }}/go${{ matrix.go }}
Expand Down
10 changes: 10 additions & 0 deletions capture_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ func (m *Metrics) CaptureMetrics(w http.ResponseWriter, fn func(http.ResponseWri
}
},

WriteString: func(next WriteStringFunc) WriteStringFunc {
return func(s string) (int, error) {
n, err := next(s)

m.Written += int64(n)
headerWritten = true
return n, err
}
},

ReadFrom: func(next ReadFromFunc) ReadFromFunc {
return func(src io.Reader) (int64, error) {
n, err := next(src)
Expand Down
18 changes: 0 additions & 18 deletions bench_test.go → capture_metrics_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,6 @@ func BenchmarkCaptureMetricsTwice(b *testing.B) {
benchmark(b, 2)
}

func BenchmarkWrap(b *testing.B) {
b.StopTimer()
doneCh := make(chan struct{}, 1)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b.StartTimer()
for i := 0; i < b.N; i++ {
Wrap(w, Hooks{})
}
doneCh <- struct{}{}
})
s := httptest.NewServer(h)
defer s.Close()
if _, err := http.Get(s.URL); err != nil {
b.Fatal(err)
}
<-doneCh
}

func benchmark(b *testing.B, wrappings int) {
dummyH := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h := dummyH
Expand Down
10 changes: 9 additions & 1 deletion capture_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ func TestCaptureMetrics(t *testing.T) {
WantWritten: 17,
WantCode: http.StatusOK,
},
{
Name: "string writer",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "write string")
}),
WantWritten: int64(len("write string")),
WantCode: http.StatusOK,
},
{
Name: "empty panic",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -104,7 +112,7 @@ func TestCaptureMetrics(t *testing.T) {
}
if err == nil {
defer res.Body.Close()
}
}
m := <-ch
if m.Code != test.WantCode {
t.Errorf("test %d: got=%d want=%d", i, m.Code, test.WantCode)
Expand Down
197 changes: 127 additions & 70 deletions codegen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,25 @@ type Build struct {
}

func (b *Build) MustBuild() {
prefix := "wrap_generated_"
b.Implementation().MustWriteFile(prefix + b.Suffix + ".go")
b.Tests().MustWriteFile(prefix + b.Suffix + "_test.go")
prefix := "wrap_generated"
if b.Suffix != "" {
prefix += "_" + b.Suffix
}

b.Implementation().MustWriteFile(prefix + ".go")
b.Tests().MustWriteFile(prefix + "_test.go")
}

func (b *Build) writeHeader(g *Generator) {
g.Printf(`
// +build %s
// Code generated by "httpsnoop/codegen"; DO NOT EDIT.
if b.Tags != "" {
g.Printf(`// +build %s
`, b.Tags)
}
g.buf.WriteString(`// Code generated by "httpsnoop/codegen"; DO NOT EDIT.

package httpsnoop

`, b.Tags)
`)
}

func (b *Build) Implementation() *Generator {
Expand Down Expand Up @@ -89,65 +95,111 @@ type Hooks struct {
// hooks can be used.
`, strings.Join(docList, "\n"))
g.Printf("func Wrap(w http.ResponseWriter, hooks Hooks) http.ResponseWriter {\n")
g.Printf("rw := &rw{w: w, h: hooks}\n")
g.Printf("state := &rwState{w: w}\n")

// Precompute hook chains once per Wrap call and
// build a uint8 combo index so the switch compiles to a jump table.
g.Printf("var combo uint8\n")
for _, fn := range ifaces[0].Funcs {
g.Printf("if hooks.%s != nil {\n", fn.Name)
g.Printf("state.%s = hooks.%s(w.%s)\n", fieldName(fn.Name), fn.Name, fn.Name)
g.Printf("}\n")
}

for i, iface := range subIfaces {
g.Printf("_, i%d := w.(%s)\n", i, iface.Name)
g.Printf("if t%[1]d, i%[1]d := w.(%s); i%[1]d {\n", i, iface.Name)
bit := len(subIfaces) - i - 1
g.Printf("combo |= 1<<%d\n", bit)
for _, fn := range iface.Funcs {
g.Printf("if hooks.%s != nil {\n", fn.Name)
g.Printf("state.%s = hooks.%s(t%d.%s)\n", fieldName(fn.Name), fn.Name, i, fn.Name)
g.Printf("}\n")
}
g.Printf("}\n")
}
g.Printf("switch {\n")

g.Printf("switch combo {\n")
combinations := 1 << uint(len(subIfaces))
for i := 0; i < combinations; i++ {
conditions := make([]string, len(subIfaces))
fields := make([]string, 0, len(subIfaces))
fields = append(fields, "Unwrapper", "http.ResponseWriter")
for j, iface := range subIfaces {
ok := i&(1<<uint(len(subIfaces)-j-1)) > 0
if !ok {
conditions[j] = "!"
} else {
fields = append(fields, iface.Name)
}
conditions[j] += fmt.Sprintf("i%d", j)
}
values := make([]string, len(fields))
for i := range fields {
values[i] = "rw"
}
g.Printf("// combination %d/%d\n", i+1, combinations)
g.Printf("case %s:\n", strings.Join(conditions, "&&"))
fieldsS, valuesS := strings.Join(fields, "\n"), strings.Join(values, ",")
g.Printf("return struct{\n%s\n}{%s}\n", fieldsS, valuesS)
for c := 0; c < combinations; c++ {
g.Printf("case %d: return (*rw%d)(state)\n", c, c)
}
g.Printf("}\n")
g.Printf("panic(\"unreachable\")")
g.Printf("}\n")
g.Printf("}\n\n")

// rw struct
g.Printf(`
type rw struct {
w http.ResponseWriter
h Hooks
}

func (w *rw) Unwrap() http.ResponseWriter {
return w.w
}
// rwState holds the underlying writer plus the precomputed hooks.
// All variant types are type-definitions over rwState, so a single *rwState
// allocation can be reinterpreted as any variant via pointer conversion.
g.Printf("type rwState struct {\n")
g.Printf("w http.ResponseWriter\n")
for _, iface := range ifaces {
for _, fn := range iface.Funcs {
g.Printf("%s %s\n", fieldName(fn.Name), fn.Type())
}
}
g.Printf("}\n\n")

`)
// do<Name> helpers on *rwState
// These actual dispatch logic, defined once and called by the variant types.
for _, iface := range ifaces {
for _, fn := range iface.Funcs {
g.Printf("func (w *rw) %s(%s) (%s) {\n", fn.Name, fn.Args, fn.Returns)
g.Printf("f := w.w.(%s).%s\n", iface.Name, fn.Name)
g.Printf("if w.h.%s != nil {\n", fn.Name)
g.Printf("f = w.h.%s(f)\n", fn.Name)
g.Printf("}\n")
g.Printf("func (r *rwState) do%s(%s) (%s) {\n", fn.Name, fn.Args, fn.Returns)
g.Printf("if r.%s != nil {\n", fieldName(fn.Name))
if fn.Returns != "" {
g.Printf("return ")
g.Printf("return r.%s(%s)\n", fieldName(fn.Name), fn.Args.Names())
} else {
g.Printf("r.%s(%s)\n", fieldName(fn.Name), fn.Args.Names())
g.Printf("return\n")
}
g.Printf("f(%s)\n", fn.Args.Names())
g.Printf("}\n")
g.Printf("\n")

receiver := "r.w"
if iface.Name != "http.ResponseWriter" {
receiver = fmt.Sprintf("r.w.(%s)", iface.Name)
}
if fn.Returns != "" {
g.Printf("return %s.%s(%s)\n", receiver, fn.Name, fn.Args.Names())
} else {
g.Printf("%s.%s(%s)\n", receiver, fn.Name, fn.Args.Names())
}
g.Printf("}\n\n")
}
}

// Variant types, each is a type with the same memory layout as rwState,
// but exposing exactly the method set required by its combination of interfaces.
// This allows (*rwN)(state) to be a zero-cost pointer conversion.
emitVariantMethod := func(c int, fn *InterfaceFunc) {
g.Printf("func (w *rw%d) %s(%s) (%s) {\n", c, fn.Name, fn.Args, fn.Returns)
if fn.Returns != "" {
g.Printf("return (*rwState)(w).do%s(%s)\n", fn.Name, fn.Args.Names())
} else {
g.Printf("(*rwState)(w).do%s(%s)\n", fn.Name, fn.Args.Names())
}
g.Printf("}\n")
}
for c := 0; c < combinations; c++ {
supported := []string{"http.ResponseWriter"}
for j, iface := range subIfaces {
if c&(1<<uint(len(subIfaces)-j-1)) > 0 {
supported = append(supported, iface.Name)
}
}
g.Printf("// combination %d/%d: %s\n", c+1, combinations, strings.Join(supported, ", "))
g.Printf("type rw%d rwState\n", c)
g.Printf("func (w *rw%d) Unwrap() http.ResponseWriter { return w.w }\n", c)
for _, fn := range ifaces[0].Funcs {
emitVariantMethod(c, fn)
}
for j, iface := range subIfaces {
if c&(1<<uint(len(subIfaces)-j-1)) > 0 {
for _, fn := range iface.Funcs {
emitVariantMethod(c, fn)
}
}
}
g.Printf("\n")
}
g.Printf(`
type Unwrapper interface {
Unwrap() http.ResponseWriter
Expand All @@ -159,9 +211,8 @@ func Unwrap(w http.ResponseWriter) http.ResponseWriter {
if rw, ok := w.(Unwrapper); ok {
// recurse until rw.Unwrap() returns a non-Unwrapper
return Unwrap(rw.Unwrap())
} else {
return w
}
return w
}
`)
return &g
Expand Down Expand Up @@ -263,12 +314,19 @@ func (fn *InterfaceFunc) Type() string {
return fn.Name + "Func"
}

func fieldName(s string) string {
if s == "" {
return s
}
return strings.ToLower(s[:1]) + s[1:]
}

type Generator struct {
buf bytes.Buffer
}

func (g *Generator) Printf(s string, args ...interface{}) {
fmt.Fprintf(&g.buf, s, args...)
_, _ = fmt.Fprintf(&g.buf, s, args...)
}

func (g *Generator) WriteFile(name string) error {
Expand Down Expand Up @@ -342,33 +400,32 @@ func main() {
{"EnableFullDuplex", nil, "error"},
},
},
{
Name: "http.Pusher",
Funcs: []*InterfaceFunc{
{"Push", FuncArgs{
{"target", "string"},
{"opts", "*http.PushOptions"},
}, "error"},
},
}, {
Name: "io.StringWriter",
Funcs: []*InterfaceFunc{
{"WriteString", FuncArgs{{"s", "string"}}, "int, error"},
},
},
}
builds := []Build{
{
Suffix: "lt_1.8",
Tags: "!go1.8",
Interfaces: ifaces,
},
{
Suffix: "gteq_1.8",
Tags: "go1.8",
Interfaces: append(ifaces, &Interface{
Name: "http.Pusher",
Funcs: []*InterfaceFunc{
{"Push", FuncArgs{
{"target", "string"},
{"opts", "*http.PushOptions"},
}, "error"},
},
}),
},
}
for _, build := range builds {
build.MustBuild()
}
}

func fatalf(s string, args ...interface{}) {
fmt.Fprintf(os.Stderr, s+"\n", args...)
_, _ = fmt.Fprintf(os.Stderr, s+"\n", args...)
os.Exit(1)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/felixge/httpsnoop

go 1.13
go 1.25
Loading