client.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package thrift
  2. import (
  3. "context"
  4. "fmt"
  5. )
  6. // ResponseMeta represents the metadata attached to the response.
  7. type ResponseMeta struct {
  8. // The headers in the response, if any.
  9. // If the underlying transport/protocol is not THeader, this will always be nil.
  10. Headers THeaderMap
  11. }
  12. type TClient interface {
  13. Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error)
  14. }
  15. type TStandardClient struct {
  16. seqId int32
  17. iprot, oprot TProtocol
  18. }
  19. // TStandardClient implements TClient, and uses the standard message format for Thrift.
  20. // It is not safe for concurrent use.
  21. func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
  22. return &TStandardClient{
  23. iprot: inputProtocol,
  24. oprot: outputProtocol,
  25. }
  26. }
  27. func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
  28. // Set headers from context object on THeaderProtocol
  29. if headerProt, ok := oprot.(*THeaderProtocol); ok {
  30. headerProt.ClearWriteHeaders()
  31. for _, key := range GetWriteHeaderList(ctx) {
  32. if value, ok := GetHeader(ctx, key); ok {
  33. headerProt.SetWriteHeader(key, value)
  34. }
  35. }
  36. }
  37. if err := oprot.WriteMessageBegin(ctx, method, CALL, seqId); err != nil {
  38. return err
  39. }
  40. if err := args.Write(ctx, oprot); err != nil {
  41. return err
  42. }
  43. if err := oprot.WriteMessageEnd(ctx); err != nil {
  44. return err
  45. }
  46. return oprot.Flush(ctx)
  47. }
  48. func (p *TStandardClient) Recv(ctx context.Context, iprot TProtocol, seqId int32, method string, result TStruct) error {
  49. rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin(ctx)
  50. if err != nil {
  51. return err
  52. }
  53. if method != rMethod {
  54. return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
  55. } else if seqId != rSeqId {
  56. return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
  57. } else if rTypeId == EXCEPTION {
  58. var exception tApplicationException
  59. if err := exception.Read(ctx, iprot); err != nil {
  60. return err
  61. }
  62. if err := iprot.ReadMessageEnd(ctx); err != nil {
  63. return err
  64. }
  65. return &exception
  66. } else if rTypeId != REPLY {
  67. return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
  68. }
  69. if err := result.Read(ctx, iprot); err != nil {
  70. return err
  71. }
  72. return iprot.ReadMessageEnd(ctx)
  73. }
  74. func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) (ResponseMeta, error) {
  75. p.seqId++
  76. seqId := p.seqId
  77. if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
  78. return ResponseMeta{}, err
  79. }
  80. // method is oneway
  81. if result == nil {
  82. return ResponseMeta{}, nil
  83. }
  84. err := p.Recv(ctx, p.iprot, seqId, method, result)
  85. var headers THeaderMap
  86. if hp, ok := p.iprot.(*THeaderProtocol); ok {
  87. headers = hp.transport.readHeaders
  88. }
  89. return ResponseMeta{
  90. Headers: headers,
  91. }, err
  92. }