Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 218 additions & 80 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ type Source struct {

// Ensure the Source satisfies the interfaces at compile time
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.Validator = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.SourceUnitEnumChunker = (*Source)(nil)

// Type returns the type of source
func (s *Source) Type() sourcespb.SourceType { return SourceType }
Expand Down Expand Up @@ -439,44 +440,12 @@ func (s *Source) pageChunker(
return
}

// Skip GLACIER and GLACIER_IR objects.
if obj.StorageClass == s3types.ObjectStorageClassGlacier || obj.StorageClass == s3types.ObjectStorageClassGlacierIr {
ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", obj.StorageClass)
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "storage_class", float64(*obj.Size))
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for glacier object")
}
continue
}

// Ignore large files.
if *obj.Size > s.maxObjectSize {
ctx.Logger().V(5).Info("Skipping large file", "max_object_size", s.maxObjectSize)
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "size_limit", float64(*obj.Size))
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for large file")
}
continue
}

// File empty file.
if *obj.Size == 0 {
ctx.Logger().V(5).Info("Skipping empty file")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "empty_file", 0)
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for empty file")
}
continue
}

// Skip incompatible extensions.
if common.SkipFile(*obj.Key) {
ctx.Logger().V(5).Info("Skipping file with incompatible extension")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "incompatible_extension", float64(*obj.Size))
skipObject, reason := s.shouldSkipObject(ctx, objIdx, obj, metadata)
if skipObject {
s.metricsCollector.RecordObjectSkipped(metadata.bucket, reason, float64(*obj.Size))
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for incompatible file")
ctx.Logger().Error(err, fmt.Sprintf("could not update progress for %s", reason))
}
continue
}

s.jobPool.Go(func() error {
Expand All @@ -485,12 +454,6 @@ func (s *Source) pageChunker(
return ctx.Err()
}

if strings.HasSuffix(*obj.Key, "/") {
ctx.Logger().V(5).Info("Skipping directory")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "directory", float64(*obj.Size))
return nil
}

path := strings.Split(*obj.Key, "/")
prefix := strings.Join(path[:len(path)-1], "/")

Expand All @@ -508,10 +471,7 @@ func (s *Source) pageChunker(
objCtx, cancel := context.WithTimeout(ctx, getObjectTimeout)
defer cancel()

res, err := metadata.client.GetObject(objCtx, &s3.GetObjectInput{
Bucket: &metadata.bucket,
Key: obj.Key,
})
res, err := s.getObject(objCtx, metadata.client, *obj.Key, metadata.bucket, *obj.Size)
if err != nil {
if strings.Contains(err.Error(), "AccessDenied") {
ctx.Logger().Error(err, "could not get S3 object; access denied")
Expand All @@ -520,13 +480,6 @@ func (s *Source) pageChunker(
ctx.Logger().Error(err, "could not get S3 object")
s.metricsCollector.RecordObjectError(metadata.bucket)
}
// According to the documentation for GetObjectWithContext,
// the response can be non-nil even if there was an error.
// It's uncertain if the body will be nil in such cases,
// but we'll close it if it's not.
if res != nil && res.Body != nil {
res.Body.Close()
}

nErr, ok := state.errorCount.Load(prefix)
if !ok {
Expand All @@ -546,32 +499,8 @@ func (s *Source) pageChunker(
}
defer res.Body.Close()

email := "Unknown"
if obj.Owner != nil {
email = *obj.Owner.DisplayName
}
modified := obj.LastModified.String()
chunkSkel := &sources.Chunk{
SourceType: s.Type(),
SourceName: s.name,
SourceID: s.SourceID(),
JobID: s.JobID(),
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_S3{
S3: &source_metadatapb.S3{
Bucket: metadata.bucket,
File: sanitizer.UTF8(*obj.Key),
Link: sanitizer.UTF8(makeS3Link(metadata.bucket, metadata.client.Options().Region, *obj.Key)),
Email: sanitizer.UTF8(email),
Timestamp: sanitizer.UTF8(modified),
},
},
},
Verify: s.verify,
}

if err := handlers.HandleFile(ctx, res.Body, chunkSkel, sources.ChanReporter{Ch: chunksChan}); err != nil {
ctx.Logger().Error(err, "error handling file")
err = s.handleFileChunk(ctx, obj, res, sources.ChanReporter{Ch: chunksChan}, metadata.bucket, metadata.client.Options().Region)
if err != nil {
s.metricsCollector.RecordObjectError(metadata.bucket)
return nil
}
Expand Down Expand Up @@ -679,3 +608,212 @@ func (s *Source) visitRoles(
func makeS3Link(bucket, region, key string) string {
return fmt.Sprintf("https://%s.s3.%s.amazonaws.com/%s", bucket, region, key)
}

func (s *Source) handleFileChunk(
ctx context.Context,
obj s3types.Object,
objRes *s3.GetObjectOutput,
reporter sources.ChunkReporter,
bucket,
region string,
) error {
email := "Unknown"
if obj.Owner != nil {
email = *obj.Owner.DisplayName
}
modified := obj.LastModified.String()
chunkSkel := &sources.Chunk{
SourceType: s.Type(),
SourceName: s.name,
SourceID: s.SourceID(),
JobID: s.JobID(),
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_S3{
S3: &source_metadatapb.S3{
Bucket: bucket,
File: sanitizer.UTF8(*obj.Key),
Link: sanitizer.UTF8(makeS3Link(bucket, region, *obj.Key)),
Email: sanitizer.UTF8(email),
Timestamp: sanitizer.UTF8(modified),
},
},
},
Verify: s.verify,
}

if err := handlers.HandleFile(ctx, objRes.Body, chunkSkel, reporter); err != nil {
ctx.Logger().Error(err, "error handling file")
return err
}
return nil
}

// Get S3 Object with error handling
func (s *Source) getObject(ctx context.Context, client *s3.Client, key, bucket string, size int64) (*s3.GetObjectOutput, error) {
res, err := client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &bucket,
Key: &key,
})
if err != nil {
// According to the documentation for GetObjectWithContext,
// the response can be non-nil even if there was an error.
// It's uncertain if the body will be nil in such cases,
// but we'll close it if it's not.
if res != nil && res.Body != nil {
res.Body.Close()
}
return nil, err
}
return res, nil
}

// Decides if an object should be skipped while scanning/enumerating
func (s *Source) shouldSkipObject(ctx context.Context, objIdx int, obj s3types.Object, metadata pageMetadata) (bool, string) {
// Skip GLACIER and GLACIER_IR objects.
if obj.StorageClass == s3types.ObjectStorageClassGlacier || obj.StorageClass == s3types.ObjectStorageClassGlacierIr {
ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", obj.StorageClass)
return true, "storage_class"
}

// Ignore large files.
if *obj.Size > s.maxObjectSize {
ctx.Logger().V(5).Info("Skipping large file", "max_object_size", s.maxObjectSize)
return true, "size_limit"
}

// File empty file.
if *obj.Size == 0 {
ctx.Logger().V(5).Info("Skipping empty file")
return true, "empty_file"
}

// Skip incompatible extensions.
if common.SkipFile(*obj.Key) {
ctx.Logger().V(5).Info("Skipping file with incompatible extension")
return true, "incompatible_extension"
}

// Skip directory
if strings.HasSuffix(*obj.Key, "/") {
ctx.Logger().V(5).Info("Skipping directory")
return true, "directory"
}

return false, ""
}

type S3SourceUnit struct {
Object s3types.Object
Bucket string
Role string
}

func (s S3SourceUnit) SourceUnitID() (string, sources.SourceUnitKind) {
// The ID is the object key, and the kind is "s3_object".
return *s.Object.Key, "s3_object"
}

func (s S3SourceUnit) Display() string {
return fmt.Sprintf("%s:%s", s.Bucket, *s.Object.Key)
}

var _ sources.SourceUnit = S3SourceUnit{}

// Enumerate implements SourceUnitEnumerator interface. This implementation visits
// each configured role, scans the buckets and passes each s3 object as a source unit
func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error {
visitor := func(c context.Context, defaultRegionClient *s3.Client, roleArn string, buckets []string) error {
for _, bucket := range buckets {
if common.IsDone(ctx) {
return ctx.Err()
}

ctx.Logger().V(5).Info("Enumerating bucket")

regionalClient, err := s.getRegionalClientForBucket(ctx, defaultRegionClient, roleArn, bucket)
if err != nil {
ctx.Logger().V(5).Error(err, "could not get regional client for bucket")
continue
}

input := &s3.ListObjectsV2Input{Bucket: &bucket}
paginator := s3.NewListObjectsV2Paginator(regionalClient, input)

pageNumber := 1
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
ctx.Logger().V(5).Error(err, "could not list objects in bucket")
break
}

metadata := pageMetadata{
bucket: bucket,
pageNumber: pageNumber,
client: regionalClient,
page: output,
}

for objIdx, obj := range output.Contents {

skipObject, _ := s.shouldSkipObject(ctx, objIdx, obj, metadata)
if skipObject {
continue
}

unit := S3SourceUnit{
Object: obj,
Bucket: bucket,
Role: roleArn,
}
if err := reporter.UnitOk(ctx, unit); err != nil {
return err
}
}

pageNumber++
}
}
return nil
}

return s.visitRoles(ctx, visitor)
}

func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {

s3unit, ok := unit.(S3SourceUnit)
if !ok {
return fmt.Errorf("expected *S3SourceUnit, got %T", unit)
}
objectKey, _ := unit.SourceUnitID()
bucket := s3unit.Bucket
logger := ctx.Logger().WithValues("bucket", bucket, "key", objectKey)

defaultClient, err := s.newClient(ctx, defaultAWSRegion, s3unit.Role)
if err != nil {
return fmt.Errorf("could not create s3 client: %w", err)
}

client, err := s.getRegionalClientForBucket(ctx, defaultClient, s3unit.Role, bucket)
if err != nil {
return reporter.ChunkErr(ctx, fmt.Errorf("unable to get regional client for bucket: %w", err))
}

// Make sure we use a separate context for the GetObjectWithContext call.
// This ensures that the timeout is isolated and does not affect any downstream operations. (e.g. HandleFile)
const getObjectTimeout = 30 * time.Second
objCtx, cancel := context.WithTimeout(ctx, getObjectTimeout)
defer cancel()

res, err := s.getObject(objCtx, client, objectKey, bucket, *s3unit.Object.Size)
if err != nil {
return reporter.ChunkErr(ctx, fmt.Errorf("unable to get object: %w", err))
}
defer res.Body.Close()

logger.V(3).Info(fmt.Sprintf("chunking s3 unit %s", *s3unit.Object.Key))

return s.handleFileChunk(ctx, s3unit.Object, res, reporter, bucket, client.Options().Region)

}
Loading
Loading