Skip to content

Commit 81d6289

Browse files
committed
ConsumptionOptional now takes a type param
1 parent 5134bc3 commit 81d6289

File tree

8 files changed

+52
-32
lines changed

8 files changed

+52
-32
lines changed

api.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,16 +286,23 @@ func MustConsume[T any](fn any) Provider {
286286
// TODO: add ExampleConsumptionOptional
287287

288288
// ConsumptionOptional creates a new provider and annotates it as
289-
// allowed to have some of its return values ignored.
289+
// allowed to have one of its return values ignored.
290290
// Without this annotation, a wrap function will not be included
291291
// if some of its return values are not consumed.
292292
//
293293
// In the downward direction, optional consumption is the default.
294294
//
295295
// When used on an existing Provider, it creates an annotated copy of that provider.
296-
func ConsumptionOptional(fn any) Provider {
296+
//
297+
// ConsumptionOptional calls may be called on the output of ConsumptionOptional calls
298+
// to mark more than one type as optional.
299+
func ConsumptionOptional[T any](fn any) Provider {
297300
return newThing(fn).modify(func(fm *provider) {
298-
fm.consumptionOptional = true
301+
if fm.consumptionOptional == nil {
302+
fm.consumptionOptional = make(map[typeCode]struct{})
303+
}
304+
t := reflect.TypeOf((*T)(nil)).Elem()
305+
fm.consumptionOptional[getTypeCode(t)] = struct{}{}
299306
})
300307
}
301308

bind.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,9 @@ func makeUnusedInputProvider() (*provider, error) {
452452
d.nonFinal = true
453453
d.cacheable = true
454454
d.mustCache = true
455-
d.consumptionOptional = true
455+
d.consumptionOptional = map[typeCode]struct{}{
456+
unusedTypeCode: {},
457+
}
456458
d, err := characterizeFunc(d, charContext{inputsAreStatic: true})
457459
if err != nil {
458460
return nil, fmt.Errorf("internal error #328: problem with unused injectors: %w", err)
@@ -472,7 +474,9 @@ func makeUnusedReturnsProvider() (*provider, error) {
472474
d.isSynthetic = true
473475
d.shun = true
474476
d.required = false
475-
d.consumptionOptional = true
477+
d.consumptionOptional = map[typeCode]struct{}{
478+
unusedTypeCode: {},
479+
}
476480
return d, nil
477481
}
478482

debug.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -240,28 +240,28 @@ func generateReproduce(funcs []*provider, invokeF *provider, initF *provider) st
240240
f += "\t\t\t\t" + extraIndent
241241
closeParens := ""
242242
for annotation, active := range map[string]bool{
243-
"NonFinal": fm.nonFinal,
244-
"Cacheable": fm.cacheable,
245-
"MustCache": fm.mustCache,
246-
"Required": fm.required,
247-
"CallsInner": fm.callsInner,
248-
"Memoize": fm.memoize,
249-
"Reorder": fm.reorder,
250-
"Desired": fm.desired,
251-
"Shun": fm.shun,
252-
"NotCacheable": fm.notCacheable,
253-
"ConsumptionOptional": fm.consumptionOptional,
254-
"Singleton": fm.singleton,
243+
"NonFinal": fm.nonFinal,
244+
"Cacheable": fm.cacheable,
245+
"MustCache": fm.mustCache,
246+
"Required": fm.required,
247+
"CallsInner": fm.callsInner,
248+
"Memoize": fm.memoize,
249+
"Reorder": fm.reorder,
250+
"Desired": fm.desired,
251+
"Shun": fm.shun,
252+
"NotCacheable": fm.notCacheable,
253+
"Singleton": fm.singleton,
255254
} {
256255
if active {
257256
f += annotation + "("
258257
closeParens += ")"
259258
}
260259
}
261260
for anno, m := range map[string]map[typeCode]struct{}{
262-
"ShadowingAllowed": fm.shadowingAllowed,
263-
"Loose": fm.loose,
264-
"MustConsume": fm.mustConsume,
261+
"ShadowingAllowed": fm.shadowingAllowed,
262+
"Loose": fm.loose,
263+
"MustConsume": fm.mustConsume,
264+
"ConsumptionOptional": fm.consumptionOptional,
265265
} {
266266
for tc := range m {
267267
f += anno + "[" + tc.String() + "]("

debug_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func TestDetailedError(t *testing.T) {
8080
Memoize(func(i int32) int32 { return i }),
8181
OverridesError(func(_ func()) error { return nil }),
8282
MustConsume[int64](func(i int32) int64 { return int64(i) }),
83-
ConsumptionOptional(func(i int64) float64 { return float64(i) }),
83+
ConsumptionOptional[float64](func(i int64) float64 { return float64(i) }),
8484
func(_ MyType5) error { return nil },
8585
NonFinal(func() {}),
8686
)
@@ -93,7 +93,7 @@ func TestDetailedError(t *testing.T) {
9393
require.NotEqual(t, -1, index, "contains 'func TestRegression'")
9494
detailed = detailed[index:]
9595

96-
for _, word := range strings.Split("Desired Shun Required Cacheable MustCache Cluster Memoize ShadowingAllowed\\[error\\] MustConsume\\[int64\\] ConsumptionOptional NonFinal Loose\\[string\\]", " ") {
96+
for _, word := range strings.Split("Desired Shun Required Cacheable MustCache Cluster Memoize ShadowingAllowed\\[error\\] MustConsume\\[int64\\] ConsumptionOptional\\[float64\\] NonFinal Loose\\[string\\]", " ") {
9797
re := regexp.MustCompile(fmt.Sprintf(`\b%s\(` /*)*/, word))
9898
if !re.MatchString(detailed) {
9999
t.Errorf("did not find %s( in reproduce output", word) // )

include.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*pro
8080
if fm.mustConsume != nil {
8181
fm.d.mustConsumeFlow[outputParams] = true
8282
}
83-
if !fm.consumptionOptional {
83+
if fm.consumptionOptional == nil {
8484
fm.d.mustConsumeFlow[returnParams] = true
8585
}
8686
if fm.required {
@@ -307,9 +307,14 @@ func checkFlows(funcs []*provider, numFuncs int, canRemoveDesired bool) error {
307307
}
308308
Param:
309309
for _, tc := range tclist {
310-
if param == int(outputParams) {
310+
switch flowType(param) {
311+
case outputParams:
311312
if _, ok := fm.mustConsume[tc]; !ok {
312-
continue
313+
continue Param
314+
}
315+
case returnParams:
316+
if _, ok := fm.consumptionOptional[tc]; ok {
317+
continue Param
313318
}
314319
}
315320
if tc == unusedTypeCode {

nject.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type provider struct {
3030
shun bool
3131
notCacheable bool
3232
mustConsume map[typeCode]struct{}
33-
consumptionOptional bool
33+
consumptionOptional map[typeCode]struct{}
3434
singleton bool
3535
cluster int32
3636
parallel bool
@@ -94,7 +94,7 @@ func (fm *provider) copy() *provider {
9494
shun: fm.shun,
9595
notCacheable: fm.notCacheable,
9696
mustConsume: mapCopy(fm.mustConsume),
97-
consumptionOptional: fm.consumptionOptional,
97+
consumptionOptional: mapCopy(fm.consumptionOptional),
9898
singleton: fm.singleton,
9999
cluster: fm.cluster,
100100
parallel: fm.parallel,

nject_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,14 @@ func TestAnnotateMustConsume(t *testing.T) {
5858
}
5959

6060
func TestAnnotateConsumptionOptional(t *testing.T) {
61+
stc := getTypeCode("foo")
6162
wrapTest(t, func(t *testing.T) {
62-
p := ConsumptionOptional(func() {})
63+
p := ConsumptionOptional[string](func() {})
6364
require.IsType(t, &provider{}, p)
64-
require.True(t, p.(*provider).consumptionOptional)
65-
require.True(t, p.(*provider).copy().consumptionOptional)
65+
require.NotNil(t, p.(*provider).consumptionOptional)
66+
require.Contains(t, p.(*provider).consumptionOptional, stc)
67+
require.NotNil(t, p.(*provider).copy().consumptionOptional)
68+
require.Contains(t, p.(*provider).copy().consumptionOptional, stc)
6669
})
6770
}
6871

reorder.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,13 @@ func reorder(funcs []*provider, initF *provider) ([]*provider, error) {
177177
}
178178
// if you return a T, you're not marked consumptionOptional, then you
179179
// MUST be after a provider that receives a T as a returned value
180+
_, consumptionOptional := fm.consumptionOptional[t]
180181
if num, ok := upTypes[t]; ok {
181-
aAfterB(!fm.consumptionOptional, i, num)
182+
aAfterB(!consumptionOptional, i, num)
182183
} else {
183184
debugln("\tuptype", counter, t)
184185
upTypes[t] = counter
185-
aAfterB(!fm.consumptionOptional, i, counter)
186+
aAfterB(!consumptionOptional, i, counter)
186187
counter++
187188
}
188189
// if you return a T, then you SHOULD be be after all providers that

0 commit comments

Comments
 (0)