-
Notifications
You must be signed in to change notification settings - Fork 221
Expand file tree
/
Copy pathresponse_processor.go
More file actions
248 lines (219 loc) · 8.06 KB
/
response_processor.go
File metadata and controls
248 lines (219 loc) · 8.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0
// Package transparent provides a transparent HTTP proxy implementation
// that forwards requests to a destination without modifying them.
package transparent
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math"
"mime"
"net/http"
"strings"
"github.com/stacklok/toolhive/pkg/transport/types"
)
// maxJSONRPCResponseBytes caps how much of an upstream JSON-RPC response the proxy
// will buffer for structural validation. Matches existing streamable-HTTP body
// limits elsewhere in the codebase (pkg/vmcp/client, pkg/vmcp/session/internal/backend).
const maxJSONRPCResponseBytes = 100 << 20 // 100 MiB
// JSON-RPC error code returned to clients when the proxy rejects a malformed
// upstream response. -32000..-32099 is the implementation-defined server-error
// range in the JSON-RPC 2.0 spec; -32603 is reserved for internal JSON-RPC
// implementation errors and is not appropriate for a policy-level rejection.
const jsonRPCInvalidUpstreamCode = -32000
// ResponseProcessor defines the interface for processing and modifying HTTP responses
// based on transport-specific requirements.
type ResponseProcessor interface {
// ProcessResponse modifies an HTTP response based on transport-specific logic.
// Returns error if processing fails.
ProcessResponse(resp *http.Response) error
// ShouldProcess returns true if this processor should handle the given response.
ShouldProcess(resp *http.Response) bool
}
// NoOpResponseProcessor is the default processor for non-SSE transports.
// It validates JSON-RPC responses for streamable HTTP and otherwise leaves responses unchanged.
type NoOpResponseProcessor struct{}
// ProcessResponse validates JSON-RPC responses when applicable.
func (*NoOpResponseProcessor) ProcessResponse(resp *http.Response) error {
if !shouldValidateJSONRPCResponse(resp) {
return nil
}
// Read one byte past the cap so we can detect oversize without allocating beyond it.
body, err := io.ReadAll(io.LimitReader(resp.Body, maxJSONRPCResponseBytes+1))
if err != nil {
return fmt.Errorf("failed to read upstream response body: %w", err)
}
_ = resp.Body.Close()
if len(body) > maxJSONRPCResponseBytes {
writeInvalidUpstreamJSONRPCResponse(resp, fmt.Errorf(
"upstream JSON-RPC response exceeds maximum allowed size of %d bytes", maxJSONRPCResponseBytes))
return nil
}
if err := validateJSONRPCResponse(body); err != nil {
writeInvalidUpstreamJSONRPCResponse(resp, err)
return nil
}
// The reverse proxy still needs a readable body after validation.
resp.Body = io.NopCloser(bytes.NewReader(body))
resp.ContentLength = int64(len(body))
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
return nil
}
// ShouldProcess always returns false for no-op processor.
func (*NoOpResponseProcessor) ShouldProcess(_ *http.Response) bool {
return false
}
func shouldValidateJSONRPCResponse(resp *http.Response) bool {
if resp == nil || resp.Body == nil || resp.Request == nil {
return false
}
if resp.Request.Method != http.MethodPost || resp.StatusCode != http.StatusOK {
return false
}
if !hasIdentityContentEncoding(resp.Header.Get("Content-Encoding")) {
// Content-Encoding semantics (RFC 9110): media-type rules apply after decoding.
// Validating a still-encoded body would mis-classify legitimate gzip JSON-RPC
// frames as invalid. Skip rather than introduce decompression here.
return false
}
if !requestLooksLikeMCP(resp.Request) {
// Narrow validation to traffic that carries an MCP streamable-HTTP signal,
// so non-MCP application/json POSTs flowing through the catch-all are not
// rewritten. Backward-compat clients omitting MCP-Protocol-Version on the
// initial initialize will pass through unchanged.
return false
}
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return false
}
return mediaType == "application/json" || mediaType == "application/json-rpc"
}
func hasIdentityContentEncoding(value string) bool {
v := strings.TrimSpace(strings.ToLower(value))
return v == "" || v == "identity"
}
func requestLooksLikeMCP(req *http.Request) bool {
if req == nil {
return false
}
return req.Header.Get("MCP-Protocol-Version") != "" || req.Header.Get("Mcp-Session-Id") != ""
}
func validateJSONRPCResponse(body []byte) error {
var payload any
dec := json.NewDecoder(bytes.NewReader(body))
if err := dec.Decode(&payload); err != nil {
return fmt.Errorf("invalid JSON body: %w", err)
}
if dec.More() {
return fmt.Errorf("JSON-RPC response must contain a single JSON value")
}
if err := dec.Decode(&struct{}{}); err != io.EOF {
return fmt.Errorf("JSON-RPC response must contain a single JSON value")
}
switch value := payload.(type) {
case map[string]any:
return validateJSONRPCResponseObject(value)
case []any:
if len(value) == 0 {
return fmt.Errorf("JSON-RPC batch response must not be empty")
}
for i, item := range value {
obj, ok := item.(map[string]any)
if !ok {
return fmt.Errorf("JSON-RPC batch item %d must be an object", i)
}
if err := validateJSONRPCResponseObject(obj); err != nil {
return fmt.Errorf("JSON-RPC batch item %d is invalid: %w", i, err)
}
}
return nil
default:
return fmt.Errorf("JSON-RPC response must be an object or array")
}
}
func validateJSONRPCResponseObject(obj map[string]any) error {
if obj["jsonrpc"] != "2.0" {
return fmt.Errorf(`JSON-RPC response must include "jsonrpc":"2.0"`)
}
if _, ok := obj["id"]; !ok {
return fmt.Errorf("JSON-RPC response must include id")
}
if !isValidJSONRPCID(obj["id"]) {
return fmt.Errorf("JSON-RPC response id must be string, number, or null")
}
_, hasResult := obj["result"]
_, hasError := obj["error"]
if hasResult == hasError {
return fmt.Errorf("JSON-RPC response must include exactly one of result or error")
}
if hasError {
if errObj, ok := obj["error"].(map[string]any); !ok || !isValidJSONRPCError(errObj) {
return fmt.Errorf("JSON-RPC error response must include error.code and error.message")
}
}
return nil
}
func isValidJSONRPCID(id any) bool {
switch id.(type) {
case nil, string, float64:
return true
default:
return false
}
}
func isValidJSONRPCError(errObj map[string]any) bool {
code, codeOK := errObj["code"].(float64)
if !codeOK || math.Trunc(code) != code {
// JSON-RPC 2.0 requires error.code to be an integer.
return false
}
_, messageOK := errObj["message"].(string)
return messageOK
}
func writeInvalidUpstreamJSONRPCResponse(resp *http.Response, validationErr error) {
body, err := json.Marshal(map[string]any{
"jsonrpc": "2.0",
"error": map[string]any{
"code": jsonRPCInvalidUpstreamCode,
"message": "Invalid upstream JSON-RPC response",
"data": validationErr.Error(),
},
"id": nil,
})
if err != nil {
body = []byte(`{"jsonrpc":"2.0","error":{"code":-32000,"message":"Invalid upstream JSON-RPC response"},"id":null}`)
}
resp.StatusCode = http.StatusBadGateway
resp.Status = fmt.Sprintf("%d %s", http.StatusBadGateway, http.StatusText(http.StatusBadGateway))
resp.Body = io.NopCloser(bytes.NewReader(body))
resp.ContentLength = int64(len(body))
// Replace headers wholesale so upstream session/cookie/cache metadata is not
// smuggled into the proxy-generated error. Only carry the fields needed to
// describe this synthetic body.
resp.Header = http.Header{}
resp.Header.Set("Content-Type", "application/json")
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
resp.Trailer = nil
}
// createResponseProcessor is a factory function that creates the appropriate
// response processor based on transport type.
func createResponseProcessor(
transportType string,
proxy *TransparentProxy,
endpointPrefix string,
trustProxyHeaders bool,
) ResponseProcessor {
switch transportType {
case types.TransportTypeSSE.String():
return NewSSEResponseProcessor(proxy, endpointPrefix, trustProxyHeaders)
case types.TransportTypeStreamableHTTP.String():
return &NoOpResponseProcessor{}
default:
// Default to no-op for unknown transport types
return &NoOpResponseProcessor{}
}
}