compiler.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927
  1. package encoder
  2. import (
  3. "context"
  4. "encoding"
  5. "encoding/json"
  6. "reflect"
  7. "sync/atomic"
  8. "unsafe"
  9. "github.com/goccy/go-json/internal/errors"
  10. "github.com/goccy/go-json/internal/runtime"
  11. )
  12. type marshalerContext interface {
  13. MarshalJSON(context.Context) ([]byte, error)
  14. }
  15. var (
  16. marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
  17. marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem()
  18. marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
  19. jsonNumberType = reflect.TypeOf(json.Number(""))
  20. cachedOpcodeSets []*OpcodeSet
  21. cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet
  22. typeAddr *runtime.TypeAddr
  23. )
  24. func init() {
  25. typeAddr = runtime.AnalyzeTypeAddr()
  26. if typeAddr == nil {
  27. typeAddr = &runtime.TypeAddr{}
  28. }
  29. cachedOpcodeSets = make([]*OpcodeSet, typeAddr.AddrRange>>typeAddr.AddrShift)
  30. }
  31. func loadOpcodeMap() map[uintptr]*OpcodeSet {
  32. p := atomic.LoadPointer(&cachedOpcodeMap)
  33. return *(*map[uintptr]*OpcodeSet)(unsafe.Pointer(&p))
  34. }
  35. func storeOpcodeSet(typ uintptr, set *OpcodeSet, m map[uintptr]*OpcodeSet) {
  36. newOpcodeMap := make(map[uintptr]*OpcodeSet, len(m)+1)
  37. newOpcodeMap[typ] = set
  38. for k, v := range m {
  39. newOpcodeMap[k] = v
  40. }
  41. atomic.StorePointer(&cachedOpcodeMap, *(*unsafe.Pointer)(unsafe.Pointer(&newOpcodeMap)))
  42. }
  43. func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) {
  44. opcodeMap := loadOpcodeMap()
  45. if codeSet, exists := opcodeMap[typeptr]; exists {
  46. return codeSet, nil
  47. }
  48. codeSet, err := newCompiler().compile(typeptr)
  49. if err != nil {
  50. return nil, err
  51. }
  52. storeOpcodeSet(typeptr, codeSet, opcodeMap)
  53. return codeSet, nil
  54. }
  55. func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) {
  56. if (ctx.Option.Flag & ContextOption) == 0 {
  57. return codeSet, nil
  58. }
  59. query := FieldQueryFromContext(ctx.Option.Context)
  60. if query == nil {
  61. return codeSet, nil
  62. }
  63. ctx.Option.Flag |= FieldQueryOption
  64. cacheCodeSet := codeSet.getQueryCache(query.Hash())
  65. if cacheCodeSet != nil {
  66. return cacheCodeSet, nil
  67. }
  68. queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query))
  69. if err != nil {
  70. return nil, err
  71. }
  72. codeSet.setQueryCache(query.Hash(), queryCodeSet)
  73. return queryCodeSet, nil
  74. }
  75. type Compiler struct {
  76. structTypeToCode map[uintptr]*StructCode
  77. }
  78. func newCompiler() *Compiler {
  79. return &Compiler{
  80. structTypeToCode: map[uintptr]*StructCode{},
  81. }
  82. }
  83. func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) {
  84. // noescape trick for header.typ ( reflect.*rtype )
  85. typ := *(**runtime.Type)(unsafe.Pointer(&typeptr))
  86. code, err := c.typeToCode(typ)
  87. if err != nil {
  88. return nil, err
  89. }
  90. return c.codeToOpcodeSet(typ, code)
  91. }
  92. func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) {
  93. noescapeKeyCode := c.codeToOpcode(&compileContext{
  94. structTypeToCodes: map[uintptr]Opcodes{},
  95. recursiveCodes: &Opcodes{},
  96. }, typ, code)
  97. if err := noescapeKeyCode.Validate(); err != nil {
  98. return nil, err
  99. }
  100. escapeKeyCode := c.codeToOpcode(&compileContext{
  101. structTypeToCodes: map[uintptr]Opcodes{},
  102. recursiveCodes: &Opcodes{},
  103. escapeKey: true,
  104. }, typ, code)
  105. noescapeKeyCode = copyOpcode(noescapeKeyCode)
  106. escapeKeyCode = copyOpcode(escapeKeyCode)
  107. setTotalLengthToInterfaceOp(noescapeKeyCode)
  108. setTotalLengthToInterfaceOp(escapeKeyCode)
  109. interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode)
  110. interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode)
  111. codeLength := noescapeKeyCode.TotalLength()
  112. return &OpcodeSet{
  113. Type: typ,
  114. NoescapeKeyCode: noescapeKeyCode,
  115. EscapeKeyCode: escapeKeyCode,
  116. InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode,
  117. InterfaceEscapeKeyCode: interfaceEscapeKeyCode,
  118. CodeLength: codeLength,
  119. EndCode: ToEndCode(interfaceNoescapeKeyCode),
  120. Code: code,
  121. QueryCache: map[string]*OpcodeSet{},
  122. }, nil
  123. }
  124. func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) {
  125. switch {
  126. case c.implementsMarshalJSON(typ):
  127. return c.marshalJSONCode(typ)
  128. case c.implementsMarshalText(typ):
  129. return c.marshalTextCode(typ)
  130. }
  131. isPtr := false
  132. orgType := typ
  133. if typ.Kind() == reflect.Ptr {
  134. typ = typ.Elem()
  135. isPtr = true
  136. }
  137. switch {
  138. case c.implementsMarshalJSON(typ):
  139. return c.marshalJSONCode(orgType)
  140. case c.implementsMarshalText(typ):
  141. return c.marshalTextCode(orgType)
  142. }
  143. switch typ.Kind() {
  144. case reflect.Slice:
  145. elem := typ.Elem()
  146. if elem.Kind() == reflect.Uint8 {
  147. p := runtime.PtrTo(elem)
  148. if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
  149. return c.bytesCode(typ, isPtr)
  150. }
  151. }
  152. return c.sliceCode(typ)
  153. case reflect.Map:
  154. if isPtr {
  155. return c.ptrCode(runtime.PtrTo(typ))
  156. }
  157. return c.mapCode(typ)
  158. case reflect.Struct:
  159. return c.structCode(typ, isPtr)
  160. case reflect.Int:
  161. return c.intCode(typ, isPtr)
  162. case reflect.Int8:
  163. return c.int8Code(typ, isPtr)
  164. case reflect.Int16:
  165. return c.int16Code(typ, isPtr)
  166. case reflect.Int32:
  167. return c.int32Code(typ, isPtr)
  168. case reflect.Int64:
  169. return c.int64Code(typ, isPtr)
  170. case reflect.Uint, reflect.Uintptr:
  171. return c.uintCode(typ, isPtr)
  172. case reflect.Uint8:
  173. return c.uint8Code(typ, isPtr)
  174. case reflect.Uint16:
  175. return c.uint16Code(typ, isPtr)
  176. case reflect.Uint32:
  177. return c.uint32Code(typ, isPtr)
  178. case reflect.Uint64:
  179. return c.uint64Code(typ, isPtr)
  180. case reflect.Float32:
  181. return c.float32Code(typ, isPtr)
  182. case reflect.Float64:
  183. return c.float64Code(typ, isPtr)
  184. case reflect.String:
  185. return c.stringCode(typ, isPtr)
  186. case reflect.Bool:
  187. return c.boolCode(typ, isPtr)
  188. case reflect.Interface:
  189. return c.interfaceCode(typ, isPtr)
  190. default:
  191. if isPtr && typ.Implements(marshalTextType) {
  192. typ = orgType
  193. }
  194. return c.typeToCodeWithPtr(typ, isPtr)
  195. }
  196. }
  197. func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) {
  198. switch {
  199. case c.implementsMarshalJSON(typ):
  200. return c.marshalJSONCode(typ)
  201. case c.implementsMarshalText(typ):
  202. return c.marshalTextCode(typ)
  203. }
  204. switch typ.Kind() {
  205. case reflect.Ptr:
  206. return c.ptrCode(typ)
  207. case reflect.Slice:
  208. elem := typ.Elem()
  209. if elem.Kind() == reflect.Uint8 {
  210. p := runtime.PtrTo(elem)
  211. if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
  212. return c.bytesCode(typ, false)
  213. }
  214. }
  215. return c.sliceCode(typ)
  216. case reflect.Array:
  217. return c.arrayCode(typ)
  218. case reflect.Map:
  219. return c.mapCode(typ)
  220. case reflect.Struct:
  221. return c.structCode(typ, isPtr)
  222. case reflect.Interface:
  223. return c.interfaceCode(typ, false)
  224. case reflect.Int:
  225. return c.intCode(typ, false)
  226. case reflect.Int8:
  227. return c.int8Code(typ, false)
  228. case reflect.Int16:
  229. return c.int16Code(typ, false)
  230. case reflect.Int32:
  231. return c.int32Code(typ, false)
  232. case reflect.Int64:
  233. return c.int64Code(typ, false)
  234. case reflect.Uint:
  235. return c.uintCode(typ, false)
  236. case reflect.Uint8:
  237. return c.uint8Code(typ, false)
  238. case reflect.Uint16:
  239. return c.uint16Code(typ, false)
  240. case reflect.Uint32:
  241. return c.uint32Code(typ, false)
  242. case reflect.Uint64:
  243. return c.uint64Code(typ, false)
  244. case reflect.Uintptr:
  245. return c.uintCode(typ, false)
  246. case reflect.Float32:
  247. return c.float32Code(typ, false)
  248. case reflect.Float64:
  249. return c.float64Code(typ, false)
  250. case reflect.String:
  251. return c.stringCode(typ, false)
  252. case reflect.Bool:
  253. return c.boolCode(typ, false)
  254. }
  255. return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
  256. }
  257. const intSize = 32 << (^uint(0) >> 63)
  258. //nolint:unparam
  259. func (c *Compiler) intCode(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  260. return &IntCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
  261. }
  262. //nolint:unparam
  263. func (c *Compiler) int8Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  264. return &IntCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
  265. }
  266. //nolint:unparam
  267. func (c *Compiler) int16Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  268. return &IntCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
  269. }
  270. //nolint:unparam
  271. func (c *Compiler) int32Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  272. return &IntCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
  273. }
  274. //nolint:unparam
  275. func (c *Compiler) int64Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  276. return &IntCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
  277. }
  278. //nolint:unparam
  279. func (c *Compiler) uintCode(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  280. return &UintCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
  281. }
  282. //nolint:unparam
  283. func (c *Compiler) uint8Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  284. return &UintCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
  285. }
  286. //nolint:unparam
  287. func (c *Compiler) uint16Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  288. return &UintCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
  289. }
  290. //nolint:unparam
  291. func (c *Compiler) uint32Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  292. return &UintCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
  293. }
  294. //nolint:unparam
  295. func (c *Compiler) uint64Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  296. return &UintCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
  297. }
  298. //nolint:unparam
  299. func (c *Compiler) float32Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
  300. return &FloatCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
  301. }
  302. //nolint:unparam
  303. func (c *Compiler) float64Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
  304. return &FloatCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
  305. }
  306. //nolint:unparam
  307. func (c *Compiler) stringCode(typ *runtime.Type, isPtr bool) (*StringCode, error) {
  308. return &StringCode{typ: typ, isPtr: isPtr}, nil
  309. }
  310. //nolint:unparam
  311. func (c *Compiler) boolCode(typ *runtime.Type, isPtr bool) (*BoolCode, error) {
  312. return &BoolCode{typ: typ, isPtr: isPtr}, nil
  313. }
  314. //nolint:unparam
  315. func (c *Compiler) intStringCode(typ *runtime.Type) (*IntCode, error) {
  316. return &IntCode{typ: typ, bitSize: intSize, isString: true}, nil
  317. }
  318. //nolint:unparam
  319. func (c *Compiler) int8StringCode(typ *runtime.Type) (*IntCode, error) {
  320. return &IntCode{typ: typ, bitSize: 8, isString: true}, nil
  321. }
  322. //nolint:unparam
  323. func (c *Compiler) int16StringCode(typ *runtime.Type) (*IntCode, error) {
  324. return &IntCode{typ: typ, bitSize: 16, isString: true}, nil
  325. }
  326. //nolint:unparam
  327. func (c *Compiler) int32StringCode(typ *runtime.Type) (*IntCode, error) {
  328. return &IntCode{typ: typ, bitSize: 32, isString: true}, nil
  329. }
  330. //nolint:unparam
  331. func (c *Compiler) int64StringCode(typ *runtime.Type) (*IntCode, error) {
  332. return &IntCode{typ: typ, bitSize: 64, isString: true}, nil
  333. }
  334. //nolint:unparam
  335. func (c *Compiler) uintStringCode(typ *runtime.Type) (*UintCode, error) {
  336. return &UintCode{typ: typ, bitSize: intSize, isString: true}, nil
  337. }
  338. //nolint:unparam
  339. func (c *Compiler) uint8StringCode(typ *runtime.Type) (*UintCode, error) {
  340. return &UintCode{typ: typ, bitSize: 8, isString: true}, nil
  341. }
  342. //nolint:unparam
  343. func (c *Compiler) uint16StringCode(typ *runtime.Type) (*UintCode, error) {
  344. return &UintCode{typ: typ, bitSize: 16, isString: true}, nil
  345. }
  346. //nolint:unparam
  347. func (c *Compiler) uint32StringCode(typ *runtime.Type) (*UintCode, error) {
  348. return &UintCode{typ: typ, bitSize: 32, isString: true}, nil
  349. }
  350. //nolint:unparam
  351. func (c *Compiler) uint64StringCode(typ *runtime.Type) (*UintCode, error) {
  352. return &UintCode{typ: typ, bitSize: 64, isString: true}, nil
  353. }
  354. //nolint:unparam
  355. func (c *Compiler) bytesCode(typ *runtime.Type, isPtr bool) (*BytesCode, error) {
  356. return &BytesCode{typ: typ, isPtr: isPtr}, nil
  357. }
  358. //nolint:unparam
  359. func (c *Compiler) interfaceCode(typ *runtime.Type, isPtr bool) (*InterfaceCode, error) {
  360. return &InterfaceCode{typ: typ, isPtr: isPtr}, nil
  361. }
  362. //nolint:unparam
  363. func (c *Compiler) marshalJSONCode(typ *runtime.Type) (*MarshalJSONCode, error) {
  364. return &MarshalJSONCode{
  365. typ: typ,
  366. isAddrForMarshaler: c.isPtrMarshalJSONType(typ),
  367. isNilableType: c.isNilableType(typ),
  368. isMarshalerContext: typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType),
  369. }, nil
  370. }
  371. //nolint:unparam
  372. func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error) {
  373. return &MarshalTextCode{
  374. typ: typ,
  375. isAddrForMarshaler: c.isPtrMarshalTextType(typ),
  376. isNilableType: c.isNilableType(typ),
  377. }, nil
  378. }
  379. func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) {
  380. code, err := c.typeToCodeWithPtr(typ.Elem(), true)
  381. if err != nil {
  382. return nil, err
  383. }
  384. ptr, ok := code.(*PtrCode)
  385. if ok {
  386. return &PtrCode{typ: typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil
  387. }
  388. return &PtrCode{typ: typ, value: code, ptrNum: 1}, nil
  389. }
  390. func (c *Compiler) sliceCode(typ *runtime.Type) (*SliceCode, error) {
  391. elem := typ.Elem()
  392. code, err := c.listElemCode(elem)
  393. if err != nil {
  394. return nil, err
  395. }
  396. if code.Kind() == CodeKindStruct {
  397. structCode := code.(*StructCode)
  398. structCode.enableIndirect()
  399. }
  400. return &SliceCode{typ: typ, value: code}, nil
  401. }
  402. func (c *Compiler) arrayCode(typ *runtime.Type) (*ArrayCode, error) {
  403. elem := typ.Elem()
  404. code, err := c.listElemCode(elem)
  405. if err != nil {
  406. return nil, err
  407. }
  408. if code.Kind() == CodeKindStruct {
  409. structCode := code.(*StructCode)
  410. structCode.enableIndirect()
  411. }
  412. return &ArrayCode{typ: typ, value: code}, nil
  413. }
  414. func (c *Compiler) mapCode(typ *runtime.Type) (*MapCode, error) {
  415. keyCode, err := c.mapKeyCode(typ.Key())
  416. if err != nil {
  417. return nil, err
  418. }
  419. valueCode, err := c.mapValueCode(typ.Elem())
  420. if err != nil {
  421. return nil, err
  422. }
  423. if valueCode.Kind() == CodeKindStruct {
  424. structCode := valueCode.(*StructCode)
  425. structCode.enableIndirect()
  426. }
  427. return &MapCode{typ: typ, key: keyCode, value: valueCode}, nil
  428. }
  429. func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) {
  430. switch {
  431. case c.isPtrMarshalJSONType(typ):
  432. return c.marshalJSONCode(typ)
  433. case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
  434. return c.marshalTextCode(typ)
  435. case typ.Kind() == reflect.Map:
  436. return c.ptrCode(runtime.PtrTo(typ))
  437. default:
  438. code, err := c.typeToCodeWithPtr(typ, false)
  439. if err != nil {
  440. return nil, err
  441. }
  442. ptr, ok := code.(*PtrCode)
  443. if ok {
  444. if ptr.value.Kind() == CodeKindMap {
  445. ptr.ptrNum++
  446. }
  447. }
  448. return code, nil
  449. }
  450. }
  451. func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) {
  452. switch {
  453. case c.implementsMarshalJSON(typ):
  454. return c.marshalJSONCode(typ)
  455. case c.implementsMarshalText(typ):
  456. return c.marshalTextCode(typ)
  457. }
  458. switch typ.Kind() {
  459. case reflect.Ptr:
  460. return c.ptrCode(typ)
  461. case reflect.String:
  462. return c.stringCode(typ, false)
  463. case reflect.Int:
  464. return c.intStringCode(typ)
  465. case reflect.Int8:
  466. return c.int8StringCode(typ)
  467. case reflect.Int16:
  468. return c.int16StringCode(typ)
  469. case reflect.Int32:
  470. return c.int32StringCode(typ)
  471. case reflect.Int64:
  472. return c.int64StringCode(typ)
  473. case reflect.Uint:
  474. return c.uintStringCode(typ)
  475. case reflect.Uint8:
  476. return c.uint8StringCode(typ)
  477. case reflect.Uint16:
  478. return c.uint16StringCode(typ)
  479. case reflect.Uint32:
  480. return c.uint32StringCode(typ)
  481. case reflect.Uint64:
  482. return c.uint64StringCode(typ)
  483. case reflect.Uintptr:
  484. return c.uintStringCode(typ)
  485. }
  486. return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
  487. }
  488. func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) {
  489. switch typ.Kind() {
  490. case reflect.Map:
  491. return c.ptrCode(runtime.PtrTo(typ))
  492. default:
  493. code, err := c.typeToCodeWithPtr(typ, false)
  494. if err != nil {
  495. return nil, err
  496. }
  497. ptr, ok := code.(*PtrCode)
  498. if ok {
  499. if ptr.value.Kind() == CodeKindMap {
  500. ptr.ptrNum++
  501. }
  502. }
  503. return code, nil
  504. }
  505. }
  506. func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) {
  507. typeptr := uintptr(unsafe.Pointer(typ))
  508. if code, exists := c.structTypeToCode[typeptr]; exists {
  509. derefCode := *code
  510. derefCode.isRecursive = true
  511. return &derefCode, nil
  512. }
  513. indirect := runtime.IfaceIndir(typ)
  514. code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect}
  515. c.structTypeToCode[typeptr] = code
  516. fieldNum := typ.NumField()
  517. tags := c.typeToStructTags(typ)
  518. fields := []*StructFieldCode{}
  519. for i, tag := range tags {
  520. isOnlyOneFirstField := i == 0 && fieldNum == 1
  521. field, err := c.structFieldCode(code, tag, isPtr, isOnlyOneFirstField)
  522. if err != nil {
  523. return nil, err
  524. }
  525. if field.isAnonymous {
  526. structCode := field.getAnonymousStruct()
  527. if structCode != nil {
  528. structCode.removeFieldsByTags(tags)
  529. if c.isAssignableIndirect(field, isPtr) {
  530. if indirect {
  531. structCode.isIndirect = true
  532. } else {
  533. structCode.isIndirect = false
  534. }
  535. }
  536. }
  537. } else {
  538. structCode := field.getStruct()
  539. if structCode != nil {
  540. if indirect {
  541. // if parent is indirect type, set child indirect property to true
  542. structCode.isIndirect = true
  543. } else {
  544. // if parent is not indirect type, set child indirect property to false.
  545. // but if parent's indirect is false and isPtr is true, then indirect must be true.
  546. // Do this only if indirectConversion is enabled at the end of compileStruct.
  547. structCode.isIndirect = false
  548. }
  549. }
  550. }
  551. fields = append(fields, field)
  552. }
  553. fieldMap := c.getFieldMap(fields)
  554. duplicatedFieldMap := c.getDuplicatedFieldMap(fieldMap)
  555. code.fields = c.filteredDuplicatedFields(fields, duplicatedFieldMap)
  556. if !code.disableIndirectConversion && !indirect && isPtr {
  557. code.enableIndirect()
  558. }
  559. delete(c.structTypeToCode, typeptr)
  560. return code, nil
  561. }
  562. func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTag, isPtr, isOnlyOneFirstField bool) (*StructFieldCode, error) {
  563. field := tag.Field
  564. fieldType := runtime.Type2RType(field.Type)
  565. isIndirectSpecialCase := isPtr && isOnlyOneFirstField
  566. fieldCode := &StructFieldCode{
  567. typ: fieldType,
  568. key: tag.Key,
  569. tag: tag,
  570. offset: field.Offset,
  571. isAnonymous: field.Anonymous && !tag.IsTaggedKey,
  572. isTaggedKey: tag.IsTaggedKey,
  573. isNilableType: c.isNilableType(fieldType),
  574. isNilCheck: true,
  575. }
  576. switch {
  577. case c.isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(fieldType, isIndirectSpecialCase):
  578. code, err := c.marshalJSONCode(fieldType)
  579. if err != nil {
  580. return nil, err
  581. }
  582. fieldCode.value = code
  583. fieldCode.isAddrForMarshaler = true
  584. fieldCode.isNilCheck = false
  585. structCode.isIndirect = false
  586. structCode.disableIndirectConversion = true
  587. case c.isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(fieldType, isIndirectSpecialCase):
  588. code, err := c.marshalTextCode(fieldType)
  589. if err != nil {
  590. return nil, err
  591. }
  592. fieldCode.value = code
  593. fieldCode.isAddrForMarshaler = true
  594. fieldCode.isNilCheck = false
  595. structCode.isIndirect = false
  596. structCode.disableIndirectConversion = true
  597. case isPtr && c.isPtrMarshalJSONType(fieldType):
  598. // *struct{ field T }
  599. // func (*T) MarshalJSON() ([]byte, error)
  600. code, err := c.marshalJSONCode(fieldType)
  601. if err != nil {
  602. return nil, err
  603. }
  604. fieldCode.value = code
  605. fieldCode.isAddrForMarshaler = true
  606. fieldCode.isNilCheck = false
  607. case isPtr && c.isPtrMarshalTextType(fieldType):
  608. // *struct{ field T }
  609. // func (*T) MarshalText() ([]byte, error)
  610. code, err := c.marshalTextCode(fieldType)
  611. if err != nil {
  612. return nil, err
  613. }
  614. fieldCode.value = code
  615. fieldCode.isAddrForMarshaler = true
  616. fieldCode.isNilCheck = false
  617. default:
  618. code, err := c.typeToCodeWithPtr(fieldType, isPtr)
  619. if err != nil {
  620. return nil, err
  621. }
  622. switch code.Kind() {
  623. case CodeKindPtr, CodeKindInterface:
  624. fieldCode.isNextOpPtrType = true
  625. }
  626. fieldCode.value = code
  627. }
  628. return fieldCode, nil
  629. }
  630. func (c *Compiler) isAssignableIndirect(fieldCode *StructFieldCode, isPtr bool) bool {
  631. if isPtr {
  632. return false
  633. }
  634. codeType := fieldCode.value.Kind()
  635. if codeType == CodeKindMarshalJSON {
  636. return false
  637. }
  638. if codeType == CodeKindMarshalText {
  639. return false
  640. }
  641. return true
  642. }
  643. func (c *Compiler) getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode {
  644. fieldMap := map[string][]*StructFieldCode{}
  645. for _, field := range fields {
  646. if field.isAnonymous {
  647. for k, v := range c.getAnonymousFieldMap(field) {
  648. fieldMap[k] = append(fieldMap[k], v...)
  649. }
  650. continue
  651. }
  652. fieldMap[field.key] = append(fieldMap[field.key], field)
  653. }
  654. return fieldMap
  655. }
  656. func (c *Compiler) getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode {
  657. fieldMap := map[string][]*StructFieldCode{}
  658. structCode := field.getAnonymousStruct()
  659. if structCode == nil || structCode.isRecursive {
  660. fieldMap[field.key] = append(fieldMap[field.key], field)
  661. return fieldMap
  662. }
  663. for k, v := range c.getFieldMapFromAnonymousParent(structCode.fields) {
  664. fieldMap[k] = append(fieldMap[k], v...)
  665. }
  666. return fieldMap
  667. }
  668. func (c *Compiler) getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode {
  669. fieldMap := map[string][]*StructFieldCode{}
  670. for _, field := range fields {
  671. if field.isAnonymous {
  672. for k, v := range c.getAnonymousFieldMap(field) {
  673. // Do not handle tagged key when embedding more than once
  674. for _, vv := range v {
  675. vv.isTaggedKey = false
  676. }
  677. fieldMap[k] = append(fieldMap[k], v...)
  678. }
  679. continue
  680. }
  681. fieldMap[field.key] = append(fieldMap[field.key], field)
  682. }
  683. return fieldMap
  684. }
  685. func (c *Compiler) getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} {
  686. duplicatedFieldMap := map[*StructFieldCode]struct{}{}
  687. for _, fields := range fieldMap {
  688. if len(fields) == 1 {
  689. continue
  690. }
  691. if c.isTaggedKeyOnly(fields) {
  692. for _, field := range fields {
  693. if field.isTaggedKey {
  694. continue
  695. }
  696. duplicatedFieldMap[field] = struct{}{}
  697. }
  698. } else {
  699. for _, field := range fields {
  700. duplicatedFieldMap[field] = struct{}{}
  701. }
  702. }
  703. }
  704. return duplicatedFieldMap
  705. }
  706. func (c *Compiler) filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode {
  707. filteredFields := make([]*StructFieldCode, 0, len(fields))
  708. for _, field := range fields {
  709. if field.isAnonymous {
  710. structCode := field.getAnonymousStruct()
  711. if structCode != nil && !structCode.isRecursive {
  712. structCode.fields = c.filteredDuplicatedFields(structCode.fields, duplicatedFieldMap)
  713. if len(structCode.fields) > 0 {
  714. filteredFields = append(filteredFields, field)
  715. }
  716. continue
  717. }
  718. }
  719. if _, exists := duplicatedFieldMap[field]; exists {
  720. continue
  721. }
  722. filteredFields = append(filteredFields, field)
  723. }
  724. return filteredFields
  725. }
  726. func (c *Compiler) isTaggedKeyOnly(fields []*StructFieldCode) bool {
  727. var taggedKeyFieldCount int
  728. for _, field := range fields {
  729. if field.isTaggedKey {
  730. taggedKeyFieldCount++
  731. }
  732. }
  733. return taggedKeyFieldCount == 1
  734. }
  735. func (c *Compiler) typeToStructTags(typ *runtime.Type) runtime.StructTags {
  736. tags := runtime.StructTags{}
  737. fieldNum := typ.NumField()
  738. for i := 0; i < fieldNum; i++ {
  739. field := typ.Field(i)
  740. if runtime.IsIgnoredStructField(field) {
  741. continue
  742. }
  743. tags = append(tags, runtime.StructTagFromField(field))
  744. }
  745. return tags
  746. }
  747. // *struct{ field T } => struct { field *T }
  748. // func (*T) MarshalJSON() ([]byte, error)
  749. func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
  750. return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalJSONType(typ)
  751. }
  752. // *struct{ field T } => struct { field *T }
  753. // func (*T) MarshalText() ([]byte, error)
  754. func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
  755. return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalTextType(typ)
  756. }
  757. func (c *Compiler) implementsMarshalJSON(typ *runtime.Type) bool {
  758. if !c.implementsMarshalJSONType(typ) {
  759. return false
  760. }
  761. if typ.Kind() != reflect.Ptr {
  762. return true
  763. }
  764. // type kind is reflect.Ptr
  765. if !c.implementsMarshalJSONType(typ.Elem()) {
  766. return true
  767. }
  768. // needs to dereference
  769. return false
  770. }
  771. func (c *Compiler) implementsMarshalText(typ *runtime.Type) bool {
  772. if !typ.Implements(marshalTextType) {
  773. return false
  774. }
  775. if typ.Kind() != reflect.Ptr {
  776. return true
  777. }
  778. // type kind is reflect.Ptr
  779. if !typ.Elem().Implements(marshalTextType) {
  780. return true
  781. }
  782. // needs to dereference
  783. return false
  784. }
  785. func (c *Compiler) isNilableType(typ *runtime.Type) bool {
  786. if !runtime.IfaceIndir(typ) {
  787. return true
  788. }
  789. switch typ.Kind() {
  790. case reflect.Ptr:
  791. return true
  792. case reflect.Map:
  793. return true
  794. case reflect.Func:
  795. return true
  796. default:
  797. return false
  798. }
  799. }
  800. func (c *Compiler) implementsMarshalJSONType(typ *runtime.Type) bool {
  801. return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType)
  802. }
  803. func (c *Compiler) isPtrMarshalJSONType(typ *runtime.Type) bool {
  804. return !c.implementsMarshalJSONType(typ) && c.implementsMarshalJSONType(runtime.PtrTo(typ))
  805. }
  806. func (c *Compiler) isPtrMarshalTextType(typ *runtime.Type) bool {
  807. return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType)
  808. }
  809. func (c *Compiler) codeToOpcode(ctx *compileContext, typ *runtime.Type, code Code) *Opcode {
  810. codes := code.ToOpcode(ctx)
  811. codes.Last().Next = newEndOp(ctx, typ)
  812. c.linkRecursiveCode(ctx)
  813. return codes.First()
  814. }
  815. func (c *Compiler) linkRecursiveCode(ctx *compileContext) {
  816. recursiveCodes := map[uintptr]*CompiledCode{}
  817. for _, recursive := range *ctx.recursiveCodes {
  818. typeptr := uintptr(unsafe.Pointer(recursive.Type))
  819. codes := ctx.structTypeToCodes[typeptr]
  820. if recursiveCode, ok := recursiveCodes[typeptr]; ok {
  821. *recursive.Jmp = *recursiveCode
  822. continue
  823. }
  824. code := copyOpcode(codes.First())
  825. code.Op = code.Op.PtrHeadToHead()
  826. lastCode := newEndOp(&compileContext{}, recursive.Type)
  827. lastCode.Op = OpRecursiveEnd
  828. // OpRecursiveEnd must set before call TotalLength
  829. code.End.Next = lastCode
  830. totalLength := code.TotalLength()
  831. // Idx, ElemIdx, Length must set after call TotalLength
  832. lastCode.Idx = uint32((totalLength + 1) * uintptrSize)
  833. lastCode.ElemIdx = lastCode.Idx + uintptrSize
  834. lastCode.Length = lastCode.Idx + 2*uintptrSize
  835. // extend length to alloc slot for elemIdx + length
  836. curTotalLength := uintptr(recursive.TotalLength()) + 3
  837. nextTotalLength := uintptr(totalLength) + 3
  838. compiled := recursive.Jmp
  839. compiled.Code = code
  840. compiled.CurLen = curTotalLength
  841. compiled.NextLen = nextTotalLength
  842. compiled.Linked = true
  843. recursiveCodes[typeptr] = compiled
  844. }
  845. }