1package s3manager
2
3import (
4	"fmt"
5	"io"
6	"net/http"
7	"strconv"
8	"strings"
9	"sync"
10
11	"github.com/aws/aws-sdk-go/aws"
12	"github.com/aws/aws-sdk-go/aws/awserr"
13	"github.com/aws/aws-sdk-go/aws/awsutil"
14	"github.com/aws/aws-sdk-go/aws/client"
15	"github.com/aws/aws-sdk-go/aws/request"
16	"github.com/aws/aws-sdk-go/service/s3"
17	"github.com/aws/aws-sdk-go/service/s3/s3iface"
18)
19
20// DefaultDownloadPartSize is the default range of bytes to get at a time when
21// using Download().
22const DefaultDownloadPartSize = 1024 * 1024 * 5
23
24// DefaultDownloadConcurrency is the default number of goroutines to spin up
25// when using Download().
26const DefaultDownloadConcurrency = 5
27
28type errReadingBody struct {
29	err error
30}
31
32func (e *errReadingBody) Error() string {
33	return fmt.Sprintf("failed to read part body: %v", e.err)
34}
35
36func (e *errReadingBody) Unwrap() error {
37	return e.err
38}
39
40// The Downloader structure that calls Download(). It is safe to call Download()
41// on this structure for multiple objects and across concurrent goroutines.
42// Mutating the Downloader's properties is not safe to be done concurrently.
43type Downloader struct {
44	// The size (in bytes) to request from S3 for each part.
45	// The minimum allowed part size is 5MB, and  if this value is set to zero,
46	// the DefaultDownloadPartSize value will be used.
47	//
48	// PartSize is ignored if the Range input parameter is provided.
49	PartSize int64
50
51	// The number of goroutines to spin up in parallel when sending parts.
52	// If this is set to zero, the DefaultDownloadConcurrency value will be used.
53	//
54	// Concurrency of 1 will download the parts sequentially.
55	//
56	// Concurrency is ignored if the Range input parameter is provided.
57	Concurrency int
58
59	// An S3 client to use when performing downloads.
60	S3 s3iface.S3API
61
62	// List of request options that will be passed down to individual API
63	// operation requests made by the downloader.
64	RequestOptions []request.Option
65
66	// Defines the buffer strategy used when downloading a part.
67	//
68	// If a WriterReadFromProvider is given the Download manager
69	// will pass the io.WriterAt of the Download request to the provider
70	// and will use the returned WriterReadFrom from the provider as the
71	// destination writer when copying from http response body.
72	BufferProvider WriterReadFromProvider
73}
74
75// WithDownloaderRequestOptions appends to the Downloader's API request options.
76func WithDownloaderRequestOptions(opts ...request.Option) func(*Downloader) {
77	return func(d *Downloader) {
78		d.RequestOptions = append(d.RequestOptions, opts...)
79	}
80}
81
82// NewDownloader creates a new Downloader instance to downloads objects from
83// S3 in concurrent chunks. Pass in additional functional options  to customize
84// the downloader behavior. Requires a client.ConfigProvider in order to create
85// a S3 service client. The session.Session satisfies the client.ConfigProvider
86// interface.
87//
88// Example:
89//     // The session the S3 Downloader will use
90//     sess := session.Must(session.NewSession())
91//
92//     // Create a downloader with the session and default options
93//     downloader := s3manager.NewDownloader(sess)
94//
95//     // Create a downloader with the session and custom options
96//     downloader := s3manager.NewDownloader(sess, func(d *s3manager.Downloader) {
97//          d.PartSize = 64 * 1024 * 1024 // 64MB per part
98//     })
99func NewDownloader(c client.ConfigProvider, options ...func(*Downloader)) *Downloader {
100	return newDownloader(s3.New(c), options...)
101}
102
103func newDownloader(client s3iface.S3API, options ...func(*Downloader)) *Downloader {
104	d := &Downloader{
105		S3:             client,
106		PartSize:       DefaultDownloadPartSize,
107		Concurrency:    DefaultDownloadConcurrency,
108		BufferProvider: defaultDownloadBufferProvider(),
109	}
110	for _, option := range options {
111		option(d)
112	}
113
114	return d
115}
116
117// NewDownloaderWithClient creates a new Downloader instance to downloads
118// objects from S3 in concurrent chunks. Pass in additional functional
119// options to customize the downloader behavior. Requires a S3 service client
120// to make S3 API calls.
121//
122// Example:
123//     // The session the S3 Downloader will use
124//     sess := session.Must(session.NewSession())
125//
126//     // The S3 client the S3 Downloader will use
127//     s3Svc := s3.New(sess)
128//
129//     // Create a downloader with the s3 client and default options
130//     downloader := s3manager.NewDownloaderWithClient(s3Svc)
131//
132//     // Create a downloader with the s3 client and custom options
133//     downloader := s3manager.NewDownloaderWithClient(s3Svc, func(d *s3manager.Downloader) {
134//          d.PartSize = 64 * 1024 * 1024 // 64MB per part
135//     })
136func NewDownloaderWithClient(svc s3iface.S3API, options ...func(*Downloader)) *Downloader {
137	return newDownloader(svc, options...)
138}
139
140type maxRetrier interface {
141	MaxRetries() int
142}
143
144// Download downloads an object in S3 and writes the payload into w using
145// concurrent GET requests. The n int64 returned is the size of the object downloaded
146// in bytes.
147//
148// Additional functional options can be provided to configure the individual
149// download. These options are copies of the Downloader instance Download is called from.
150// Modifying the options will not impact the original Downloader instance.
151//
152// It is safe to call this method concurrently across goroutines.
153//
154// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
155// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
156//
157// Specifying a Downloader.Concurrency of 1 will cause the Downloader to
158// download the parts from S3 sequentially.
159//
160// If the GetObjectInput's Range value is provided that will cause the downloader
161// to perform a single GetObjectInput request for that object's range. This will
162// caused the part size, and concurrency configurations to be ignored.
163func (d Downloader) Download(w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
164	return d.DownloadWithContext(aws.BackgroundContext(), w, input, options...)
165}
166
167// DownloadWithContext downloads an object in S3 and writes the payload into w
168// using concurrent GET requests. The n int64 returned is the size of the object downloaded
169// in bytes.
170//
171// DownloadWithContext is the same as Download with the additional support for
172// Context input parameters. The Context must not be nil. A nil Context will
173// cause a panic. Use the Context to add deadlining, timeouts, etc. The
174// DownloadWithContext may create sub-contexts for individual underlying
175// requests.
176//
177// Additional functional options can be provided to configure the individual
178// download. These options are copies of the Downloader instance Download is
179// called from. Modifying the options will not impact the original Downloader
180// instance. Use the WithDownloaderRequestOptions helper function to pass in request
181// options that will be applied to all API operations made with this downloader.
182//
183// The w io.WriterAt can be satisfied by an os.File to do multipart concurrent
184// downloads, or in memory []byte wrapper using aws.WriteAtBuffer.
185//
186// Specifying a Downloader.Concurrency of 1 will cause the Downloader to
187// download the parts from S3 sequentially.
188//
189// It is safe to call this method concurrently across goroutines.
190//
191// If the GetObjectInput's Range value is provided that will cause the downloader
192// to perform a single GetObjectInput request for that object's range. This will
193// caused the part size, and concurrency configurations to be ignored.
194func (d Downloader) DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*Downloader)) (n int64, err error) {
195	impl := downloader{w: w, in: input, cfg: d, ctx: ctx}
196
197	for _, option := range options {
198		option(&impl.cfg)
199	}
200	impl.cfg.RequestOptions = append(impl.cfg.RequestOptions, request.WithAppendUserAgent("S3Manager"))
201
202	if s, ok := d.S3.(maxRetrier); ok {
203		impl.partBodyMaxRetries = s.MaxRetries()
204	}
205
206	impl.totalBytes = -1
207	if impl.cfg.Concurrency == 0 {
208		impl.cfg.Concurrency = DefaultDownloadConcurrency
209	}
210
211	if impl.cfg.PartSize == 0 {
212		impl.cfg.PartSize = DefaultDownloadPartSize
213	}
214
215	return impl.download()
216}
217
218// DownloadWithIterator will download a batched amount of objects in S3 and writes them
219// to the io.WriterAt specificed in the iterator.
220//
221// Example:
222//	svc := s3manager.NewDownloader(session)
223//
224//	fooFile, err := os.Open("/tmp/foo.file")
225//	if err != nil {
226//		return err
227//	}
228//
229//	barFile, err := os.Open("/tmp/bar.file")
230//	if err != nil {
231//		return err
232//	}
233//
234//	objects := []s3manager.BatchDownloadObject {
235//		{
236//			Object: &s3.GetObjectInput {
237//				Bucket: aws.String("bucket"),
238//				Key: aws.String("foo"),
239//			},
240//			Writer: fooFile,
241//		},
242//		{
243//			Object: &s3.GetObjectInput {
244//				Bucket: aws.String("bucket"),
245//				Key: aws.String("bar"),
246//			},
247//			Writer: barFile,
248//		},
249//	}
250//
251//	iter := &s3manager.DownloadObjectsIterator{Objects: objects}
252//	if err := svc.DownloadWithIterator(aws.BackgroundContext(), iter); err != nil {
253//		return err
254//	}
255func (d Downloader) DownloadWithIterator(ctx aws.Context, iter BatchDownloadIterator, opts ...func(*Downloader)) error {
256	var errs []Error
257	for iter.Next() {
258		object := iter.DownloadObject()
259		if _, err := d.DownloadWithContext(ctx, object.Writer, object.Object, opts...); err != nil {
260			errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
261		}
262
263		if object.After == nil {
264			continue
265		}
266
267		if err := object.After(); err != nil {
268			errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
269		}
270	}
271
272	if len(errs) > 0 {
273		return NewBatchError("BatchedDownloadIncomplete", "some objects have failed to download.", errs)
274	}
275	return nil
276}
277
278// downloader is the implementation structure used internally by Downloader.
279type downloader struct {
280	ctx aws.Context
281	cfg Downloader
282
283	in *s3.GetObjectInput
284	w  io.WriterAt
285
286	wg sync.WaitGroup
287	m  sync.Mutex
288
289	pos        int64
290	totalBytes int64
291	written    int64
292	err        error
293
294	partBodyMaxRetries int
295}
296
297// download performs the implementation of the object download across ranged
298// GETs.
299func (d *downloader) download() (n int64, err error) {
300	// If range is specified fall back to single download of that range
301	// this enables the functionality of ranged gets with the downloader but
302	// at the cost of no multipart downloads.
303	if rng := aws.StringValue(d.in.Range); len(rng) > 0 {
304		d.downloadRange(rng)
305		return d.written, d.err
306	}
307
308	// Spin off first worker to check additional header information
309	d.getChunk()
310
311	if total := d.getTotalBytes(); total >= 0 {
312		// Spin up workers
313		ch := make(chan dlchunk, d.cfg.Concurrency)
314
315		for i := 0; i < d.cfg.Concurrency; i++ {
316			d.wg.Add(1)
317			go d.downloadPart(ch)
318		}
319
320		// Assign work
321		for d.getErr() == nil {
322			if d.pos >= total {
323				break // We're finished queuing chunks
324			}
325
326			// Queue the next range of bytes to read.
327			ch <- dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
328			d.pos += d.cfg.PartSize
329		}
330
331		// Wait for completion
332		close(ch)
333		d.wg.Wait()
334	} else {
335		// Checking if we read anything new
336		for d.err == nil {
337			d.getChunk()
338		}
339
340		// We expect a 416 error letting us know we are done downloading the
341		// total bytes. Since we do not know the content's length, this will
342		// keep grabbing chunks of data until the range of bytes specified in
343		// the request is out of range of the content. Once, this happens, a
344		// 416 should occur.
345		e, ok := d.err.(awserr.RequestFailure)
346		if ok && e.StatusCode() == http.StatusRequestedRangeNotSatisfiable {
347			d.err = nil
348		}
349	}
350
351	// Return error
352	return d.written, d.err
353}
354
355// downloadPart is an individual goroutine worker reading from the ch channel
356// and performing a GetObject request on the data with a given byte range.
357//
358// If this is the first worker, this operation also resolves the total number
359// of bytes to be read so that the worker manager knows when it is finished.
360func (d *downloader) downloadPart(ch chan dlchunk) {
361	defer d.wg.Done()
362	for {
363		chunk, ok := <-ch
364		if !ok {
365			break
366		}
367		if d.getErr() != nil {
368			// Drain the channel if there is an error, to prevent deadlocking
369			// of download producer.
370			continue
371		}
372
373		if err := d.downloadChunk(chunk); err != nil {
374			d.setErr(err)
375		}
376	}
377}
378
379// getChunk grabs a chunk of data from the body.
380// Not thread safe. Should only used when grabbing data on a single thread.
381func (d *downloader) getChunk() {
382	if d.getErr() != nil {
383		return
384	}
385
386	chunk := dlchunk{w: d.w, start: d.pos, size: d.cfg.PartSize}
387	d.pos += d.cfg.PartSize
388
389	if err := d.downloadChunk(chunk); err != nil {
390		d.setErr(err)
391	}
392}
393
394// downloadRange downloads an Object given the passed in Byte-Range value.
395// The chunk used down download the range will be configured for that range.
396func (d *downloader) downloadRange(rng string) {
397	if d.getErr() != nil {
398		return
399	}
400
401	chunk := dlchunk{w: d.w, start: d.pos}
402	// Ranges specified will short circuit the multipart download
403	chunk.withRange = rng
404
405	if err := d.downloadChunk(chunk); err != nil {
406		d.setErr(err)
407	}
408
409	// Update the position based on the amount of data received.
410	d.pos = d.written
411}
412
413// downloadChunk downloads the chunk from s3
414func (d *downloader) downloadChunk(chunk dlchunk) error {
415	in := &s3.GetObjectInput{}
416	awsutil.Copy(in, d.in)
417
418	// Get the next byte range of data
419	in.Range = aws.String(chunk.ByteRange())
420
421	var n int64
422	var err error
423	for retry := 0; retry <= d.partBodyMaxRetries; retry++ {
424		n, err = d.tryDownloadChunk(in, &chunk)
425		if err == nil {
426			break
427		}
428		// Check if the returned error is an errReadingBody.
429		// If err is errReadingBody this indicates that an error
430		// occurred while copying the http response body.
431		// If this occurs we unwrap the err to set the underlying error
432		// and attempt any remaining retries.
433		if bodyErr, ok := err.(*errReadingBody); ok {
434			err = bodyErr.Unwrap()
435		} else {
436			return err
437		}
438
439		chunk.cur = 0
440		logMessage(d.cfg.S3, aws.LogDebugWithRequestRetries,
441			fmt.Sprintf("DEBUG: object part body download interrupted %s, err, %v, retrying attempt %d",
442				aws.StringValue(in.Key), err, retry))
443	}
444
445	d.incrWritten(n)
446
447	return err
448}
449
450func (d *downloader) tryDownloadChunk(in *s3.GetObjectInput, w io.Writer) (int64, error) {
451	cleanup := func() {}
452	if d.cfg.BufferProvider != nil {
453		w, cleanup = d.cfg.BufferProvider.GetReadFrom(w)
454	}
455	defer cleanup()
456
457	resp, err := d.cfg.S3.GetObjectWithContext(d.ctx, in, d.cfg.RequestOptions...)
458	if err != nil {
459		return 0, err
460	}
461	d.setTotalBytes(resp) // Set total if not yet set.
462
463	n, err := io.Copy(w, resp.Body)
464	resp.Body.Close()
465	if err != nil {
466		return n, &errReadingBody{err: err}
467	}
468
469	return n, nil
470}
471
472func logMessage(svc s3iface.S3API, level aws.LogLevelType, msg string) {
473	s, ok := svc.(*s3.S3)
474	if !ok {
475		return
476	}
477
478	if s.Config.Logger == nil {
479		return
480	}
481
482	if s.Config.LogLevel.Matches(level) {
483		s.Config.Logger.Log(msg)
484	}
485}
486
487// getTotalBytes is a thread-safe getter for retrieving the total byte status.
488func (d *downloader) getTotalBytes() int64 {
489	d.m.Lock()
490	defer d.m.Unlock()
491
492	return d.totalBytes
493}
494
495// setTotalBytes is a thread-safe setter for setting the total byte status.
496// Will extract the object's total bytes from the Content-Range if the file
497// will be chunked, or Content-Length. Content-Length is used when the response
498// does not include a Content-Range. Meaning the object was not chunked. This
499// occurs when the full file fits within the PartSize directive.
500func (d *downloader) setTotalBytes(resp *s3.GetObjectOutput) {
501	d.m.Lock()
502	defer d.m.Unlock()
503
504	if d.totalBytes >= 0 {
505		return
506	}
507
508	if resp.ContentRange == nil {
509		// ContentRange is nil when the full file contents is provided, and
510		// is not chunked. Use ContentLength instead.
511		if resp.ContentLength != nil {
512			d.totalBytes = *resp.ContentLength
513			return
514		}
515	} else {
516		parts := strings.Split(*resp.ContentRange, "/")
517
518		total := int64(-1)
519		var err error
520		// Checking for whether or not a numbered total exists
521		// If one does not exist, we will assume the total to be -1, undefined,
522		// and sequentially download each chunk until hitting a 416 error
523		totalStr := parts[len(parts)-1]
524		if totalStr != "*" {
525			total, err = strconv.ParseInt(totalStr, 10, 64)
526			if err != nil {
527				d.err = err
528				return
529			}
530		}
531
532		d.totalBytes = total
533	}
534}
535
536func (d *downloader) incrWritten(n int64) {
537	d.m.Lock()
538	defer d.m.Unlock()
539
540	d.written += n
541}
542
543// getErr is a thread-safe getter for the error object
544func (d *downloader) getErr() error {
545	d.m.Lock()
546	defer d.m.Unlock()
547
548	return d.err
549}
550
551// setErr is a thread-safe setter for the error object
552func (d *downloader) setErr(e error) {
553	d.m.Lock()
554	defer d.m.Unlock()
555
556	d.err = e
557}
558
559// dlchunk represents a single chunk of data to write by the worker routine.
560// This structure also implements an io.SectionReader style interface for
561// io.WriterAt, effectively making it an io.SectionWriter (which does not
562// exist).
563type dlchunk struct {
564	w     io.WriterAt
565	start int64
566	size  int64
567	cur   int64
568
569	// specifies the byte range the chunk should be downloaded with.
570	withRange string
571}
572
573// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
574// position to its end (or EOF).
575//
576// If a range is specified on the dlchunk the size will be ignored when writing.
577// as the total size may not of be known ahead of time.
578func (c *dlchunk) Write(p []byte) (n int, err error) {
579	if c.cur >= c.size && len(c.withRange) == 0 {
580		return 0, io.EOF
581	}
582
583	n, err = c.w.WriteAt(p, c.start+c.cur)
584	c.cur += int64(n)
585
586	return
587}
588
589// ByteRange returns a HTTP Byte-Range header value that should be used by the
590// client to request the chunk's range.
591func (c *dlchunk) ByteRange() string {
592	if len(c.withRange) != 0 {
593		return c.withRange
594	}
595
596	return fmt.Sprintf("bytes=%d-%d", c.start, c.start+c.size-1)
597}
598