Skip to content

Commit 5a76e4d

Browse files
committed
stuff
1 parent 18c0fef commit 5a76e4d

File tree

3 files changed

+68
-16
lines changed

3 files changed

+68
-16
lines changed

datacatalog/pkg/repositories/errors/postgres.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ type postgresErrorTransformer struct {
2424
}
2525

2626
const (
27-
unexpectedType = "unexpected error type for: %v"
28-
uniqueConstraintViolation = "value with matching already exists (%s)"
29-
defaultPgError = "failed database operation with code [%s] and msg [%s]"
30-
unsupportedTableOperation = "cannot query with specified table attributes: %s"
27+
unexpectedType = "unexpected error type for: %v"
28+
uUniqueConstraintViolation = "value with matching already exists (%s)"
29+
defaultPgError = "failed database operation with code [%s] and msg [%s]"
30+
unsupportedTableOperation = "cannot query with specified table attributes: %s"
3131
)
3232

3333
func (p *postgresErrorTransformer) fromGormError(err error) error {
@@ -40,9 +40,14 @@ func (p *postgresErrorTransformer) fromGormError(err error) error {
4040
}
4141

4242
func (p *postgresErrorTransformer) ToDataCatalogError(err error) error {
43+
var dce catalogErrors.DataCatalogError
44+
if errors.As(err, &dce) {
45+
return dce // already a data catalog error
46+
}
47+
4348
// First try the stdlib error handling
4449
if database.IsPgErrorWithCode(err, uniqueConstraintViolationCode) {
45-
return catalogErrors.NewDataCatalogErrorf(codes.AlreadyExists, uniqueConstraintViolation, err.Error())
50+
return catalogErrors.NewDataCatalogErrorf(codes.AlreadyExists, uUniqueConstraintViolation, err.Error())
4651
}
4752

4853
if unwrappedErr := errors.Unwrap(err); unwrappedErr != nil {
@@ -58,7 +63,7 @@ func (p *postgresErrorTransformer) ToDataCatalogError(err error) error {
5863

5964
switch pqError.Code {
6065
case uniqueConstraintViolationCode:
61-
return catalogErrors.NewDataCatalogErrorf(codes.AlreadyExists, uniqueConstraintViolation, pqError.Message)
66+
return catalogErrors.NewDataCatalogErrorf(codes.AlreadyExists, uUniqueConstraintViolation, pqError.Message)
6267
case undefinedTable:
6368
return catalogErrors.NewDataCatalogErrorf(codes.InvalidArgument, unsupportedTableOperation, pqError.Message)
6469
default:

datacatalog/pkg/repositories/gormimpl/artifact.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ package gormimpl
33
import (
44
"context"
55

6+
7+
"google.golang.org/grpc/codes"
68
"gorm.io/gorm"
79
"gorm.io/gorm/clause"
810
"k8s.io/utils/clock"
911

1012
"github.com/flyteorg/flyte/datacatalog/pkg/common"
13+
catalogErrors "github.com/flyteorg/flyte/datacatalog/pkg/errors"
1114
"github.com/flyteorg/flyte/datacatalog/pkg/repositories/errors"
1215
"github.com/flyteorg/flyte/datacatalog/pkg/repositories/interfaces"
1316
"github.com/flyteorg/flyte/datacatalog/pkg/repositories/models"
@@ -41,18 +44,28 @@ func (h *artifactRepo) Create(ctx context.Context, artifact models.Artifact) err
4144
timer := h.repoMetrics.CreateDuration.Start(ctx)
4245
defer timer.Stop()
4346

44-
tx := h.db.WithContext(ctx).Begin()
47+
err := h.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
48+
tx = tx.
49+
Where("artifacts.expires_at is null or artifacts.expires_at < ?", h.clock.Now().UTC()).
50+
Order("artifacts.created_at DESC"). // Always pick the most recent
51+
Find(
52+
&models.Artifact{ArtifactKey: artifact.ArtifactKey},
53+
)
4554

46-
tx = tx.Create(&artifact)
55+
if tx.Error != nil {
56+
return tx.Error
57+
}
4758

48-
if tx.Error != nil {
49-
tx.Rollback()
50-
return h.errorTransformer.ToDataCatalogError(tx.Error)
51-
}
59+
if tx.RowsAffected > 0 {
60+
return catalogErrors.NewDataCatalogErrorf(codes.AlreadyExists, "artifact already exists")
61+
}
5262

53-
tx = tx.Commit()
54-
if tx.Error != nil {
55-
return h.errorTransformer.ToDataCatalogError(tx.Error)
63+
tx = tx.Create(&artifact)
64+
return tx.Error
65+
})
66+
67+
if err != nil {
68+
return h.errorTransformer.ToDataCatalogError(err)
5669
}
5770

5871
return nil

datacatalog/pkg/repositories/gormimpl/artifact_test.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,23 @@ func getDBTagResponse(artifact models.Artifact) []map[string]interface{} {
124124
}
125125

126126
func TestCreateArtifact(t *testing.T) {
127+
testClock := testclock.NewFakeClock(time.Unix(0, 0))
127128
artifact := getTestArtifact()
128129

130+
existingChecked := false
129131
artifactCreated := false
130132
GlobalMock := mocket.Catcher.Reset()
131133
GlobalMock.Logging = true
132134

133135
numArtifactDataCreated := 0
134136
numPartitionsCreated := 0
135137

138+
GlobalMock.NewMock().WithQuery(
139+
`SELECT * FROM "artifacts" WHERE (artifacts.expires_at is null or artifacts.expires_at < $1) AND "artifacts"."dataset_project" = $2 AND "artifacts"."dataset_name" = $3 AND "artifacts"."dataset_domain" = $4 AND "artifacts"."dataset_version" = $5 AND "artifacts"."artifact_id" = $6 ORDER BY artifacts.created_at DESC%!!(string=123)!(string=testVersion)!(string=testDomain)!(string=testName)!(string=testProject)(EXTRA time.Time=1970-01-01 00:00:00 +0000 UTC)`).WithCallback(
140+
func(s string, values []driver.NamedValue) {
141+
existingChecked = true
142+
})
143+
136144
// Only match on queries that append expected filters
137145
GlobalMock.NewMock().WithQuery(
138146
`INSERT INTO "artifacts" ("created_at","updated_at","deleted_at","dataset_project","dataset_name","dataset_domain","dataset_version","artifact_id","dataset_uuid","serialized_metadata","expires_at") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)`).WithCallback(
@@ -173,12 +181,38 @@ func TestCreateArtifact(t *testing.T) {
173181

174182
artifact.Partitions = partitions
175183

176-
artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope(), clock.RealClock{})
184+
artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope(), testClock)
177185
err := artifactRepo.Create(context.Background(), artifact)
178186
assert.NoError(t, err)
179187
assert.True(t, artifactCreated)
180188
assert.Equal(t, 2, numArtifactDataCreated)
181189
assert.Equal(t, 1, numPartitionsCreated)
190+
assert.True(t, existingChecked)
191+
}
192+
193+
func TestCreateArtifactAlreadyExists(t *testing.T) {
194+
testClock := testclock.NewFakeClock(time.Unix(0, 0))
195+
196+
artifact := getTestArtifactWithExpiration(testClock.Now().Add(time.Second))
197+
expectedArtifactResponse := getDBArtifactResponse(artifact)
198+
199+
existingChecked := false
200+
GlobalMock := mocket.Catcher.Reset()
201+
GlobalMock.Logging = true
202+
203+
GlobalMock.NewMock().WithQuery(
204+
`SELECT * FROM "artifacts" WHERE (artifacts.expires_at is null or artifacts.expires_at < $1) AND "artifacts"."dataset_project" = $2 AND "artifacts"."dataset_name" = $3 AND "artifacts"."dataset_domain" = $4 AND "artifacts"."dataset_version" = $5 AND "artifacts"."artifact_id" = $6 ORDER BY artifacts.created_at DESC%!!(string=123)!(string=testVersion)!(string=testDomain)!(string=testName)!(string=testProject)(EXTRA time.Time=1970-01-01 00:00:00 +0000 UTC)`).WithCallback(
205+
func(s string, values []driver.NamedValue) {
206+
existingChecked = true
207+
}).WithReply(expectedArtifactResponse)
208+
209+
artifactRepo := NewArtifactRepo(utils.GetDbForTest(t), errors.NewPostgresErrorTransformer(), promutils.NewTestScope(), testClock)
210+
err := artifactRepo.Create(context.Background(), artifact)
211+
assert.Error(t, err)
212+
dcErr, ok := err.(apiErrors.DataCatalogError)
213+
assert.True(t, ok)
214+
assert.Equal(t, codes.AlreadyExists.String(), dcErr.Code().String())
215+
assert.True(t, existingChecked)
182216
}
183217

184218
func TestGetArtifactNotExpired(t *testing.T) {

0 commit comments

Comments
 (0)