123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- /*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
- package thrift
- import (
- "errors"
- "fmt"
- "io"
- "sync"
- "sync/atomic"
- "time"
- )
- // ErrAbandonRequest is a special error server handler implementations can
- // return to indicate that the request has been abandoned.
- //
- // TSimpleServer will check for this error, and close the client connection
- // instead of writing the response/error back to the client.
- //
- // It shall only be used when the server handler implementation know that the
- // client already abandoned the request (by checking that the passed in context
- // is already canceled, for example).
- var ErrAbandonRequest = errors.New("request abandoned")
- // ServerConnectivityCheckInterval defines the ticker interval used by
- // connectivity check in thrift compiled TProcessorFunc implementations.
- //
- // It's defined as a variable instead of constant, so that thrift server
- // implementations can change its value to control the behavior.
- //
- // If it's changed to <=0, the feature will be disabled.
- var ServerConnectivityCheckInterval = time.Millisecond * 5
- /*
- * This is not a typical TSimpleServer as it is not blocked after accept a socket.
- * It is more like a TThreadedServer that can handle different connections in different goroutines.
- * This will work if golang user implements a conn-pool like thing in client side.
- */
- type TSimpleServer struct {
- closed int32
- wg sync.WaitGroup
- mu sync.Mutex
- processorFactory TProcessorFactory
- serverTransport TServerTransport
- inputTransportFactory TTransportFactory
- outputTransportFactory TTransportFactory
- inputProtocolFactory TProtocolFactory
- outputProtocolFactory TProtocolFactory
- // Headers to auto forward in THeaderProtocol
- forwardHeaders []string
- logger Logger
- }
- func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer {
- return NewTSimpleServerFactory2(NewTProcessorFactory(processor), serverTransport)
- }
- func NewTSimpleServer4(processor TProcessor, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer {
- return NewTSimpleServerFactory4(NewTProcessorFactory(processor),
- serverTransport,
- transportFactory,
- protocolFactory,
- )
- }
- func NewTSimpleServer6(processor TProcessor, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
- return NewTSimpleServerFactory6(NewTProcessorFactory(processor),
- serverTransport,
- inputTransportFactory,
- outputTransportFactory,
- inputProtocolFactory,
- outputProtocolFactory,
- )
- }
- func NewTSimpleServerFactory2(processorFactory TProcessorFactory, serverTransport TServerTransport) *TSimpleServer {
- return NewTSimpleServerFactory6(processorFactory,
- serverTransport,
- NewTTransportFactory(),
- NewTTransportFactory(),
- NewTBinaryProtocolFactoryDefault(),
- NewTBinaryProtocolFactoryDefault(),
- )
- }
- func NewTSimpleServerFactory4(processorFactory TProcessorFactory, serverTransport TServerTransport, transportFactory TTransportFactory, protocolFactory TProtocolFactory) *TSimpleServer {
- return NewTSimpleServerFactory6(processorFactory,
- serverTransport,
- transportFactory,
- transportFactory,
- protocolFactory,
- protocolFactory,
- )
- }
- func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
- return &TSimpleServer{
- processorFactory: processorFactory,
- serverTransport: serverTransport,
- inputTransportFactory: inputTransportFactory,
- outputTransportFactory: outputTransportFactory,
- inputProtocolFactory: inputProtocolFactory,
- outputProtocolFactory: outputProtocolFactory,
- }
- }
- func (p *TSimpleServer) ProcessorFactory() TProcessorFactory {
- return p.processorFactory
- }
- func (p *TSimpleServer) ServerTransport() TServerTransport {
- return p.serverTransport
- }
- func (p *TSimpleServer) InputTransportFactory() TTransportFactory {
- return p.inputTransportFactory
- }
- func (p *TSimpleServer) OutputTransportFactory() TTransportFactory {
- return p.outputTransportFactory
- }
- func (p *TSimpleServer) InputProtocolFactory() TProtocolFactory {
- return p.inputProtocolFactory
- }
- func (p *TSimpleServer) OutputProtocolFactory() TProtocolFactory {
- return p.outputProtocolFactory
- }
- func (p *TSimpleServer) Listen() error {
- return p.serverTransport.Listen()
- }
- // SetForwardHeaders sets the list of header keys that will be auto forwarded
- // while using THeaderProtocol.
- //
- // "forward" means that when the server is also a client to other upstream
- // thrift servers, the context object user gets in the processor functions will
- // have both read and write headers set, with write headers being forwarded.
- // Users can always override the write headers by calling SetWriteHeaderList
- // before calling thrift client functions.
- func (p *TSimpleServer) SetForwardHeaders(headers []string) {
- size := len(headers)
- if size == 0 {
- p.forwardHeaders = nil
- return
- }
- keys := make([]string, size)
- copy(keys, headers)
- p.forwardHeaders = keys
- }
- // SetLogger sets the logger used by this TSimpleServer.
- //
- // If no logger was set before Serve is called, a default logger using standard
- // log library will be used.
- func (p *TSimpleServer) SetLogger(logger Logger) {
- p.logger = logger
- }
- func (p *TSimpleServer) innerAccept() (int32, error) {
- client, err := p.serverTransport.Accept()
- p.mu.Lock()
- defer p.mu.Unlock()
- closed := atomic.LoadInt32(&p.closed)
- if closed != 0 {
- return closed, nil
- }
- if err != nil {
- return 0, err
- }
- if client != nil {
- p.wg.Add(1)
- go func() {
- defer p.wg.Done()
- if err := p.processRequests(client); err != nil {
- p.logger(fmt.Sprintf("error processing request: %v", err))
- }
- }()
- }
- return 0, nil
- }
- func (p *TSimpleServer) AcceptLoop() error {
- for {
- closed, err := p.innerAccept()
- if err != nil {
- return err
- }
- if closed != 0 {
- return nil
- }
- }
- }
- func (p *TSimpleServer) Serve() error {
- p.logger = fallbackLogger(p.logger)
- err := p.Listen()
- if err != nil {
- return err
- }
- p.AcceptLoop()
- return nil
- }
- func (p *TSimpleServer) Stop() error {
- p.mu.Lock()
- defer p.mu.Unlock()
- if atomic.LoadInt32(&p.closed) != 0 {
- return nil
- }
- atomic.StoreInt32(&p.closed, 1)
- p.serverTransport.Interrupt()
- p.wg.Wait()
- return nil
- }
- // If err is actually EOF, return nil, otherwise return err as-is.
- func treatEOFErrorsAsNil(err error) error {
- if err == nil {
- return nil
- }
- if errors.Is(err, io.EOF) {
- return nil
- }
- var te TTransportException
- if errors.As(err, &te) && te.TypeId() == END_OF_FILE {
- return nil
- }
- return err
- }
- func (p *TSimpleServer) processRequests(client TTransport) (err error) {
- defer func() {
- err = treatEOFErrorsAsNil(err)
- }()
- processor := p.processorFactory.GetProcessor(client)
- inputTransport, err := p.inputTransportFactory.GetTransport(client)
- if err != nil {
- return err
- }
- inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport)
- var outputTransport TTransport
- var outputProtocol TProtocol
- // for THeaderProtocol, we must use the same protocol instance for
- // input and output so that the response is in the same dialect that
- // the server detected the request was in.
- headerProtocol, ok := inputProtocol.(*THeaderProtocol)
- if ok {
- outputProtocol = inputProtocol
- } else {
- oTrans, err := p.outputTransportFactory.GetTransport(client)
- if err != nil {
- return err
- }
- outputTransport = oTrans
- outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport)
- }
- if inputTransport != nil {
- defer inputTransport.Close()
- }
- if outputTransport != nil {
- defer outputTransport.Close()
- }
- for {
- if atomic.LoadInt32(&p.closed) != 0 {
- return nil
- }
- ctx := SetResponseHelper(
- defaultCtx,
- TResponseHelper{
- THeaderResponseHelper: NewTHeaderResponseHelper(outputProtocol),
- },
- )
- if headerProtocol != nil {
- // We need to call ReadFrame here, otherwise we won't
- // get any headers on the AddReadTHeaderToContext call.
- //
- // ReadFrame is safe to be called multiple times so it
- // won't break when it's called again later when we
- // actually start to read the message.
- if err := headerProtocol.ReadFrame(ctx); err != nil {
- return err
- }
- ctx = AddReadTHeaderToContext(ctx, headerProtocol.GetReadHeaders())
- ctx = SetWriteHeaderList(ctx, p.forwardHeaders)
- }
- ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
- if errors.Is(err, ErrAbandonRequest) {
- return client.Close()
- }
- if errors.As(err, new(TTransportException)) && err != nil {
- return err
- }
- var tae TApplicationException
- if errors.As(err, &tae) && tae.TypeId() == UNKNOWN_METHOD {
- continue
- }
- if !ok {
- break
- }
- }
- return nil
- }
|