Skip to content

Commit 36a0335

Browse files
committed
encoding: add support for compressor options in Compress methods
1 parent 2b8b708 commit 36a0335

8 files changed

Lines changed: 50 additions & 21 deletions

File tree

encoding/compressor_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ type wrapCompressor struct {
4545
compressInvokes int32
4646
}
4747

48-
func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
48+
func (wc *wrapCompressor) Compress(w io.Writer, opts ...any) (io.WriteCloser, error) {
4949
atomic.AddInt32(&wc.compressInvokes, 1)
50-
return wc.Compressor.Compress(w)
50+
return wc.Compressor.Compress(w, opts...)
5151
}
5252

5353
func setupGzipWrapCompressor(t *testing.T) *wrapCompressor {
@@ -186,7 +186,7 @@ type fakeCompressor struct {
186186
decompressedMessageSize int
187187
}
188188

189-
func (f *fakeCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
189+
func (f *fakeCompressor) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) {
190190
return nopWriteCloser{w}, nil
191191
}
192192

encoding/encoding.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ func init() {
6161
type Compressor interface {
6262
// Compress writes the data written to wc to w after compressing it. If an
6363
// error occurs while initializing the compressor, that error is returned
64-
// instead.
65-
Compress(w io.Writer) (io.WriteCloser, error)
64+
// instead. opts passes caller-provided context to the compressor (e.g.
65+
// dictionary IDs for trained compression formats). Unknown options must
66+
// be silently ignored.
67+
Compress(w io.Writer, opts ...any) (io.WriteCloser, error)
6668
// Decompress reads data from r, decompresses it, and provides the
6769
// uncompressed data via the returned io.Reader. If an error occurs while
6870
// initializing the decompressor, that error is returned instead.

encoding/gzip/gzip.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func SetLevel(level int) error {
7070
return nil
7171
}
7272

73-
func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) {
73+
func (c *compressor) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) {
7474
z := c.poolCompressor.Get().(*writer)
7575
z.Writer.Reset(w)
7676
return z, nil

internal/transport/server_stream.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ type ServerStream struct {
5050
headerSent atomic.Bool // atomically set when the headers are sent out.
5151

5252
headerWireLength int
53+
54+
sendCompressOptions []any
5355
}
5456

5557
// Read reads an n byte message from the input stream.
@@ -108,6 +110,17 @@ func (s *ServerStream) ContentSubtype() string {
108110
return s.contentSubtype
109111
}
110112

113+
// SetSendCompressOptions sets the compressor options to be passed to the
114+
// compressor's Compress method.
115+
func (s *ServerStream) SetSendCompressOptions(opts []any) {
116+
s.sendCompressOptions = opts
117+
}
118+
119+
// SendCompressOptions returns the compressor options set for the stream.
120+
func (s *ServerStream) SendCompressOptions() []any {
121+
return s.sendCompressOptions
122+
}
123+
111124
// SetSendCompress sets the compression algorithm to the stream.
112125
func (s *ServerStream) SetSendCompress(name string) error {
113126
if s.isHeaderSent() || s.getState() == streamDone {

rpc_util.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ type callInfo struct {
168168
onFinish []func(err error)
169169
authority string
170170
acceptedResponseCompressors []string
171+
compressorOptions []any
171172
}
172173

173174
func acceptedCompressorAllows(allowed []string, name string) bool {
@@ -490,14 +491,16 @@ func (o PerRPCCredsCallOption) after(*callInfo, *csAttempt) {}
490491

491492
// UseCompressor returns a CallOption which sets the compressor used when
492493
// sending the request. If WithCompressor is also set, UseCompressor has
493-
// higher priority.
494+
// higher priority. The optional compressorOptions are forwarded to the
495+
// compressor's Compress method, allowing callers to pass additional context
496+
// such as dictionary IDs for trained compression formats.
494497
//
495498
// # Experimental
496499
//
497500
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
498501
// later release.
499-
func UseCompressor(name string) CallOption {
500-
return CompressorCallOption{CompressorType: name}
502+
func UseCompressor(name string, compressorOptions ...any) CallOption {
503+
return CompressorCallOption{CompressorType: name, CompressorOptions: compressorOptions}
501504
}
502505

503506
// CompressorCallOption is a CallOption that indicates the compressor to use.
@@ -507,11 +510,13 @@ func UseCompressor(name string) CallOption {
507510
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
508511
// later release.
509512
type CompressorCallOption struct {
510-
CompressorType string
513+
CompressorType string
514+
CompressorOptions []any
511515
}
512516

513517
func (o CompressorCallOption) before(c *callInfo) error {
514518
c.compressorName = o.CompressorType
519+
c.compressorOptions = o.CompressorOptions
515520
return nil
516521
}
517522
func (o CompressorCallOption) after(*callInfo, *csAttempt) {}
@@ -817,7 +822,7 @@ func encode(c baseCodec, msg any) (mem.BufferSlice, error) {
817822
// indicating no compression was done.
818823
//
819824
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
820-
func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool) (mem.BufferSlice, payloadFormat, error) {
825+
func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor, pool mem.BufferPool, compressorOptions ...any) (mem.BufferSlice, payloadFormat, error) {
821826
if (compressor == nil && cp == nil) || in.Len() == 0 {
822827
return nil, compressionNone, nil
823828
}
@@ -828,7 +833,7 @@ func compress(in mem.BufferSlice, cp Compressor, compressor encoding.Compressor,
828833
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
829834
}
830835
if compressor != nil {
831-
z, err := compressor.Compress(w)
836+
z, err := compressor.Compress(w, compressorOptions...)
832837
if err != nil {
833838
return nil, 0, wrapErr(err)
834839
}

rpc_util_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ type testCompressorForRegistry struct {
5252
name string
5353
}
5454

55-
func (c *testCompressorForRegistry) Compress(w io.Writer) (io.WriteCloser, error) {
55+
func (c *testCompressorForRegistry) Compress(w io.Writer, _ ...any) (io.WriteCloser, error) {
5656
return &testWriteCloser{w}, nil
5757
}
5858

@@ -541,7 +541,7 @@ type mockCompressor struct {
541541
ch chan<- struct{}
542542
}
543543

544-
func (m *mockCompressor) Compress(io.Writer) (io.WriteCloser, error) {
544+
func (m *mockCompressor) Compress(io.Writer, ...any) (io.WriteCloser, error) {
545545
panic("unimplemented")
546546
}
547547

server.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2146,11 +2146,15 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
21462146
// It is not safe to call SetSendCompressor concurrently with SendHeader and
21472147
// SendMsg.
21482148
//
2149+
// The optional compressorOptions are forwarded to the compressor's Compress
2150+
// method on each SendMsg call, allowing callers to pass additional context
2151+
// such as dictionary IDs for trained compression formats.
2152+
//
21492153
// # Experimental
21502154
//
21512155
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
21522156
// later release.
2153-
func SetSendCompressor(ctx context.Context, name string) error {
2157+
func SetSendCompressor(ctx context.Context, name string, compressorOptions ...any) error {
21542158
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.ServerStream)
21552159
if !ok || stream == nil {
21562160
return fmt.Errorf("failed to fetch the stream from the given context")
@@ -2160,6 +2164,7 @@ func SetSendCompressor(ctx context.Context, name string) error {
21602164
return fmt.Errorf("unable to set send compressor: %w", err)
21612165
}
21622166

2167+
stream.SetSendCompressOptions(compressorOptions)
21632168
return stream.SetSendCompress(name)
21642169
}
21652170

stream.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,8 @@ type clientStream struct {
633633
// nameResolutionDelay indicates if there was a delay in the name resolution.
634634
// This field is only valid on client side, it's always false on server side.
635635
nameResolutionDelay bool
636+
637+
compressorOptions []any
636638
}
637639

638640
type replayOp struct {
@@ -964,7 +966,7 @@ func (cs *clientStream) SendMsg(m any) (err error) {
964966
}
965967

966968
// load hdr, payload, data
967-
hdr, data, payload, pf, err := prepareMsg(m, cs.codec, cs.compressorV0, cs.compressorV1, cs.cc.dopts.copts.BufferPool)
969+
hdr, data, payload, pf, err := prepareMsg(m, cs.codec, cs.compressorV0, cs.compressorV1, cs.cc.dopts.copts.BufferPool, cs.compressorOptions...)
968970
if err != nil {
969971
return err
970972
}
@@ -1471,7 +1473,7 @@ func (as *addrConnStream) SendMsg(m any) (err error) {
14711473
}
14721474

14731475
// load hdr, payload, data
1474-
hdr, data, payload, pf, err := prepareMsg(m, as.codec, as.sendCompressorV0, as.sendCompressorV1, as.ac.dopts.copts.BufferPool)
1476+
hdr, data, payload, pf, err := prepareMsg(m, as.codec, as.sendCompressorV0, as.sendCompressorV1, as.ac.dopts.copts.BufferPool, as.callInfo.compressorOptions...)
14751477
if err != nil {
14761478
return err
14771479
}
@@ -1669,7 +1671,8 @@ type serverStream struct {
16691671
// synchronized.
16701672
serverHeaderBinlogged bool
16711673

1672-
mu sync.Mutex // protects trInfo.tr after the service handler runs.
1674+
mu sync.Mutex // protects trInfo.tr after the service handler runs.
1675+
sendCompressorOptions []any
16731676
}
16741677

16751678
func (ss *serverStream) Context() context.Context {
@@ -1748,10 +1751,11 @@ func (ss *serverStream) SendMsg(m any) (err error) {
17481751
if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
17491752
ss.compressorV1 = encoding.GetCompressor(sendCompressorsName)
17501753
ss.sendCompressorName = sendCompressorsName
1754+
ss.sendCompressorOptions = ss.s.SendCompressOptions()
17511755
}
17521756

17531757
// load hdr, payload, data
1754-
hdr, data, payload, pf, err := prepareMsg(m, ss.codec, ss.compressorV0, ss.compressorV1, ss.p.bufferPool)
1758+
hdr, data, payload, pf, err := prepareMsg(m, ss.codec, ss.compressorV0, ss.compressorV1, ss.p.bufferPool, ss.sendCompressorOptions...)
17551759
if err != nil {
17561760
return err
17571761
}
@@ -1893,7 +1897,7 @@ func MethodFromServerStream(stream ServerStream) (string, bool) {
18931897
// compression was made and therefore whether the payload needs to be freed in
18941898
// addition to the returned data. Freeing the payload if the returned boolean is
18951899
// false can lead to undefined behavior.
1896-
func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor, pool mem.BufferPool) (hdr []byte, data, payload mem.BufferSlice, pf payloadFormat, err error) {
1900+
func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor, pool mem.BufferPool, compressorOptions ...any) (hdr []byte, data, payload mem.BufferSlice, pf payloadFormat, err error) {
18971901
if preparedMsg, ok := m.(*PreparedMsg); ok {
18981902
return preparedMsg.hdr, preparedMsg.encodedData, preparedMsg.payload, preparedMsg.pf, nil
18991903
}
@@ -1903,7 +1907,7 @@ func prepareMsg(m any, codec baseCodec, cp Compressor, comp encoding.Compressor,
19031907
if err != nil {
19041908
return nil, nil, nil, 0, err
19051909
}
1906-
compData, pf, err := compress(data, cp, comp, pool)
1910+
compData, pf, err := compress(data, cp, comp, pool, compressorOptions...)
19071911
if err != nil {
19081912
data.Free()
19091913
return nil, nil, nil, 0, err

0 commit comments

Comments
 (0)