-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvalidator.go
141 lines (120 loc) · 3.84 KB
/
validator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package di
import (
"errors"
"fmt"
"reflect"
"github.com/dozm/di/errorx"
"github.com/dozm/di/syncx"
)
type validatorState struct {
Singleton CallSite
}
type CallSiteValidator struct {
scopedServices *syncx.Map[reflect.Type, reflect.Type]
}
func (v *CallSiteValidator) ValidateCallSite(callSite CallSite) error {
scoped, err := v.visitCallSite(callSite, validatorState{})
if err != nil {
return err
}
if scoped != nil {
v.scopedServices.Store(callSite.ServiceType(), scoped)
}
return nil
}
func (v *CallSiteValidator) ValidateResolution(serviceType reflect.Type, scope Scope, rootScope Scope) (err error) {
if scope == rootScope {
scopedService, ok := v.scopedServices.Load(serviceType)
if !ok {
return
}
if serviceType == scopedService {
return &errorx.ScopedServiceFromRootError{
Message: fmt.Sprintf("cannot resolve scoped service '%v' from root scope", serviceType)}
}
return &errorx.ScopedServiceFromRootError{
Message: fmt.Sprintf("cannot resolve '%v' from root scope because it requires scoped service '%v'", serviceType, scopedService),
}
}
return
}
func (r *CallSiteValidator) visitCallSite(callSite CallSite, state validatorState) (reflect.Type, error) {
switch callSite.Cache().Location {
case CacheLocation_Root:
return r.visitRootCache(callSite, state)
case CacheLocation_Scope:
return r.visitScopeCache(callSite, state)
case CacheLocation_Dispose:
return r.visitDisposeCache(callSite, state)
case CacheLocation_None:
return r.visitNoCache(callSite, state)
default:
return nil, errors.New("unknow cache location")
}
}
func (r *CallSiteValidator) visitCallSiteMain(callSite CallSite, state validatorState) (reflect.Type, error) {
switch callSite.Kind() {
case CallSiteKind_Factory, CallSiteKind_Constant, CallSiteKind_Container:
return nil, nil
case CallSiteKind_Slice:
return r.visitSlice(callSite.(*SliceCallSite), state)
case CallSiteKind_Constructor:
return r.visitConstructor(callSite.(*ConstructorCallSite), state)
default:
return nil, errors.New("unknow call site kind")
}
}
func (v *CallSiteValidator) visitConstructor(callSite *ConstructorCallSite, state validatorState) (reflect.Type, error) {
var result reflect.Type
for _, cs := range callSite.Parameters {
scoped, err := v.visitCallSite(cs, state)
if err != nil {
return nil, err
}
if result == nil {
result = scoped
}
}
return result, nil
}
func (v *CallSiteValidator) visitSlice(callSite *SliceCallSite, state validatorState) (reflect.Type, error) {
var result reflect.Type
for _, cs := range callSite.CallSites {
scoped, err := v.visitCallSite(cs, state)
if err != nil {
return nil, err
}
if result == nil {
result = scoped
}
}
return result, nil
}
func (v *CallSiteValidator) visitRootCache(singletonCallSite CallSite, state validatorState) (reflect.Type, error) {
state.Singleton = singletonCallSite
return v.visitCallSiteMain(singletonCallSite, state)
}
func (v *CallSiteValidator) visitScopeCache(scopedCallSite CallSite, state validatorState) (reflect.Type, error) {
if scopedCallSite.ServiceType() == ScopeFactoryType {
return nil, nil
}
if state.Singleton != nil {
return nil, fmt.Errorf("cannot consume scoped service '%v' from singleton '%v'",
scopedCallSite.ServiceType(),
state.Singleton.ServiceType())
}
_, err := v.visitCallSiteMain(scopedCallSite, state)
if err != nil {
return nil, err
}
return scopedCallSite.ServiceType(), nil
}
func (v *CallSiteValidator) visitDisposeCache(callSite CallSite, state validatorState) (reflect.Type, error) {
return v.visitCallSiteMain(callSite, state)
}
func (v *CallSiteValidator) visitNoCache(callSite CallSite, state validatorState) (reflect.Type, error) {
return v.visitCallSiteMain(callSite, state)
}
func newCallSiteValidator() *CallSiteValidator {
return &CallSiteValidator{scopedServices: syncx.NewMap[reflect.Type, reflect.Type]()}
}