seen.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. package tracker
  2. import (
  3. "bytes"
  4. "fmt"
  5. "sync"
  6. "github.com/pelletier/go-toml/v2/internal/ast"
  7. )
  8. type keyKind uint8
  9. const (
  10. invalidKind keyKind = iota
  11. valueKind
  12. tableKind
  13. arrayTableKind
  14. )
  15. func (k keyKind) String() string {
  16. switch k {
  17. case invalidKind:
  18. return "invalid"
  19. case valueKind:
  20. return "value"
  21. case tableKind:
  22. return "table"
  23. case arrayTableKind:
  24. return "array table"
  25. }
  26. panic("missing keyKind string mapping")
  27. }
  28. // SeenTracker tracks which keys have been seen with which TOML type to flag
  29. // duplicates and mismatches according to the spec.
  30. //
  31. // Each node in the visited tree is represented by an entry. Each entry has an
  32. // identifier, which is provided by a counter. Entries are stored in the array
  33. // entries. As new nodes are discovered (referenced for the first time in the
  34. // TOML document), entries are created and appended to the array. An entry
  35. // points to its parent using its id.
  36. //
  37. // To find whether a given key (sequence of []byte) has already been visited,
  38. // the entries are linearly searched, looking for one with the right name and
  39. // parent id.
  40. //
  41. // Given that all keys appear in the document after their parent, it is
  42. // guaranteed that all descendants of a node are stored after the node, this
  43. // speeds up the search process.
  44. //
  45. // When encountering [[array tables]], the descendants of that node are removed
  46. // to allow that branch of the tree to be "rediscovered". To maintain the
  47. // invariant above, the deletion process needs to keep the order of entries.
  48. // This results in more copies in that case.
  49. type SeenTracker struct {
  50. entries []entry
  51. currentIdx int
  52. }
  53. var pool sync.Pool
  54. func (s *SeenTracker) reset() {
  55. // Always contains a root element at index 0.
  56. s.currentIdx = 0
  57. if len(s.entries) == 0 {
  58. s.entries = make([]entry, 1, 2)
  59. } else {
  60. s.entries = s.entries[:1]
  61. }
  62. s.entries[0].child = -1
  63. s.entries[0].next = -1
  64. }
  65. type entry struct {
  66. // Use -1 to indicate no child or no sibling.
  67. child int
  68. next int
  69. name []byte
  70. kind keyKind
  71. explicit bool
  72. kv bool
  73. }
  74. // Find the index of the child of parentIdx with key k. Returns -1 if
  75. // it does not exist.
  76. func (s *SeenTracker) find(parentIdx int, k []byte) int {
  77. for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
  78. if bytes.Equal(s.entries[i].name, k) {
  79. return i
  80. }
  81. }
  82. return -1
  83. }
  84. // Remove all descendants of node at position idx.
  85. func (s *SeenTracker) clear(idx int) {
  86. if idx >= len(s.entries) {
  87. return
  88. }
  89. for i := s.entries[idx].child; i >= 0; {
  90. next := s.entries[i].next
  91. n := s.entries[0].next
  92. s.entries[0].next = i
  93. s.entries[i].next = n
  94. s.entries[i].name = nil
  95. s.clear(i)
  96. i = next
  97. }
  98. s.entries[idx].child = -1
  99. }
  100. func (s *SeenTracker) create(parentIdx int, name []byte, kind keyKind, explicit bool, kv bool) int {
  101. e := entry{
  102. child: -1,
  103. next: s.entries[parentIdx].child,
  104. name: name,
  105. kind: kind,
  106. explicit: explicit,
  107. kv: kv,
  108. }
  109. var idx int
  110. if s.entries[0].next >= 0 {
  111. idx = s.entries[0].next
  112. s.entries[0].next = s.entries[idx].next
  113. s.entries[idx] = e
  114. } else {
  115. idx = len(s.entries)
  116. s.entries = append(s.entries, e)
  117. }
  118. s.entries[parentIdx].child = idx
  119. return idx
  120. }
  121. func (s *SeenTracker) setExplicitFlag(parentIdx int) {
  122. for i := s.entries[parentIdx].child; i >= 0; i = s.entries[i].next {
  123. if s.entries[i].kv {
  124. s.entries[i].explicit = true
  125. s.entries[i].kv = false
  126. }
  127. s.setExplicitFlag(i)
  128. }
  129. }
  130. // CheckExpression takes a top-level node and checks that it does not contain
  131. // keys that have been seen in previous calls, and validates that types are
  132. // consistent.
  133. func (s *SeenTracker) CheckExpression(node *ast.Node) error {
  134. if s.entries == nil {
  135. s.reset()
  136. }
  137. switch node.Kind {
  138. case ast.KeyValue:
  139. return s.checkKeyValue(node)
  140. case ast.Table:
  141. return s.checkTable(node)
  142. case ast.ArrayTable:
  143. return s.checkArrayTable(node)
  144. default:
  145. panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
  146. }
  147. }
  148. func (s *SeenTracker) checkTable(node *ast.Node) error {
  149. if s.currentIdx >= 0 {
  150. s.setExplicitFlag(s.currentIdx)
  151. }
  152. it := node.Key()
  153. parentIdx := 0
  154. // This code is duplicated in checkArrayTable. This is because factoring
  155. // it in a function requires to copy the iterator, or allocate it to the
  156. // heap, which is not cheap.
  157. for it.Next() {
  158. if it.IsLast() {
  159. break
  160. }
  161. k := it.Node().Data
  162. idx := s.find(parentIdx, k)
  163. if idx < 0 {
  164. idx = s.create(parentIdx, k, tableKind, false, false)
  165. } else {
  166. entry := s.entries[idx]
  167. if entry.kind == valueKind {
  168. return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
  169. }
  170. }
  171. parentIdx = idx
  172. }
  173. k := it.Node().Data
  174. idx := s.find(parentIdx, k)
  175. if idx >= 0 {
  176. kind := s.entries[idx].kind
  177. if kind != tableKind {
  178. return fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
  179. }
  180. if s.entries[idx].explicit {
  181. return fmt.Errorf("toml: table %s already exists", string(k))
  182. }
  183. s.entries[idx].explicit = true
  184. } else {
  185. idx = s.create(parentIdx, k, tableKind, true, false)
  186. }
  187. s.currentIdx = idx
  188. return nil
  189. }
  190. func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
  191. if s.currentIdx >= 0 {
  192. s.setExplicitFlag(s.currentIdx)
  193. }
  194. it := node.Key()
  195. parentIdx := 0
  196. for it.Next() {
  197. if it.IsLast() {
  198. break
  199. }
  200. k := it.Node().Data
  201. idx := s.find(parentIdx, k)
  202. if idx < 0 {
  203. idx = s.create(parentIdx, k, tableKind, false, false)
  204. } else {
  205. entry := s.entries[idx]
  206. if entry.kind == valueKind {
  207. return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
  208. }
  209. }
  210. parentIdx = idx
  211. }
  212. k := it.Node().Data
  213. idx := s.find(parentIdx, k)
  214. if idx >= 0 {
  215. kind := s.entries[idx].kind
  216. if kind != arrayTableKind {
  217. return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
  218. }
  219. s.clear(idx)
  220. } else {
  221. idx = s.create(parentIdx, k, arrayTableKind, true, false)
  222. }
  223. s.currentIdx = idx
  224. return nil
  225. }
  226. func (s *SeenTracker) checkKeyValue(node *ast.Node) error {
  227. parentIdx := s.currentIdx
  228. it := node.Key()
  229. for it.Next() {
  230. k := it.Node().Data
  231. idx := s.find(parentIdx, k)
  232. if idx < 0 {
  233. idx = s.create(parentIdx, k, tableKind, false, true)
  234. } else {
  235. entry := s.entries[idx]
  236. if it.IsLast() {
  237. return fmt.Errorf("toml: key %s is already defined", string(k))
  238. } else if entry.kind != tableKind {
  239. return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
  240. } else if entry.explicit {
  241. return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
  242. }
  243. }
  244. parentIdx = idx
  245. }
  246. s.entries[parentIdx].kind = valueKind
  247. value := node.Value()
  248. switch value.Kind {
  249. case ast.InlineTable:
  250. return s.checkInlineTable(value)
  251. case ast.Array:
  252. return s.checkArray(value)
  253. }
  254. return nil
  255. }
  256. func (s *SeenTracker) checkArray(node *ast.Node) error {
  257. it := node.Children()
  258. for it.Next() {
  259. n := it.Node()
  260. switch n.Kind {
  261. case ast.InlineTable:
  262. err := s.checkInlineTable(n)
  263. if err != nil {
  264. return err
  265. }
  266. case ast.Array:
  267. err := s.checkArray(n)
  268. if err != nil {
  269. return err
  270. }
  271. }
  272. }
  273. return nil
  274. }
  275. func (s *SeenTracker) checkInlineTable(node *ast.Node) error {
  276. if pool.New == nil {
  277. pool.New = func() interface{} {
  278. return &SeenTracker{}
  279. }
  280. }
  281. s = pool.Get().(*SeenTracker)
  282. s.reset()
  283. it := node.Children()
  284. for it.Next() {
  285. n := it.Node()
  286. err := s.checkKeyValue(n)
  287. if err != nil {
  288. return err
  289. }
  290. }
  291. // As inline tables are self-contained, the tracker does not
  292. // need to retain the details of what they contain. The
  293. // keyValue element that creates the inline table is kept to
  294. // mark the presence of the inline table and prevent
  295. // redefinition of its keys: check* functions cannot walk into
  296. // a value.
  297. pool.Put(s)
  298. return nil
  299. }