cache.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. package validator
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strings"
  6. "sync"
  7. "sync/atomic"
  8. )
  9. type tagType uint8
  10. const (
  11. typeDefault tagType = iota
  12. typeOmitEmpty
  13. typeIsDefault
  14. typeNoStructLevel
  15. typeStructOnly
  16. typeDive
  17. typeOr
  18. typeKeys
  19. typeEndKeys
  20. )
  21. const (
  22. invalidValidation = "Invalid validation tag on field '%s'"
  23. undefinedValidation = "Undefined validation function '%s' on field '%s'"
  24. keysTagNotDefined = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag"
  25. )
  26. type structCache struct {
  27. lock sync.Mutex
  28. m atomic.Value // map[reflect.Type]*cStruct
  29. }
  30. func (sc *structCache) Get(key reflect.Type) (c *cStruct, found bool) {
  31. c, found = sc.m.Load().(map[reflect.Type]*cStruct)[key]
  32. return
  33. }
  34. func (sc *structCache) Set(key reflect.Type, value *cStruct) {
  35. m := sc.m.Load().(map[reflect.Type]*cStruct)
  36. nm := make(map[reflect.Type]*cStruct, len(m)+1)
  37. for k, v := range m {
  38. nm[k] = v
  39. }
  40. nm[key] = value
  41. sc.m.Store(nm)
  42. }
  43. type tagCache struct {
  44. lock sync.Mutex
  45. m atomic.Value // map[string]*cTag
  46. }
  47. func (tc *tagCache) Get(key string) (c *cTag, found bool) {
  48. c, found = tc.m.Load().(map[string]*cTag)[key]
  49. return
  50. }
  51. func (tc *tagCache) Set(key string, value *cTag) {
  52. m := tc.m.Load().(map[string]*cTag)
  53. nm := make(map[string]*cTag, len(m)+1)
  54. for k, v := range m {
  55. nm[k] = v
  56. }
  57. nm[key] = value
  58. tc.m.Store(nm)
  59. }
  60. type cStruct struct {
  61. name string
  62. fields []*cField
  63. fn StructLevelFuncCtx
  64. }
  65. type cField struct {
  66. idx int
  67. name string
  68. altName string
  69. namesEqual bool
  70. cTags *cTag
  71. }
  72. type cTag struct {
  73. tag string
  74. aliasTag string
  75. actualAliasTag string
  76. param string
  77. keys *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation
  78. next *cTag
  79. fn FuncCtx
  80. typeof tagType
  81. hasTag bool
  82. hasAlias bool
  83. hasParam bool // true if parameter used eg. eq= where the equal sign has been set
  84. isBlockEnd bool // indicates the current tag represents the last validation in the block
  85. runValidationWhenNil bool
  86. }
  87. func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
  88. v.structCache.lock.Lock()
  89. defer v.structCache.lock.Unlock() // leave as defer! because if inner panics, it will never get unlocked otherwise!
  90. typ := current.Type()
  91. // could have been multiple trying to access, but once first is done this ensures struct
  92. // isn't parsed again.
  93. cs, ok := v.structCache.Get(typ)
  94. if ok {
  95. return cs
  96. }
  97. cs = &cStruct{name: sName, fields: make([]*cField, 0), fn: v.structLevelFuncs[typ]}
  98. numFields := current.NumField()
  99. rules := v.rules[typ]
  100. var ctag *cTag
  101. var fld reflect.StructField
  102. var tag string
  103. var customName string
  104. for i := 0; i < numFields; i++ {
  105. fld = typ.Field(i)
  106. if !fld.Anonymous && len(fld.PkgPath) > 0 {
  107. continue
  108. }
  109. if rtag, ok := rules[fld.Name]; ok {
  110. tag = rtag
  111. } else {
  112. tag = fld.Tag.Get(v.tagName)
  113. }
  114. if tag == skipValidationTag {
  115. continue
  116. }
  117. customName = fld.Name
  118. if v.hasTagNameFunc {
  119. name := v.tagNameFunc(fld)
  120. if len(name) > 0 {
  121. customName = name
  122. }
  123. }
  124. // NOTE: cannot use shared tag cache, because tags may be equal, but things like alias may be different
  125. // and so only struct level caching can be used instead of combined with Field tag caching
  126. if len(tag) > 0 {
  127. ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
  128. } else {
  129. // even if field doesn't have validations need cTag for traversing to potential inner/nested
  130. // elements of the field.
  131. ctag = new(cTag)
  132. }
  133. cs.fields = append(cs.fields, &cField{
  134. idx: i,
  135. name: fld.Name,
  136. altName: customName,
  137. cTags: ctag,
  138. namesEqual: fld.Name == customName,
  139. })
  140. }
  141. v.structCache.Set(typ, cs)
  142. return cs
  143. }
  144. func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
  145. var t string
  146. noAlias := len(alias) == 0
  147. tags := strings.Split(tag, tagSeparator)
  148. for i := 0; i < len(tags); i++ {
  149. t = tags[i]
  150. if noAlias {
  151. alias = t
  152. }
  153. // check map for alias and process new tags, otherwise process as usual
  154. if tagsVal, found := v.aliases[t]; found {
  155. if i == 0 {
  156. firstCtag, current = v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
  157. } else {
  158. next, curr := v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
  159. current.next, current = next, curr
  160. }
  161. continue
  162. }
  163. var prevTag tagType
  164. if i == 0 {
  165. current = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true, typeof: typeDefault}
  166. firstCtag = current
  167. } else {
  168. prevTag = current.typeof
  169. current.next = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
  170. current = current.next
  171. }
  172. switch t {
  173. case diveTag:
  174. current.typeof = typeDive
  175. continue
  176. case keysTag:
  177. current.typeof = typeKeys
  178. if i == 0 || prevTag != typeDive {
  179. panic(fmt.Sprintf("'%s' tag must be immediately preceded by the '%s' tag", keysTag, diveTag))
  180. }
  181. current.typeof = typeKeys
  182. // need to pass along only keys tag
  183. // need to increment i to skip over the keys tags
  184. b := make([]byte, 0, 64)
  185. i++
  186. for ; i < len(tags); i++ {
  187. b = append(b, tags[i]...)
  188. b = append(b, ',')
  189. if tags[i] == endKeysTag {
  190. break
  191. }
  192. }
  193. current.keys, _ = v.parseFieldTagsRecursive(string(b[:len(b)-1]), fieldName, "", false)
  194. continue
  195. case endKeysTag:
  196. current.typeof = typeEndKeys
  197. // if there are more in tags then there was no keysTag defined
  198. // and an error should be thrown
  199. if i != len(tags)-1 {
  200. panic(keysTagNotDefined)
  201. }
  202. return
  203. case omitempty:
  204. current.typeof = typeOmitEmpty
  205. continue
  206. case structOnlyTag:
  207. current.typeof = typeStructOnly
  208. continue
  209. case noStructLevelTag:
  210. current.typeof = typeNoStructLevel
  211. continue
  212. default:
  213. if t == isdefault {
  214. current.typeof = typeIsDefault
  215. }
  216. // if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
  217. orVals := strings.Split(t, orSeparator)
  218. for j := 0; j < len(orVals); j++ {
  219. vals := strings.SplitN(orVals[j], tagKeySeparator, 2)
  220. if noAlias {
  221. alias = vals[0]
  222. current.aliasTag = alias
  223. } else {
  224. current.actualAliasTag = t
  225. }
  226. if j > 0 {
  227. current.next = &cTag{aliasTag: alias, actualAliasTag: current.actualAliasTag, hasAlias: hasAlias, hasTag: true}
  228. current = current.next
  229. }
  230. current.hasParam = len(vals) > 1
  231. current.tag = vals[0]
  232. if len(current.tag) == 0 {
  233. panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
  234. }
  235. if wrapper, ok := v.validations[current.tag]; ok {
  236. current.fn = wrapper.fn
  237. current.runValidationWhenNil = wrapper.runValidatinOnNil
  238. } else {
  239. panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
  240. }
  241. if len(orVals) > 1 {
  242. current.typeof = typeOr
  243. }
  244. if len(vals) > 1 {
  245. current.param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
  246. }
  247. }
  248. current.isBlockEnd = true
  249. }
  250. }
  251. return
  252. }
  253. func (v *Validate) fetchCacheTag(tag string) *cTag {
  254. // find cached tag
  255. ctag, found := v.tagCache.Get(tag)
  256. if !found {
  257. v.tagCache.lock.Lock()
  258. defer v.tagCache.lock.Unlock()
  259. // could have been multiple trying to access, but once first is done this ensures tag
  260. // isn't parsed again.
  261. ctag, found = v.tagCache.Get(tag)
  262. if !found {
  263. ctag, _ = v.parseFieldTagsRecursive(tag, "", "", false)
  264. v.tagCache.Set(tag, ctag)
  265. }
  266. }
  267. return ctag
  268. }