@@ -22,6 +22,7 @@ import (
2222 "bytes"
2323 "context"
2424 "io"
25+ "sync"
2526 "sync/atomic"
2627 "testing"
2728
@@ -38,16 +39,21 @@ import (
3839 _ "google.golang.org/grpc/encoding/gzip"
3940)
4041
41- // wrapCompressor is a wrapper of encoding.Compressor which maintains count of
42- // Compressor method invokes .
42+ // wrapCompressor is a wrapper of encoding.Compressor which records invocation
43+ // count and the options passed to each Compress call .
4344type wrapCompressor struct {
4445 encoding.Compressor
4546 compressInvokes int32
47+ mu sync.Mutex
48+ receivedOpts [][]any
4649}
4750
48- func (wc * wrapCompressor ) Compress (w io.Writer ) (io.WriteCloser , error ) {
51+ func (wc * wrapCompressor ) Compress (w io.Writer , opts ... any ) (io.WriteCloser , error ) {
4952 atomic .AddInt32 (& wc .compressInvokes , 1 )
50- return wc .Compressor .Compress (w )
53+ wc .mu .Lock ()
54+ wc .receivedOpts = append (wc .receivedOpts , opts )
55+ wc .mu .Unlock ()
56+ return wc .Compressor .Compress (w , opts ... )
5157}
5258
5359func setupGzipWrapCompressor (t * testing.T ) * wrapCompressor {
@@ -186,7 +192,7 @@ type fakeCompressor struct {
186192 decompressedMessageSize int
187193}
188194
189- func (f * fakeCompressor ) Compress (w io.Writer ) (io.WriteCloser , error ) {
195+ func (f * fakeCompressor ) Compress (w io.Writer , _ ... any ) (io.WriteCloser , error ) {
190196 return nopWriteCloser {w }, nil
191197}
192198
@@ -237,3 +243,168 @@ func (s) TestDecompressionExceedsMaxMessageSize(t *testing.T) {
237243 t .Errorf ("Client.UnaryCall(%+v) returned status %v, want %v" , req , got , want )
238244 }
239245}
246+
247+ // TestSetSendCompressorOptionsPropagate verifies that options passed to
248+ // SetSendCompressor are forwarded to the compressor's Compress method.
249+ func (s ) TestSetSendCompressorOptionsPropagate (t * testing.T ) {
250+ wantOpt := "dict-id-42"
251+ for _ , tc := range []struct {
252+ name string
253+ run func (* testing.T , * wrapCompressor )
254+ }{
255+ {"unary" , testUnarySendCompressorOptionsPropagate },
256+ {"stream" , testStreamSendCompressorOptionsPropagate },
257+ } {
258+ t .Run (tc .name , func (t * testing.T ) {
259+ wc := setupGzipWrapCompressor (t )
260+ tc .run (t , wc )
261+ wc .mu .Lock ()
262+ defer wc .mu .Unlock ()
263+ if len (wc .receivedOpts ) == 0 {
264+ t .Fatal ("Compress was not called" )
265+ }
266+ if got := wc .receivedOpts [0 ]; len (got ) == 0 || got [0 ] != wantOpt {
267+ t .Fatalf ("Compress received opts %v, want [%q]" , got , wantOpt )
268+ }
269+ })
270+ }
271+ }
272+
273+ func testUnarySendCompressorOptionsPropagate (t * testing.T , _ * wrapCompressor ) {
274+ t .Helper ()
275+ ss := & stubserver.StubServer {
276+ UnaryCallF : func (ctx context.Context , _ * testpb.SimpleRequest ) (* testpb.SimpleResponse , error ) {
277+ if err := grpc .SetSendCompressor (ctx , "gzip" , "dict-id-42" ); err != nil {
278+ return nil , err
279+ }
280+ return & testpb.SimpleResponse {Payload : & testpb.Payload {Body : []byte ("payload" )}}, nil
281+ },
282+ }
283+ if err := ss .Start (nil ); err != nil {
284+ t .Fatalf ("Error starting endpoint server: %v" , err )
285+ }
286+ defer ss .Stop ()
287+
288+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
289+ defer cancel ()
290+
291+ if _ , err := ss .Client .UnaryCall (ctx , & testpb.SimpleRequest {}); err != nil {
292+ t .Fatalf ("Unexpected unary call error: %v" , err )
293+ }
294+ }
295+
296+ func testStreamSendCompressorOptionsPropagate (t * testing.T , _ * wrapCompressor ) {
297+ t .Helper ()
298+ ss := & stubserver.StubServer {
299+ FullDuplexCallF : func (stream testgrpc.TestService_FullDuplexCallServer ) error {
300+ if _ , err := stream .Recv (); err != nil {
301+ return err
302+ }
303+ if err := grpc .SetSendCompressor (stream .Context (), "gzip" , "dict-id-42" ); err != nil {
304+ return err
305+ }
306+ return stream .Send (& testpb.StreamingOutputCallResponse {
307+ Payload : & testpb.Payload {Body : []byte ("payload" )},
308+ })
309+ },
310+ }
311+ if err := ss .Start (nil ); err != nil {
312+ t .Fatalf ("Error starting endpoint server: %v" , err )
313+ }
314+ defer ss .Stop ()
315+
316+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
317+ defer cancel ()
318+
319+ s , err := ss .Client .FullDuplexCall (ctx )
320+ if err != nil {
321+ t .Fatalf ("Unexpected full duplex call error: %v" , err )
322+ }
323+ if err := s .Send (& testpb.StreamingOutputCallRequest {}); err != nil {
324+ t .Fatalf ("Unexpected send error: %v" , err )
325+ }
326+ if _ , err := s .Recv (); err != nil {
327+ t .Fatalf ("Unexpected recv error: %v" , err )
328+ }
329+ }
330+
331+ // TestUseCompressorOptionsPropagate verifies that options passed to
332+ // UseCompressor are forwarded to the compressor's Compress method.
333+ func (s ) TestUseCompressorOptionsPropagate (t * testing.T ) {
334+ wantOpt := "dict-id-42"
335+ for _ , tc := range []struct {
336+ name string
337+ run func (* testing.T , * wrapCompressor )
338+ }{
339+ {"unary" , testUnaryUseCompressorOptionsPropagate },
340+ {"stream" , testStreamUseCompressorOptionsPropagate },
341+ } {
342+ t .Run (tc .name , func (t * testing.T ) {
343+ wc := setupGzipWrapCompressor (t )
344+ tc .run (t , wc )
345+ wc .mu .Lock ()
346+ defer wc .mu .Unlock ()
347+ if len (wc .receivedOpts ) == 0 {
348+ t .Fatal ("Compress was not called" )
349+ }
350+ if got := wc .receivedOpts [0 ]; len (got ) == 0 || got [0 ] != wantOpt {
351+ t .Fatalf ("Compress received opts %v, want [%q]" , got , wantOpt )
352+ }
353+ })
354+ }
355+ }
356+
357+ func testUnaryUseCompressorOptionsPropagate (t * testing.T , _ * wrapCompressor ) {
358+ t .Helper ()
359+ ss := & stubserver.StubServer {
360+ UnaryCallF : func (_ context.Context , _ * testpb.SimpleRequest ) (* testpb.SimpleResponse , error ) {
361+ return & testpb.SimpleResponse {Payload : & testpb.Payload {Body : []byte ("payload" )}}, nil
362+ },
363+ }
364+ if err := ss .Start (nil ); err != nil {
365+ t .Fatalf ("Error starting endpoint server: %v" , err )
366+ }
367+ defer ss .Stop ()
368+
369+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
370+ defer cancel ()
371+
372+ if _ , err := ss .Client .UnaryCall (ctx , & testpb.SimpleRequest {Payload : & testpb.Payload {Body : []byte ("data" )}}, grpc .UseCompressor ("gzip" , "dict-id-42" )); err != nil {
373+ t .Fatalf ("Unexpected unary call error: %v" , err )
374+ }
375+ }
376+
377+ func testStreamUseCompressorOptionsPropagate (t * testing.T , _ * wrapCompressor ) {
378+ t .Helper ()
379+ ss := & stubserver.StubServer {
380+ FullDuplexCallF : func (stream testgrpc.TestService_FullDuplexCallServer ) error {
381+ req , err := stream .Recv ()
382+ if err != nil {
383+ return err
384+ }
385+ return stream .Send (& testpb.StreamingOutputCallResponse {
386+ Payload : req .GetPayload (),
387+ })
388+ },
389+ }
390+ if err := ss .Start (nil ); err != nil {
391+ t .Fatalf ("Error starting endpoint server: %v" , err )
392+ }
393+ defer ss .Stop ()
394+
395+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
396+ defer cancel ()
397+
398+ s , err := ss .Client .FullDuplexCall (ctx , grpc .UseCompressor ("gzip" , "dict-id-42" ))
399+ if err != nil {
400+ t .Fatalf ("Unexpected full duplex call error: %v" , err )
401+ }
402+ if err := s .Send (& testpb.StreamingOutputCallRequest {
403+ Payload : & testpb.Payload {Body : []byte ("payload" )},
404+ }); err != nil {
405+ t .Fatalf ("Unexpected send error: %v" , err )
406+ }
407+ if _ , err := s .Recv (); err != nil {
408+ t .Fatalf ("Unexpected recv error: %v" , err )
409+ }
410+ }
0 commit comments