Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 87 additions & 79 deletions provider/pihole/pihole_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,26 @@ func TestNewPiholeProvider(t *testing.T) {
}
}

func TestProvider(t *testing.T) {
func TestProvider_InitialState(t *testing.T) {
requests := requestTracker{}
p := &PiholeProvider{
api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests},
}

records, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
if len(records) != 0 {
t.Fatal("Expected empty list of records, got:", records)
}
}

// Populate the provider with records
records = []*endpoint.Endpoint{
func TestProvider_CreateRecords(t *testing.T) {
requests := requestTracker{}
p := &PiholeProvider{
api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests},
}
records := []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
Expand Down Expand Up @@ -133,9 +137,6 @@ func TestProvider(t *testing.T) {
}); err != nil {
t.Fatal(err)
}

// Test records are correct on retrieval

newRecords, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
Expand All @@ -149,25 +150,26 @@ func TestProvider(t *testing.T) {
if len(requests.deleteRequests) != 0 {
t.Fatal("Expected no delete requests, got:", requests.deleteRequests)
}

for idx, record := range records {
if newRecords[idx].DNSName != record.DNSName {
t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
}
if newRecords[idx].Targets[0] != record.Targets[0] {
t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets)
}

if !reflect.DeepEqual(requests.createRequests[idx], record) {
t.Error("Unexpected create request, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
}
}

requests.clear()
}

// Test delete a record

records = []*endpoint.Endpoint{
func TestProvider_DeleteRecords(t *testing.T) {
requests := requestTracker{}
p := &PiholeProvider{
api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests},
}
records := []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
Expand All @@ -189,6 +191,12 @@ func TestProvider(t *testing.T) {
RecordType: endpoint.RecordTypeAAAA,
},
}
// Create initial records
if err := p.ApplyChanges(context.Background(), &plan.Changes{
Create: records,
}); err != nil {
t.Fatal(err)
}
recordToDeleteA := endpoint.Endpoint{
DNSName: "test3.example.com",
Targets: []string{"192.168.1.3"},
Expand All @@ -213,22 +221,19 @@ func TestProvider(t *testing.T) {
}); err != nil {
t.Fatal(err)
}

// Test records are updated
newRecords, err = p.Records(context.Background())
newRecords, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
if len(newRecords) != 4 {
t.Fatal("Expected list of 4 records, got:", records)
}
if len(requests.createRequests) != 0 {
t.Fatal("Expected no create requests, got:", requests.createRequests)
if len(requests.createRequests) != 4 {
t.Fatal("Expected 4 create requests, got:", requests.createRequests)
}
if len(requests.deleteRequests) != 2 {
t.Fatal("Expected 2 delete request, got:", requests.deleteRequests)
}

for idx, record := range records {
if newRecords[idx].DNSName != record.DNSName {
t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
Expand All @@ -237,27 +242,30 @@ func TestProvider(t *testing.T) {
t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets)
}
}

if !reflect.DeepEqual(requests.deleteRequests[0], &recordToDeleteA) {
t.Error("Unexpected delete request, got:", requests.deleteRequests[0], "expected:", recordToDeleteA)
}
if !reflect.DeepEqual(requests.deleteRequests[1], &recordToDeleteAAAA) {
t.Error("Unexpected delete request, got:", requests.deleteRequests[1], "expected:", recordToDeleteAAAA)
}

requests.clear()
}

// Test update a record

records = []*endpoint.Endpoint{
func TestProvider_UpdateRecords(t *testing.T) {
requests := requestTracker{}
p := &PiholeProvider{
api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests},
}
// Create initial records
initialRecords := []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1"},
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
},
{
Expand All @@ -267,61 +275,68 @@ func TestProvider(t *testing.T) {
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:10:0:0:1"},
Targets: []string{"fc00::1:192:168:1:2"},
RecordType: endpoint.RecordTypeAAAA,
},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{
UpdateOld: []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test1.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:192:168:1:2"},
RecordType: endpoint.RecordTypeAAAA,
},
Create: initialRecords,
}); err != nil {
t.Fatal(err)
}
requests.clear()
// Update records
updateOld := []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
UpdateNew: []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test1.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:10:0:0:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test1.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:192:168:1:2"},
RecordType: endpoint.RecordTypeAAAA,
},
}
updateNew := []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test1.example.com",
Targets: []string{"fc00::1:192:168:1:1"},
RecordType: endpoint.RecordTypeAAAA,
},
{
DNSName: "test2.example.com",
Targets: []string{"fc00::1:10:0:0:1"},
RecordType: endpoint.RecordTypeAAAA,
},
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{
UpdateOld: updateOld,
UpdateNew: updateNew,
}); err != nil {
t.Fatal(err)
}

// Test records are updated
newRecords, err = p.Records(context.Background())
newRecords, err := p.Records(context.Background())
if err != nil {
t.Fatal(err)
}
Expand All @@ -334,16 +349,14 @@ func TestProvider(t *testing.T) {
if len(requests.deleteRequests) != 2 {
t.Fatal("Expected 2 delete request, got:", requests.deleteRequests)
}

for idx, record := range records {
for idx, record := range updateNew {
if newRecords[idx].DNSName != record.DNSName {
t.Error("DNS Name malformed on retrieval, got:", newRecords[idx].DNSName, "expected:", record.DNSName)
}
if newRecords[idx].Targets[0] != record.Targets[0] {
t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets)
}
}

expectedCreateA := endpoint.Endpoint{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1"},
Expand All @@ -364,7 +377,6 @@ func TestProvider(t *testing.T) {
Targets: []string{"fc00::1:192:168:1:2"},
RecordType: endpoint.RecordTypeAAAA,
}

for _, request := range requests.createRequests {
switch request.RecordType {
case endpoint.RecordTypeA:
Expand All @@ -375,10 +387,8 @@ func TestProvider(t *testing.T) {
if !reflect.DeepEqual(request, &expectedCreateAAAA) {
t.Error("Unexpected create request, got:", request, "expected:", &expectedCreateAAAA)
}
default:
}
}

for _, request := range requests.deleteRequests {
switch request.RecordType {
case endpoint.RecordTypeA:
Expand All @@ -389,9 +399,7 @@ func TestProvider(t *testing.T) {
if !reflect.DeepEqual(request, &expectedDeleteAAAA) {
t.Error("Unexpected delete request, got:", request, "expected:", &expectedDeleteAAAA)
}
default:
}
}

requests.clear()
}
Loading