Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 120 additions & 24 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,86 @@ type destinationRule struct {
}

type creationRule struct {
PathRegex string `yaml:"path_regex"`
KMS string
AwsProfile string `yaml:"aws_profile"`
Age string `yaml:"age"`
PGP string
GCPKMS string `yaml:"gcp_kms"`
AzureKeyVault string `yaml:"azure_keyvault"`
VaultURI string `yaml:"hc_vault_transit_uri"`
KeyGroups []keyGroup `yaml:"key_groups"`
ShamirThreshold int `yaml:"shamir_threshold"`
UnencryptedSuffix string `yaml:"unencrypted_suffix"`
EncryptedSuffix string `yaml:"encrypted_suffix"`
UnencryptedRegex string `yaml:"unencrypted_regex"`
EncryptedRegex string `yaml:"encrypted_regex"`
UnencryptedCommentRegex string `yaml:"unencrypted_comment_regex"`
EncryptedCommentRegex string `yaml:"encrypted_comment_regex"`
MACOnlyEncrypted bool `yaml:"mac_only_encrypted"`
PathRegex string `yaml:"path_regex"`
KMS interface{} `yaml:"kms"` // string or []string
AwsProfile string `yaml:"aws_profile"`
Age interface{} `yaml:"age"` // string or []string
PGP interface{} `yaml:"pgp"` // string or []string
GCPKMS interface{} `yaml:"gcp_kms"` // string or []string
AzureKeyVault interface{} `yaml:"azure_keyvault"` // string or []string
VaultURI interface{} `yaml:"hc_vault_transit_uri"` // string or []string
KeyGroups []keyGroup `yaml:"key_groups"`
ShamirThreshold int `yaml:"shamir_threshold"`
UnencryptedSuffix string `yaml:"unencrypted_suffix"`
EncryptedSuffix string `yaml:"encrypted_suffix"`
UnencryptedRegex string `yaml:"unencrypted_regex"`
EncryptedRegex string `yaml:"encrypted_regex"`
UnencryptedCommentRegex string `yaml:"unencrypted_comment_regex"`
EncryptedCommentRegex string `yaml:"encrypted_comment_regex"`
MACOnlyEncrypted bool `yaml:"mac_only_encrypted"`
}

// Helper methods to safely extract keys as []string
func (c *creationRule) GetKMSKeys() ([]string, error) {
return parseKeyField(c.KMS, "kms")
}

func (c *creationRule) GetAgeKeys() ([]string, error) {
return parseKeyField(c.Age, "age")
}

func (c *creationRule) GetPGPKeys() ([]string, error) {
return parseKeyField(c.PGP, "pgp")
}

func (c *creationRule) GetGCPKMSKeys() ([]string, error) {
return parseKeyField(c.GCPKMS, "gcp_kms")
}

func (c *creationRule) GetAzureKeyVaultKeys() ([]string, error) {
return parseKeyField(c.AzureKeyVault, "azure_keyvault")
}

func (c *creationRule) GetVaultURIs() ([]string, error) {
return parseKeyField(c.VaultURI, "hc_vault_transit_uri")
}

// Utility function to handle both string and []string
func parseKeyField(field interface{}, fieldName string) ([]string, error) {
if field == nil {
return []string{}, nil
}

switch v := field.(type) {
case string:
if v == "" {
return []string{}, nil
}
// Existing CSV parsing logic
keys := strings.Split(v, ",")
result := make([]string, 0, len(keys))
for _, key := range keys {
trimmed := strings.TrimSpace(key)
if trimmed != "" { // Skip empty strings (fixes trailing comma issue)
result = append(result, trimmed)
}
}
return result, nil
case []interface{}:
result := make([]string, len(v))
for i, item := range v {
if str, ok := item.(string); ok {
result[i] = str
} else {
return nil, fmt.Errorf("invalid %s key configuration: expected string in list, got %T", fieldName, item)
}
}
return result, nil
case []string:
return v, nil
default:
return nil, fmt.Errorf("invalid %s key configuration: expected string, []string, or nil, got %T", fieldName, field)
}
}

func NewStoresConfig() *StoresConfig {
Expand Down Expand Up @@ -279,6 +342,14 @@ func extractMasterKeys(group keyGroup) (sops.KeyGroup, error) {
return deduplicateKeygroup(keyGroup), nil
}

func getKeysWithValidation(getKeysFunc func() ([]string, error), keyType string) ([]string, error) {
keys, err := getKeysFunc()
if err != nil {
return nil, fmt.Errorf("invalid %s key configuration: %w", keyType, err)
}
return keys, nil
}

func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[string]*string) ([]sops.KeyGroup, error) {
var groups []sops.KeyGroup
if len(cRule.KeyGroups) > 0 {
Expand All @@ -291,8 +362,13 @@ func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[
}
} else {
var keyGroup sops.KeyGroup
if cRule.Age != "" {
ageKeys, err := age.MasterKeysFromRecipients(cRule.Age)
ageKeys, err := getKeysWithValidation(cRule.GetAgeKeys, "age")
if err != nil {
return nil, err
}

if len(ageKeys) > 0 {
ageKeys, err := age.MasterKeysFromRecipients(strings.Join(ageKeys, ","))
if err != nil {
return nil, err
} else {
Expand All @@ -301,23 +377,43 @@ func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[
}
}
}
for _, k := range pgp.MasterKeysFromFingerprintString(cRule.PGP) {
pgpKeys, err := getKeysWithValidation(cRule.GetPGPKeys, "pgp")
if err != nil {
return nil, err
}
for _, k := range pgp.MasterKeysFromFingerprintString(strings.Join(pgpKeys, ",")) {
keyGroup = append(keyGroup, k)
}
for _, k := range kms.MasterKeysFromArnString(cRule.KMS, kmsEncryptionContext, cRule.AwsProfile) {
kmsKeys, err := getKeysWithValidation(cRule.GetKMSKeys, "kms")
if err != nil {
return nil, err
}
for _, k := range kms.MasterKeysFromArnString(strings.Join(kmsKeys, ","), kmsEncryptionContext, cRule.AwsProfile) {
keyGroup = append(keyGroup, k)
}
for _, k := range gcpkms.MasterKeysFromResourceIDString(cRule.GCPKMS) {
gcpkmsKeys, err := getKeysWithValidation(cRule.GetGCPKMSKeys, "gcpkms")
if err != nil {
return nil, err
}
for _, k := range gcpkms.MasterKeysFromResourceIDString(strings.Join(gcpkmsKeys, ",")) {
keyGroup = append(keyGroup, k)
}
azureKeys, err := azkv.MasterKeysFromURLs(cRule.AzureKeyVault)
azKeys, err := getKeysWithValidation(cRule.GetAzureKeyVaultKeys, "azure_keyvault")
if err != nil {
return nil, err
}
azureKeys, err := azkv.MasterKeysFromURLs(strings.Join(azKeys, ","))
if err != nil {
return nil, err
}
for _, k := range azureKeys {
keyGroup = append(keyGroup, k)
}
vaultKeys, err := hcvault.NewMasterKeysFromURIs(cRule.VaultURI)
vaultKeyUris, err := getKeysWithValidation(cRule.GetVaultURIs, "vault")
if err != nil {
return nil, err
}
vaultKeys, err := hcvault.NewMasterKeysFromURIs(strings.Join(vaultKeyUris, ","))
if err != nil {
return nil, err
}
Expand Down
37 changes: 37 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,40 @@ func TestLoadConfigFileWithVaultDestinationRules(t *testing.T) {
assert.NotNil(t, conf.Destination)
assert.Contains(t, conf.Destination.Path("barfoo"), "/v1/kv/barfoo/barfoo")
}

func TestCreationRuleNativeKeyLists(t *testing.T) {
var sampleConfigWithNativeKeyLists = []byte(`
creation_rules:
- path_regex: native_list*
pgp:
- "85D77543B3D624B63CEA9E6DBC17301B491B3F21" # [email protected]
- "FBC7B9E2A4F9289AC0C1D4843D16CEE4A27381B4" # server_XYZ
kms:
- "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012"
age:
- "age1ql3z7hjy54pw3hyww5ayyfg7zqgvc7w3j2elw8zmrj2kg5sfn9aqmcac8p"
gcp_kms:
- "projects/test-project/locations/global/keyRings/test-ring/cryptoKeys/test-key"
hc_vault_transit_uri:
- "https://vault.example.com:8200/v1/transit/keys/key1"
`)
conf, err := parseCreationRuleForFile(parseConfigFile(sampleConfigWithNativeKeyLists, t), "/conf/path", "native_list_test", nil)
assert.Nil(t, err)
if conf == nil {
t.Fatal("Expected configuration but got nil")
}

assert.True(t, len(conf.KeyGroups) == 1)
assert.True(t, len(conf.KeyGroups[0]) == 6)

keyTypeCounts := make(map[string]int)
for _, key := range conf.KeyGroups[0] {
keyTypeCounts[key.TypeToIdentifier()]++
}

assert.Equal(t, 2, keyTypeCounts["pgp"])
assert.Equal(t, 1, keyTypeCounts["kms"])
assert.Equal(t, 1, keyTypeCounts["age"])
assert.Equal(t, 1, keyTypeCounts["gcp_kms"])
assert.Equal(t, 1, keyTypeCounts["hc_vault"])
}
Loading