recovery.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. // Copyright 2014 Manu Martinez-Almeida. All rights reserved.
  2. // Use of this source code is governed by a MIT style
  3. // license that can be found in the LICENSE file.
  4. package gin
  5. import (
  6. "bytes"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "io/ioutil"
  11. "log"
  12. "net"
  13. "net/http"
  14. "net/http/httputil"
  15. "os"
  16. "runtime"
  17. "strings"
  18. "time"
  19. )
  20. var (
  21. dunno = []byte("???")
  22. centerDot = []byte("·")
  23. dot = []byte(".")
  24. slash = []byte("/")
  25. )
  26. // RecoveryFunc defines the function passable to CustomRecovery.
  27. type RecoveryFunc func(c *Context, err any)
  28. // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
  29. func Recovery() HandlerFunc {
  30. return RecoveryWithWriter(DefaultErrorWriter)
  31. }
  32. // CustomRecovery returns a middleware that recovers from any panics and calls the provided handle func to handle it.
  33. func CustomRecovery(handle RecoveryFunc) HandlerFunc {
  34. return RecoveryWithWriter(DefaultErrorWriter, handle)
  35. }
  36. // RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one.
  37. func RecoveryWithWriter(out io.Writer, recovery ...RecoveryFunc) HandlerFunc {
  38. if len(recovery) > 0 {
  39. return CustomRecoveryWithWriter(out, recovery[0])
  40. }
  41. return CustomRecoveryWithWriter(out, defaultHandleRecovery)
  42. }
  43. // CustomRecoveryWithWriter returns a middleware for a given writer that recovers from any panics and calls the provided handle func to handle it.
  44. func CustomRecoveryWithWriter(out io.Writer, handle RecoveryFunc) HandlerFunc {
  45. var logger *log.Logger
  46. if out != nil {
  47. logger = log.New(out, "\n\n\x1b[31m", log.LstdFlags)
  48. }
  49. return func(c *Context) {
  50. defer func() {
  51. if err := recover(); err != nil {
  52. // Check for a broken connection, as it is not really a
  53. // condition that warrants a panic stack trace.
  54. var brokenPipe bool
  55. if ne, ok := err.(*net.OpError); ok {
  56. var se *os.SyscallError
  57. if errors.As(ne, &se) {
  58. if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
  59. brokenPipe = true
  60. }
  61. }
  62. }
  63. if logger != nil {
  64. stack := stack(3)
  65. httpRequest, _ := httputil.DumpRequest(c.Request, false)
  66. headers := strings.Split(string(httpRequest), "\r\n")
  67. for idx, header := range headers {
  68. current := strings.Split(header, ":")
  69. if current[0] == "Authorization" {
  70. headers[idx] = current[0] + ": *"
  71. }
  72. }
  73. headersToStr := strings.Join(headers, "\r\n")
  74. if brokenPipe {
  75. logger.Printf("%s\n%s%s", err, headersToStr, reset)
  76. } else if IsDebugging() {
  77. logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
  78. timeFormat(time.Now()), headersToStr, err, stack, reset)
  79. } else {
  80. logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s",
  81. timeFormat(time.Now()), err, stack, reset)
  82. }
  83. }
  84. if brokenPipe {
  85. // If the connection is dead, we can't write a status to it.
  86. c.Error(err.(error)) // nolint: errcheck
  87. c.Abort()
  88. } else {
  89. handle(c, err)
  90. }
  91. }
  92. }()
  93. c.Next()
  94. }
  95. }
  96. func defaultHandleRecovery(c *Context, err any) {
  97. c.AbortWithStatus(http.StatusInternalServerError)
  98. }
  99. // stack returns a nicely formatted stack frame, skipping skip frames.
  100. func stack(skip int) []byte {
  101. buf := new(bytes.Buffer) // the returned data
  102. // As we loop, we open files and read them. These variables record the currently
  103. // loaded file.
  104. var lines [][]byte
  105. var lastFile string
  106. for i := skip; ; i++ { // Skip the expected number of frames
  107. pc, file, line, ok := runtime.Caller(i)
  108. if !ok {
  109. break
  110. }
  111. // Print this much at least. If we can't find the source, it won't show.
  112. fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc)
  113. if file != lastFile {
  114. data, err := ioutil.ReadFile(file)
  115. if err != nil {
  116. continue
  117. }
  118. lines = bytes.Split(data, []byte{'\n'})
  119. lastFile = file
  120. }
  121. fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line))
  122. }
  123. return buf.Bytes()
  124. }
  125. // source returns a space-trimmed slice of the n'th line.
  126. func source(lines [][]byte, n int) []byte {
  127. n-- // in stack trace, lines are 1-indexed but our array is 0-indexed
  128. if n < 0 || n >= len(lines) {
  129. return dunno
  130. }
  131. return bytes.TrimSpace(lines[n])
  132. }
  133. // function returns, if possible, the name of the function containing the PC.
  134. func function(pc uintptr) []byte {
  135. fn := runtime.FuncForPC(pc)
  136. if fn == nil {
  137. return dunno
  138. }
  139. name := []byte(fn.Name())
  140. // The name includes the path name to the package, which is unnecessary
  141. // since the file name is already included. Plus, it has center dots.
  142. // That is, we see
  143. // runtime/debug.*T·ptrmethod
  144. // and want
  145. // *T.ptrmethod
  146. // Also the package path might contain dot (e.g. code.google.com/...),
  147. // so first eliminate the path prefix
  148. if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 {
  149. name = name[lastSlash+1:]
  150. }
  151. if period := bytes.Index(name, dot); period >= 0 {
  152. name = name[period+1:]
  153. }
  154. name = bytes.Replace(name, centerDot, dot, -1)
  155. return name
  156. }
  157. // timeFormat returns a customized time string for logger.
  158. func timeFormat(t time.Time) string {
  159. return t.Format("2006/01/02 - 15:04:05")
  160. }