endless.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. package endless
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "net/http"
  9. "os"
  10. "os/exec"
  11. "os/signal"
  12. "runtime"
  13. "strings"
  14. "sync"
  15. "syscall"
  16. "time"
  17. // "github.com/fvbock/uds-go/introspect"
  18. )
  19. const (
  20. PRE_SIGNAL = iota
  21. POST_SIGNAL
  22. STATE_INIT
  23. STATE_RUNNING
  24. STATE_SHUTTING_DOWN
  25. STATE_TERMINATE
  26. )
  27. var (
  28. runningServerReg sync.RWMutex
  29. runningServers map[string]*endlessServer
  30. runningServersOrder []string
  31. socketPtrOffsetMap map[string]uint
  32. runningServersForked bool
  33. DefaultReadTimeOut time.Duration
  34. DefaultWriteTimeOut time.Duration
  35. DefaultMaxHeaderBytes int
  36. DefaultHammerTime time.Duration
  37. isChild bool
  38. socketOrder string
  39. hookableSignals []os.Signal
  40. )
  41. func init() {
  42. runningServerReg = sync.RWMutex{}
  43. runningServers = make(map[string]*endlessServer)
  44. runningServersOrder = []string{}
  45. socketPtrOffsetMap = make(map[string]uint)
  46. DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB)
  47. // after a restart the parent will finish ongoing requests before
  48. // shutting down. set to a negative value to disable
  49. DefaultHammerTime = 60 * time.Second
  50. hookableSignals = []os.Signal{
  51. syscall.SIGHUP,
  52. syscall.SIGUSR1,
  53. syscall.SIGUSR2,
  54. syscall.SIGINT,
  55. syscall.SIGTERM,
  56. syscall.SIGTSTP,
  57. }
  58. }
  59. type endlessServer struct {
  60. http.Server
  61. EndlessListener net.Listener
  62. SignalHooks map[int]map[os.Signal][]func()
  63. tlsInnerListener *endlessListener
  64. wg sync.WaitGroup
  65. sigChan chan os.Signal
  66. isChild bool
  67. state uint8
  68. lock *sync.RWMutex
  69. BeforeBegin func(add string)
  70. }
  71. /*
  72. NewServer returns an intialized endlessServer Object. Calling Serve on it will
  73. actually "start" the server.
  74. */
  75. func NewServer(addr string, handler http.Handler) (srv *endlessServer) {
  76. runningServerReg.Lock()
  77. defer runningServerReg.Unlock()
  78. socketOrder = os.Getenv("ENDLESS_SOCKET_ORDER")
  79. isChild = os.Getenv("ENDLESS_CONTINUE") != ""
  80. if len(socketOrder) > 0 {
  81. for i, addr := range strings.Split(socketOrder, ",") {
  82. socketPtrOffsetMap[addr] = uint(i)
  83. }
  84. } else {
  85. socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
  86. }
  87. srv = &endlessServer{
  88. wg: sync.WaitGroup{},
  89. sigChan: make(chan os.Signal),
  90. isChild: isChild,
  91. SignalHooks: map[int]map[os.Signal][]func(){
  92. PRE_SIGNAL: map[os.Signal][]func(){
  93. syscall.SIGHUP: []func(){},
  94. syscall.SIGUSR1: []func(){},
  95. syscall.SIGUSR2: []func(){},
  96. syscall.SIGINT: []func(){},
  97. syscall.SIGTERM: []func(){},
  98. syscall.SIGTSTP: []func(){},
  99. },
  100. POST_SIGNAL: map[os.Signal][]func(){
  101. syscall.SIGHUP: []func(){},
  102. syscall.SIGUSR1: []func(){},
  103. syscall.SIGUSR2: []func(){},
  104. syscall.SIGINT: []func(){},
  105. syscall.SIGTERM: []func(){},
  106. syscall.SIGTSTP: []func(){},
  107. },
  108. },
  109. state: STATE_INIT,
  110. lock: &sync.RWMutex{},
  111. }
  112. srv.Server.Addr = addr
  113. srv.Server.ReadTimeout = DefaultReadTimeOut
  114. srv.Server.WriteTimeout = DefaultWriteTimeOut
  115. srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
  116. srv.Server.Handler = handler
  117. srv.BeforeBegin = func(addr string) {
  118. log.Println(syscall.Getpid(), addr)
  119. }
  120. runningServersOrder = append(runningServersOrder, addr)
  121. runningServers[addr] = srv
  122. return
  123. }
  124. /*
  125. ListenAndServe listens on the TCP network address addr and then calls Serve
  126. with handler to handle requests on incoming connections. Handler is typically
  127. nil, in which case the DefaultServeMux is used.
  128. */
  129. func ListenAndServe(addr string, handler http.Handler) error {
  130. server := NewServer(addr, handler)
  131. return server.ListenAndServe()
  132. }
  133. /*
  134. ListenAndServeTLS acts identically to ListenAndServe, except that it expects
  135. HTTPS connections. Additionally, files containing a certificate and matching
  136. private key for the server must be provided. If the certificate is signed by a
  137. certificate authority, the certFile should be the concatenation of the server's
  138. certificate followed by the CA's certificate.
  139. */
  140. func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
  141. server := NewServer(addr, handler)
  142. return server.ListenAndServeTLS(certFile, keyFile)
  143. }
  144. func (srv *endlessServer) getState() uint8 {
  145. srv.lock.RLock()
  146. defer srv.lock.RUnlock()
  147. return srv.state
  148. }
  149. func (srv *endlessServer) setState(st uint8) {
  150. srv.lock.Lock()
  151. defer srv.lock.Unlock()
  152. srv.state = st
  153. }
  154. /*
  155. Serve accepts incoming HTTP connections on the listener l, creating a new
  156. service goroutine for each. The service goroutines read requests and then call
  157. handler to reply to them. Handler is typically nil, in which case the
  158. DefaultServeMux is used.
  159. In addition to the stl Serve behaviour each connection is added to a
  160. sync.Waitgroup so that all outstanding connections can be served before shutting
  161. down the server.
  162. */
  163. func (srv *endlessServer) Serve() (err error) {
  164. defer log.Println(syscall.Getpid(), "Serve() returning...")
  165. srv.setState(STATE_RUNNING)
  166. err = srv.Server.Serve(srv.EndlessListener)
  167. log.Println(syscall.Getpid(), "Waiting for connections to finish...")
  168. srv.wg.Wait()
  169. srv.setState(STATE_TERMINATE)
  170. return
  171. }
  172. /*
  173. ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
  174. to handle requests on incoming connections. If srv.Addr is blank, ":http" is
  175. used.
  176. */
  177. func (srv *endlessServer) ListenAndServe() (err error) {
  178. addr := srv.Addr
  179. if addr == "" {
  180. addr = ":http"
  181. }
  182. go srv.handleSignals()
  183. l, err := srv.getListener(addr)
  184. if err != nil {
  185. log.Println(err)
  186. return
  187. }
  188. srv.EndlessListener = newEndlessListener(l, srv)
  189. if srv.isChild {
  190. syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
  191. }
  192. srv.BeforeBegin(srv.Addr)
  193. return srv.Serve()
  194. }
  195. /*
  196. ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
  197. Serve to handle requests on incoming TLS connections.
  198. Filenames containing a certificate and matching private key for the server must
  199. be provided. If the certificate is signed by a certificate authority, the
  200. certFile should be the concatenation of the server's certificate followed by the
  201. CA's certificate.
  202. If srv.Addr is blank, ":https" is used.
  203. */
  204. func (srv *endlessServer) ListenAndServeTLS(certFile, keyFile string) (err error) {
  205. addr := srv.Addr
  206. if addr == "" {
  207. addr = ":https"
  208. }
  209. config := &tls.Config{}
  210. if srv.TLSConfig != nil {
  211. *config = *srv.TLSConfig
  212. }
  213. if config.NextProtos == nil {
  214. config.NextProtos = []string{"http/1.1"}
  215. }
  216. config.Certificates = make([]tls.Certificate, 1)
  217. config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  218. if err != nil {
  219. return
  220. }
  221. go srv.handleSignals()
  222. l, err := srv.getListener(addr)
  223. if err != nil {
  224. log.Println(err)
  225. return
  226. }
  227. srv.tlsInnerListener = newEndlessListener(l, srv)
  228. srv.EndlessListener = tls.NewListener(srv.tlsInnerListener, config)
  229. if srv.isChild {
  230. syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
  231. }
  232. log.Println(syscall.Getpid(), srv.Addr)
  233. return srv.Serve()
  234. }
  235. /*
  236. getListener either opens a new socket to listen on, or takes the acceptor socket
  237. it got passed when restarted.
  238. */
  239. func (srv *endlessServer) getListener(laddr string) (l net.Listener, err error) {
  240. if srv.isChild {
  241. var ptrOffset uint = 0
  242. runningServerReg.RLock()
  243. defer runningServerReg.RUnlock()
  244. if len(socketPtrOffsetMap) > 0 {
  245. ptrOffset = socketPtrOffsetMap[laddr]
  246. // log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
  247. }
  248. f := os.NewFile(uintptr(3+ptrOffset), "")
  249. l, err = net.FileListener(f)
  250. if err != nil {
  251. err = fmt.Errorf("net.FileListener error: %v", err)
  252. return
  253. }
  254. } else {
  255. l, err = net.Listen("tcp", laddr)
  256. if err != nil {
  257. err = fmt.Errorf("net.Listen error: %v", err)
  258. return
  259. }
  260. }
  261. return
  262. }
  263. /*
  264. handleSignals listens for os Signals and calls any hooked in function that the
  265. user had registered with the signal.
  266. */
  267. func (srv *endlessServer) handleSignals() {
  268. var sig os.Signal
  269. signal.Notify(
  270. srv.sigChan,
  271. hookableSignals...,
  272. )
  273. pid := syscall.Getpid()
  274. for {
  275. sig = <-srv.sigChan
  276. srv.signalHooks(PRE_SIGNAL, sig)
  277. switch sig {
  278. case syscall.SIGHUP:
  279. log.Println(pid, "Received SIGHUP. forking.")
  280. err := srv.fork()
  281. if err != nil {
  282. log.Println("Fork err:", err)
  283. }
  284. case syscall.SIGUSR1:
  285. log.Println(pid, "Received SIGUSR1.")
  286. case syscall.SIGUSR2:
  287. log.Println(pid, "Received SIGUSR2.")
  288. srv.hammerTime(0 * time.Second)
  289. case syscall.SIGINT:
  290. log.Println(pid, "Received SIGINT.")
  291. srv.shutdown()
  292. case syscall.SIGTERM:
  293. log.Println(pid, "Received SIGTERM.")
  294. srv.shutdown()
  295. case syscall.SIGTSTP:
  296. log.Println(pid, "Received SIGTSTP.")
  297. default:
  298. log.Printf("Received %v: nothing i care about...\n", sig)
  299. }
  300. srv.signalHooks(POST_SIGNAL, sig)
  301. }
  302. }
  303. func (srv *endlessServer) signalHooks(ppFlag int, sig os.Signal) {
  304. if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
  305. return
  306. }
  307. for _, f := range srv.SignalHooks[ppFlag][sig] {
  308. f()
  309. }
  310. return
  311. }
  312. /*
  313. shutdown closes the listener so that no new connections are accepted. it also
  314. starts a goroutine that will hammer (stop all running requests) the server
  315. after DefaultHammerTime.
  316. */
  317. func (srv *endlessServer) shutdown() {
  318. if srv.getState() != STATE_RUNNING {
  319. return
  320. }
  321. srv.setState(STATE_SHUTTING_DOWN)
  322. if DefaultHammerTime >= 0 {
  323. go srv.hammerTime(DefaultHammerTime)
  324. }
  325. // disable keep-alives on existing connections
  326. srv.SetKeepAlivesEnabled(false)
  327. err := srv.EndlessListener.Close()
  328. if err != nil {
  329. log.Println(syscall.Getpid(), "Listener.Close() error:", err)
  330. } else {
  331. log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Listener closed.")
  332. }
  333. }
  334. /*
  335. hammerTime forces the server to shutdown in a given timeout - whether it
  336. finished outstanding requests or not. if Read/WriteTimeout are not set or the
  337. max header size is very big a connection could hang...
  338. srv.Serve() will not return until all connections are served. this will
  339. unblock the srv.wg.Wait() in Serve() thus causing ListenAndServe(TLS) to
  340. return.
  341. */
  342. func (srv *endlessServer) hammerTime(d time.Duration) {
  343. defer func() {
  344. // we are calling srv.wg.Done() until it panics which means we called
  345. // Done() when the counter was already at 0 and we're done.
  346. // (and thus Serve() will return and the parent will exit)
  347. if r := recover(); r != nil {
  348. log.Println("WaitGroup at 0", r)
  349. }
  350. }()
  351. if srv.getState() != STATE_SHUTTING_DOWN {
  352. return
  353. }
  354. time.Sleep(d)
  355. log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
  356. for {
  357. if srv.getState() == STATE_TERMINATE {
  358. break
  359. }
  360. srv.wg.Done()
  361. runtime.Gosched()
  362. }
  363. }
  364. func (srv *endlessServer) fork() (err error) {
  365. runningServerReg.Lock()
  366. defer runningServerReg.Unlock()
  367. // only one server instance should fork!
  368. if runningServersForked {
  369. return errors.New("Another process already forked. Ignoring this one.")
  370. }
  371. runningServersForked = true
  372. var files = make([]*os.File, len(runningServers))
  373. var orderArgs = make([]string, len(runningServers))
  374. // get the accessor socket fds for _all_ server instances
  375. for _, srvPtr := range runningServers {
  376. // introspect.PrintTypeDump(srvPtr.EndlessListener)
  377. switch srvPtr.EndlessListener.(type) {
  378. case *endlessListener:
  379. // normal listener
  380. files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.EndlessListener.(*endlessListener).File()
  381. default:
  382. // tls listener
  383. files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
  384. }
  385. orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
  386. }
  387. env := append(
  388. os.Environ(),
  389. "ENDLESS_CONTINUE=1",
  390. )
  391. if len(runningServers) > 1 {
  392. env = append(env, fmt.Sprintf(`ENDLESS_SOCKET_ORDER=%s`, strings.Join(orderArgs, ",")))
  393. }
  394. // log.Println(files)
  395. path := os.Args[0]
  396. var args []string
  397. if len(os.Args) > 1 {
  398. args = os.Args[1:]
  399. }
  400. cmd := exec.Command(path, args...)
  401. cmd.Stdout = os.Stdout
  402. cmd.Stderr = os.Stderr
  403. cmd.ExtraFiles = files
  404. cmd.Env = env
  405. // cmd.SysProcAttr = &syscall.SysProcAttr{
  406. // Setsid: true,
  407. // Setctty: true,
  408. // Ctty: ,
  409. // }
  410. err = cmd.Start()
  411. if err != nil {
  412. log.Fatalf("Restart: Failed to launch, error: %v", err)
  413. }
  414. return
  415. }
  416. type endlessListener struct {
  417. net.Listener
  418. stopped bool
  419. server *endlessServer
  420. }
  421. func (el *endlessListener) Accept() (c net.Conn, err error) {
  422. tc, err := el.Listener.(*net.TCPListener).AcceptTCP()
  423. if err != nil {
  424. return
  425. }
  426. tc.SetKeepAlive(true) // see http.tcpKeepAliveListener
  427. tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
  428. c = endlessConn{
  429. Conn: tc,
  430. server: el.server,
  431. }
  432. el.server.wg.Add(1)
  433. return
  434. }
  435. func newEndlessListener(l net.Listener, srv *endlessServer) (el *endlessListener) {
  436. el = &endlessListener{
  437. Listener: l,
  438. server: srv,
  439. }
  440. return
  441. }
  442. func (el *endlessListener) Close() error {
  443. if el.stopped {
  444. return syscall.EINVAL
  445. }
  446. el.stopped = true
  447. return el.Listener.Close()
  448. }
  449. func (el *endlessListener) File() *os.File {
  450. // returns a dup(2) - FD_CLOEXEC flag *not* set
  451. tl := el.Listener.(*net.TCPListener)
  452. fl, _ := tl.File()
  453. return fl
  454. }
  455. type endlessConn struct {
  456. net.Conn
  457. server *endlessServer
  458. }
  459. func (w endlessConn) Close() error {
  460. err := w.Conn.Close()
  461. if err == nil {
  462. w.server.wg.Done()
  463. }
  464. return err
  465. }
  466. /*
  467. RegisterSignalHook registers a function to be run PRE_SIGNAL or POST_SIGNAL for
  468. a given signal. PRE or POST in this case means before or after the signal
  469. related code endless itself runs
  470. */
  471. func (srv *endlessServer) RegisterSignalHook(prePost int, sig os.Signal, f func()) (err error) {
  472. if prePost != PRE_SIGNAL && prePost != POST_SIGNAL {
  473. err = fmt.Errorf("Cannot use %v for prePost arg. Must be endless.PRE_SIGNAL or endless.POST_SIGNAL.", sig)
  474. return
  475. }
  476. for _, s := range hookableSignals {
  477. if s == sig {
  478. srv.SignalHooks[prePost][sig] = append(srv.SignalHooks[prePost][sig], f)
  479. return
  480. }
  481. }
  482. err = fmt.Errorf("Signal %v is not supported.", sig)
  483. return
  484. }