Skip to content

Commit 1c8f186

Browse files
committed
Fix large reads
1 parent 76dfb62 commit 1c8f186

2 files changed

Lines changed: 115 additions & 12 deletions

File tree

disk_buffer_reader.go

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// Maybe I need to set the read size? n, err := io.Reader.Read(something); n will not be more that 512 unless I set something (ran into this in archive)
2+
// Only happens when Stop() is run. Why??
3+
// Make sure read time is similar between standard reader and dbr.
14
package diskbufferreader
25

36
import (
@@ -68,20 +71,18 @@ func (dbr *DiskBufferReader) Read(out []byte) (int, error) {
6871
// Update the number of bytes read.
6972
dbr.bytesRead += int64(m)
7073

71-
// Go back to the beginning of the tmp file so reads start from the beginning.
72-
dbr.tmpFile.Seek(dbr.index, io.SeekStart)
7374
}
7475

7576
// Read from the multireader of the tmp file and the reader.
77+
if dbr.index <= dbr.bytesRead {
78+
dbr.tmpFile.Seek(dbr.index, io.SeekStart)
79+
}
7680
mr := io.MultiReader(dbr.tmpFile, dbr.reader)
7781
bytesRead := 0
7882
outBuffer := bytes.NewBuffer([]byte{})
79-
outMulti := make([]byte, len(out))
8083
var outErr error
81-
if dbr.index <= dbr.bytesRead {
82-
dbr.tmpFile.Seek(dbr.index, io.SeekStart)
83-
}
8484
for {
85+
outMulti := make([]byte, len(out)-bytesRead)
8586
n, err := mr.Read(outMulti)
8687
if err != nil {
8788
if !errors.Is(err, io.EOF) {
@@ -102,12 +103,8 @@ func (dbr *DiskBufferReader) Read(out []byte) (int, error) {
102103
break
103104
}
104105
}
105-
outStart := dbr.index
106-
if int64(bytesRead) < dbr.index {
107-
outStart = int64(bytesRead)
108-
}
109-
copy(out, outBuffer.Bytes()[outStart:])
110-
dbr.index = int64(bytesRead)
106+
copy(out, outBuffer.Bytes())
107+
dbr.index += int64(bytesRead)
111108
return bytesRead, outErr
112109
}
113110

disk_buffer_reader_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package diskbufferreader
33
import (
44
"bytes"
55
"errors"
6+
"io"
7+
"net/http"
68
"testing"
79
)
810

@@ -74,3 +76,107 @@ func TestDiskBufferReader(t *testing.T) {
7476
}
7577
}
7678
}
79+
80+
func TestReadAll(t *testing.T) {
81+
tests := map[string]struct {
82+
content string
83+
record bool
84+
reset bool
85+
}{
86+
"RecordOnNoReset": {
87+
"OneTwoThree",
88+
true,
89+
false,
90+
},
91+
"RecordOffNoReset": {
92+
"OneTwoThree",
93+
false,
94+
false,
95+
},
96+
"RecordOnReset": {
97+
"OneTwoThree",
98+
true,
99+
true,
100+
},
101+
"RecordOffReset": {
102+
"OneTwoThree",
103+
false,
104+
true,
105+
},
106+
}
107+
108+
for testName, testCase := range tests {
109+
110+
readBytes := []byte(testCase.content)
111+
bytesReader := bytes.NewBuffer(readBytes)
112+
tmpReader := bytes.NewBuffer(readBytes)
113+
dbr, err := New(tmpReader)
114+
if err != nil {
115+
t.Fatal(err)
116+
}
117+
defer dbr.Close()
118+
119+
if testCase.reset {
120+
chunk := make([]byte, 3)
121+
dbr.Read(chunk)
122+
dbr.Reset()
123+
}
124+
125+
if !testCase.record {
126+
dbr.Stop()
127+
}
128+
129+
baseBytes, baseErr := io.ReadAll(bytesReader)
130+
testBytes, testErr := io.ReadAll(dbr)
131+
132+
if string(testBytes) != string(baseBytes) {
133+
t.Fatalf("%s: Unexpected read result. Got: %v, expected: %v", testName, testBytes, baseBytes)
134+
}
135+
136+
if !errors.Is(testErr, baseErr) {
137+
t.Fatalf("%s: Unexpected error. Got: %s, expected: %s", testName, testErr, baseErr)
138+
}
139+
}
140+
}
141+
142+
func TestReadAllLarge(t *testing.T) {
143+
resp, err := http.Get("https://raw.githubusercontent.com/bill-rich/bad-secrets/master/FifteenMB.gz")
144+
if err != nil {
145+
t.Fatal(err)
146+
}
147+
defer resp.Body.Close()
148+
149+
readBytes, err := io.ReadAll(resp.Body)
150+
if err != nil {
151+
t.Fatal(err)
152+
}
153+
154+
bytesReader := bytes.NewBuffer(readBytes)
155+
tmpReader := bytes.NewBuffer(readBytes)
156+
dbr, err := New(tmpReader)
157+
if err != nil {
158+
t.Fatal(err)
159+
}
160+
defer dbr.Close()
161+
162+
chunk := make([]byte, 3)
163+
dbr.Read(chunk)
164+
dbr.Reset()
165+
166+
dbr.Stop()
167+
168+
baseBytes, baseErr := io.ReadAll(bytesReader)
169+
testBytes, testErr := io.ReadAll(dbr)
170+
171+
if len(testBytes) != len(baseBytes) {
172+
t.Fatalf("Wrong number of bytes read. Got: %d, expected: %d", len(testBytes), len(baseBytes))
173+
}
174+
175+
if string(testBytes) != string(baseBytes) {
176+
t.Fatalf("Unexpected read result. Got: %v..%v, expected: %v..%v", testBytes[:1024], testBytes[len(testBytes)-16:], baseBytes[:1024], baseBytes[len(baseBytes)-16:])
177+
}
178+
179+
if !errors.Is(testErr, baseErr) {
180+
t.Fatalf("Unexpected error. Got: %s, expected: %s", testErr, baseErr)
181+
}
182+
}

0 commit comments

Comments
 (0)