Skip to content

Commit 6d65379

Browse files
committed
encoding: add support for compressor options in Compress methods
1 parent 2b8b708 commit 6d65379

8 files changed

Lines changed: 225 additions & 29 deletions

File tree

encoding/compressor_test.go

Lines changed: 176 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"bytes"
2323
"context"
2424
"io"
25+
"sync"
2526
"sync/atomic"
2627
"testing"
2728

@@ -38,16 +39,21 @@ import (
3839
_ "google.golang.org/grpc/encoding/gzip"
3940
)
4041

41-
// wrapCompressor is a wrapper of encoding.Compressor which maintains count of
42-
// Compressor method invokes.
42+
// wrapCompressor is a wrapper of encoding.Compressor which records invocation
43+
// count and the options passed to each Compress call.
4344
type wrapCompressor struct {
4445
encoding.Compressor
4546
compressInvokes int32
47+
mu sync.Mutex
48+
receivedOpts [][]any
4649
}
4750

48-
func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
51+
func (wc *wrapCompressor) Compress(w io.Writer, opts ...any) (io.WriteCloser, error) {
4952
atomic.AddInt32(&wc.compressInvokes, 1)
50-
return wc.Compressor.Compress(w)
53+
wc.mu.Lock()
54+
wc.receivedOpts = append(wc.receivedOpts, opts)
55+
wc.mu.Unlock()
56+
return wc.Compressor.Compress(w, opts...)
5157
}
5258

5359
func setupGzipWrapCompressor(t *testing.T) *wrapCompressor {
@@ -186,7 +192,7 @@ type fakeCompressor struct {
186192
decompressedMessageSize int
187193
}
188194

189-
func (f *fakeCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
195+
func (f *fakeCompressor) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) {
190196
return nopWriteCloser{w}, nil
191197
}
192198

@@ -237,3 +243,168 @@ func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) {
237243
t.Errorf("Client.UnaryCall(%+v) returned status %v, want %v", req, got, want)
238244
}
239245
}
246+
247+
// TestSetSendCompressorOptionsPropagate verifies that options passed to
248+
// SetSendCompressor are forwarded to the compressor's Compress method.
249+
func (s) TestSetSendCompressorOptionsPropagate(t *testing.T) {
250+
wantOpt := "dict-id-42"
251+
for _, tc := range []struct {
252+
name string
253+
run func(*testing.T, *wrapCompressor)
254+
}{
255+
{"unary", testUnarySendCompressorOptionsPropagate},
256+
{"stream", testStreamSendCompressorOptionsPropagate},
257+
} {
258+
t.Run(tc.name, func(t *testing.T) {
259+
wc := setupGzipWrapCompressor(t)
260+
tc.run(t, wc)
261+
wc.mu.Lock()
262+
defer wc.mu.Unlock()
263+
if len(wc.receivedOpts) == 0 {
264+
t.Fatal("Compress was not called")
265+
}
266+
if got := wc.receivedOpts[0]; len(got) == 0 || got[0] != wantOpt {
267+
t.Fatalf("Compress received opts %v, want [%q]", got, wantOpt)
268+
}
269+
})
270+
}
271+
}
272+
273+
func testUnarySendCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) {
274+
t.Helper()
275+
ss := &stubserver.StubServer{
276+
UnaryCallF: func(ctx context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
277+
if err := grpc.SetSendCompressor(ctx, "gzip", "dict-id-42"); err != nil {
278+
return nil, err
279+
}
280+
return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: []byte("payload")}}, nil
281+
},
282+
}
283+
if err := ss.Start(nil); err != nil {
284+
t.Fatalf("Error starting endpoint server: %v", err)
285+
}
286+
defer ss.Stop()
287+
288+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
289+
defer cancel()
290+
291+
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
292+
t.Fatalf("Unexpected unary call error: %v", err)
293+
}
294+
}
295+
296+
func testStreamSendCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) {
297+
t.Helper()
298+
ss := &stubserver.StubServer{
299+
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
300+
if _, err := stream.Recv(); err != nil {
301+
return err
302+
}
303+
if err := grpc.SetSendCompressor(stream.Context(), "gzip", "dict-id-42"); err != nil {
304+
return err
305+
}
306+
return stream.Send(&testpb.StreamingOutputCallResponse{
307+
Payload: &testpb.Payload{Body: []byte("payload")},
308+
})
309+
},
310+
}
311+
if err := ss.Start(nil); err != nil {
312+
t.Fatalf("Error starting endpoint server: %v", err)
313+
}
314+
defer ss.Stop()
315+
316+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
317+
defer cancel()
318+
319+
s, err := ss.Client.FullDuplexCall(ctx)
320+
if err != nil {
321+
t.Fatalf("Unexpected full duplex call error: %v", err)
322+
}
323+
if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
324+
t.Fatalf("Unexpected send error: %v", err)
325+
}
326+
if _, err := s.Recv(); err != nil {
327+
t.Fatalf("Unexpected recv error: %v", err)
328+
}
329+
}
330+
331+
// TestUseCompressorOptionsPropagate verifies that options passed to
332+
// UseCompressor are forwarded to the compressor's Compress method.
333+
func (s) TestUseCompressorOptionsPropagate(t *testing.T) {
334+
wantOpt := "dict-id-42"
335+
for _, tc := range []struct {
336+
name string
337+
run func(*testing.T, *wrapCompressor)
338+
}{
339+
{"unary", testUnaryUseCompressorOptionsPropagate},
340+
{"stream", testStreamUseCompressorOptionsPropagate},
341+
} {
342+
t.Run(tc.name, func(t *testing.T) {
343+
wc := setupGzipWrapCompressor(t)
344+
tc.run(t, wc)
345+
wc.mu.Lock()
346+
defer wc.mu.Unlock()
347+
if len(wc.receivedOpts) == 0 {
348+
t.Fatal("Compress was not called")
349+
}
350+
if got := wc.receivedOpts[0]; len(got) == 0 || got[0] != wantOpt {
351+
t.Fatalf("Compress received opts %v, want [%q]", got, wantOpt)
352+
}
353+
})
354+
}
355+
}
356+
357+
func testUnaryUseCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) {
358+
t.Helper()
359+
ss := &stubserver.StubServer{
360+
UnaryCallF: func(_ context.Context, _ *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
361+
return &testpb.SimpleResponse{Payload: &testpb.Payload{Body: []byte("payload")}}, nil
362+
},
363+
}
364+
if err := ss.Start(nil); err != nil {
365+
t.Fatalf("Error starting endpoint server: %v", err)
366+
}
367+
defer ss.Stop()
368+
369+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
370+
defer cancel()
371+
372+
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{Body: []byte("data")}}, grpc.UseCompressor("gzip", "dict-id-42")); err != nil {
373+
t.Fatalf("Unexpected unary call error: %v", err)
374+
}
375+
}
376+
377+
func testStreamUseCompressorOptionsPropagate(t *testing.T, _ *wrapCompressor) {
378+
t.Helper()
379+
ss := &stubserver.StubServer{
380+
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
381+
req, err := stream.Recv()
382+
if err != nil {
383+
return err
384+
}
385+
return stream.Send(&testpb.StreamingOutputCallResponse{
386+
Payload: req.GetPayload(),
387+
})
388+
},
389+
}
390+
if err := ss.Start(nil); err != nil {
391+
t.Fatalf("Error starting endpoint server: %v", err)
392+
}
393+
defer ss.Stop()
394+
395+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
396+
defer cancel()
397+
398+
s, err := ss.Client.FullDuplexCall(ctx, grpc.UseCompressor("gzip", "dict-id-42"))
399+
if err != nil {
400+
t.Fatalf("Unexpected full duplex call error: %v", err)
401+
}
402+
if err := s.Send(&testpb.StreamingOutputCallRequest{
403+
Payload: &testpb.Payload{Body: []byte("payload")},
404+
}); err != nil {
405+
t.Fatalf("Unexpected send error: %v", err)
406+
}
407+
if _, err := s.Recv(); err != nil {
408+
t.Fatalf("Unexpected recv error: %v", err)
409+
}
410+
}

encoding/encoding.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ func init() {
6161
type Compressor interface {
6262
// Compress writes the data written to wc to w after compressing it. If an
6363
// error occurs while initializing the compressor, that error is returned
64-
// instead.
65-
Compress(w io.Writer) (io.WriteCloser, error)
64+
// instead. opts passes caller-provided context to the compressor (e.g.
65+
// dictionary IDs for trained compression formats). Unknown options must
66+
// be silently ignored.
67+
Compress(w io.Writer, opts ...any) (io.WriteCloser, error)
6668
// Decompress reads data from r, decompresses it, and provides the
6769
// uncompressed data via the returned io.Reader. If an error occurs while
6870
// initializing the decompressor, that error is returned instead.

encoding/gzip/gzip.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func SetLevel(level int) error {
7070
return nil
7171
}
7272

73-
func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) {
73+
func (c *compressor) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) {
7474
z := c.poolCompressor.Get().(*writer)
7575
z.Writer.Reset(w)
7676
return z, nil

internal/transport/server_stream.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ type ServerStream struct {
5050
headerSent atomic.Bool // atomically set when the headers are sent out.
5151

5252
headerWireLength int
53+
54+
sendCompressOptions []any
5355
}
5456

5557
// Read reads an n byte message from the input stream.
@@ -108,13 +110,20 @@ func (s *ServerStream) ContentSubtype() string {
108110
return s.contentSubtype
109111
}
110112

111-
// SetSendCompress sets the compression algorithm to the stream.
112-
func (s *ServerStream) SetSendCompress(name string) error {
113+
// SendCompressOptions returns the compressor options set for the stream.
114+
func (s *ServerStream) SendCompressOptions() []any {
115+
return s.sendCompressOptions
116+
}
117+
118+
// SetSendCompress sets the compression algorithm to the stream. opts are
119+
// forwarded to the compressor's Compress method on each send.
120+
func (s *ServerStream) SetSendCompress(name string, opts ...any) error {
113121
if s.isHeaderSent() || s.getState() == streamDone {
114122
return errors.New("transport: set send compressor called after headers sent or stream done")
115123
}
116124

117125
s.sendCompress = name
126+
s.sendCompressOptions = opts
118127
return nil
119128
}
120129

rpc_util.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ type callInfo struct {
168168
onFinish []func(err error)
169169
authority string
170170
acceptedResponseCompressors []string
171+
compressorOptions []any
171172
}
172173

173174
func acceptedCompressorAllows(allowed []string, name string) bool {
@@ -490,14 +491,16 @@ func (o PerRPCCredsCallOption) after(*callInfo, *csAttempt) {}
490491

491492
// UseCompressor returns a CallOption which sets the compressor used when
492493
// sending the request. If WithCompressor is also set, UseCompressor has
493-
// higher priority.
494+
// higher priority. The optional compressorOptions are forwarded to the
495+
// compressor's Compress method, allowing callers to pass additional context
496+
// such as dictionary IDs for trained compression formats.
494497
//
495498
// # Experimental
496499
//
497500
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
498501
// later release.
499-
func UseCompressor(name string) CallOption {
500-
return CompressorCallOption{CompressorType: name}
502+
func UseCompressor(name string, compressorOptions ...any) CallOption {
503+
return CompressorCallOption{CompressorType: name, CompressorOptions: compressorOptions}
501504
}
502505

503506
// CompressorCallOption is a CallOption that indicates the compressor to use.
@@ -507,11 +510,13 @@ func UseCompressor(name string) CallOption {
507510
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
508511
// later release.
509512
type CompressorCallOption struct {
510-
CompressorType string
513+
CompressorType string
514+
CompressorOptions []any
511515
}
512516

513517
func (o CompressorCallOption) before(c *callInfo) error {
514518
c.compressorName = o.CompressorType
519+
c.compressorOptions = o.CompressorOptions
515520
return nil
516521
}
517522
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}
@@ -817,7 +822,7 @@ func encode(c baseCodec, msg any) (mem.BufferSlice, error) {
817822
// indicating no compression was done.
818823
//
819824
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
820-
func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) {
825+
func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool, compressorOptions ...any) (mem.BufferSlice, payloadFormat, error) {
821826
if (compressor == nil && cp == nil) || in.Len() == 0 {
822827
return nil, compressionNone, nil
823828
}
@@ -828,7 +833,7 @@ func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor,
828833
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
829834
}
830835
if compressor != nil {
831-
z, err := compressor.Compress(w)
836+
z, err := compressor.Compress(w, compressorOptions...)
832837
if err != nil {
833838
return nil, 0, wrapErr(err)
834839
}

rpc_util_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ type testCompressorForRegistry struct {
5252
name string
5353
}
5454

55-
func (c *testCompressorForRegistry) Compress(w io.Writer) (io.WriteCloser, error) {
55+
func (c *testCompressorForRegistry) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) {
5656
return &testWriteCloser{w}, nil
5757
}
5858

@@ -541,7 +541,7 @@ type mockCompressor struct {
541541
ch chan<- struct{}
542542
}
543543

544-
func (m *mockCompressor) Compress(io.Writer) (io.WriteCloser, error) {
544+
func (m *mockCompressor) Compress(io.Writer, ...any) (io.WriteCloser, error) {
545545
panic("unimplemented")
546546
}
547547

server.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,14 +1172,14 @@ func (s *Server) incrCallsFailed() {
11721172
s.channelz.ServerMetrics.CallsFailed.Add(1)
11731173
}
11741174

1175-
func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStream, msg any, cp Compressor, opts *transport.WriteOptions, comp encoding.Compressor) error {
1175+
func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStream, msg any, cp Compressor, opts *transport.WriteOptions, comp encoding.Compressor, compressorOptions ...any) error {
11761176
data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
11771177
if err != nil {
11781178
channelz.Error(logger, s.channelz, "grpc: server failed to encode response: ", err)
11791179
return err
11801180
}
11811181

1182-
compData, pf, err := compress(data, cp, comp, s.opts.bufferPool)
1182+
compData, pf, err := compress(data, cp, comp, s.opts.bufferPool, compressorOptions...)
11831183
if err != nil {
11841184
data.Free()
11851185
channelz.Error(logger, s.channelz, "grpc: server failed to compress response: ", err)
@@ -1474,7 +1474,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt
14741474
if stream.SendCompress() != sendCompressorName {
14751475
comp = encoding.GetCompressor(stream.SendCompress())
14761476
}
1477-
if err := s.sendResponse(ctx, stream, reply, cp, opts, comp); err != nil {
1477+
if err := s.sendResponse(ctx, stream, reply, cp, opts, comp, stream.SendCompressOptions()...); err != nil {
14781478
if err == io.EOF {
14791479
// The entire stream is done (for unary RPC only).
14801480
return err
@@ -2146,11 +2146,15 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
21462146
// It is not safe to call SetSendCompressor concurrently with SendHeader and
21472147
// SendMsg.
21482148
//
2149+
// The optional compressorOptions are forwarded to the compressor's Compress
2150+
// method on each SendMsg call, allowing callers to pass additional context
2151+
// such as dictionary IDs for trained compression formats.
2152+
//
21492153
// # Experimental
21502154
//
21512155
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
21522156
// later release.
2153-
func SetSendCompressor(ctx context.Context, name string) error {
2157+
func SetSendCompressor(ctx context.Context, name string, compressorOptions ...any) error {
21542158
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream)
21552159
if !ok || stream == nil {
21562160
return fmt.Errorf("failed to fetch the stream from the given context")
@@ -2160,7 +2164,7 @@ func SetSendCompressor(ctx context.Context, name string) error {
21602164
return fmt.Errorf("unable to set send compressor: %w", err)
21612165
}
21622166

2163-
return stream.SetSendCompress(name)
2167+
return stream.SetSendCompress(name, compressorOptions...)
21642168
}
21652169

21662170
// ClientSupportedCompressors returns compressor names advertised by the client

0 commit comments

Comments
 (0)