endless.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. // +build linux
  2. package endless
  3. import (
  4. "crypto/tls"
  5. "errors"
  6. "flag"
  7. "fmt"
  8. "log"
  9. "net"
  10. "net/http"
  11. "os"
  12. "os/exec"
  13. "os/signal"
  14. "runtime"
  15. "strings"
  16. "sync"
  17. "syscall"
  18. "time"
  19. // "github.com/fvbock/uds-go/introspect"
  20. )
  21. const (
  22. PRE_SIGNAL = iota
  23. POST_SIGNAL
  24. STATE_INIT
  25. STATE_RUNNING
  26. STATE_SHUTTING_DOWN
  27. STATE_TERMINATE
  28. )
  29. var (
  30. runningServerReg sync.RWMutex
  31. runningServers map[string]*endlessServer
  32. runningServersOrder []string
  33. socketPtrOffsetMap map[string]uint
  34. runningServersForked bool
  35. DefaultReadTimeOut time.Duration
  36. DefaultWriteTimeOut time.Duration
  37. DefaultMaxHeaderBytes int
  38. DefaultHammerTime time.Duration
  39. isChild bool
  40. socketOrder string
  41. hookableSignals []os.Signal
  42. )
  43. func init() {
  44. flag.BoolVar(&isChild, "continue", false, "listen on open fd (after forking)")
  45. flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
  46. runningServerReg = sync.RWMutex{}
  47. runningServers = make(map[string]*endlessServer)
  48. runningServersOrder = []string{}
  49. socketPtrOffsetMap = make(map[string]uint)
  50. DefaultMaxHeaderBytes = 0 // use http.DefaultMaxHeaderBytes - which currently is 1 << 20 (1MB)
  51. // after a restart the parent will finish ongoing requests before
  52. // shutting down. set to a negative value to disable
  53. DefaultHammerTime = 60 * time.Second
  54. hookableSignals = []os.Signal{
  55. syscall.SIGHUP,
  56. syscall.SIGUSR1,
  57. syscall.SIGUSR2,
  58. syscall.SIGINT,
  59. syscall.SIGTERM,
  60. syscall.SIGTSTP,
  61. }
  62. }
  63. type endlessServer struct {
  64. http.Server
  65. EndlessListener net.Listener
  66. SignalHooks map[int]map[os.Signal][]func()
  67. tlsInnerListener *endlessListener
  68. wg sync.WaitGroup
  69. sigChan chan os.Signal
  70. isChild bool
  71. state uint8
  72. lock *sync.RWMutex
  73. }
  74. /*
  75. NewServer returns an intialized endlessServer Object. Calling Serve on it will
  76. actually "start" the server.
  77. */
  78. func NewServer(addr string, handler http.Handler, fn func()) (srv *endlessServer) {
  79. runningServerReg.Lock()
  80. defer runningServerReg.Unlock()
  81. if !flag.Parsed() {
  82. flag.Parse()
  83. }
  84. if len(socketOrder) > 0 {
  85. for i, addr := range strings.Split(socketOrder, ",") {
  86. socketPtrOffsetMap[addr] = uint(i)
  87. }
  88. } else {
  89. socketPtrOffsetMap[addr] = uint(len(runningServersOrder))
  90. }
  91. srv = &endlessServer{
  92. wg: sync.WaitGroup{},
  93. sigChan: make(chan os.Signal),
  94. isChild: isChild,
  95. SignalHooks: map[int]map[os.Signal][]func(){
  96. PRE_SIGNAL: map[os.Signal][]func(){
  97. syscall.SIGHUP: []func(){
  98. fn,
  99. },
  100. syscall.SIGUSR1: []func(){},
  101. syscall.SIGUSR2: []func(){},
  102. syscall.SIGINT: []func(){},
  103. syscall.SIGTERM: []func(){},
  104. syscall.SIGTSTP: []func(){},
  105. },
  106. POST_SIGNAL: map[os.Signal][]func(){
  107. syscall.SIGHUP: []func(){},
  108. syscall.SIGUSR1: []func(){},
  109. syscall.SIGUSR2: []func(){},
  110. syscall.SIGINT: []func(){},
  111. syscall.SIGTERM: []func(){},
  112. syscall.SIGTSTP: []func(){},
  113. },
  114. },
  115. state: STATE_INIT,
  116. lock: &sync.RWMutex{},
  117. }
  118. srv.Server.Addr = addr
  119. srv.Server.ReadTimeout = DefaultReadTimeOut
  120. srv.Server.WriteTimeout = DefaultWriteTimeOut
  121. srv.Server.MaxHeaderBytes = DefaultMaxHeaderBytes
  122. srv.Server.Handler = handler
  123. runningServersOrder = append(runningServersOrder, addr)
  124. runningServers[addr] = srv
  125. return
  126. }
  127. func NetListen(addr string, handler http.Handler, fn func()) (net.Listener, error) {
  128. server := NewServer(addr, handler, fn)
  129. return server.ListenAndServe(false)
  130. }
  131. /*
  132. ListenAndServe listens on the TCP network address addr and then calls Serve
  133. with handler to handle requests on incoming connections. Handler is typically
  134. nil, in which case the DefaultServeMux is used.
  135. */
  136. func ListenAndServe(addr string, handler http.Handler, fn func()) error {
  137. server := NewServer(addr, handler, fn)
  138. _, err := server.ListenAndServe(true)
  139. return err
  140. }
  141. /*
  142. ListenAndServeTLS acts identically to ListenAndServe, except that it expects
  143. HTTPS connections. Additionally, files containing a certificate and matching
  144. private key for the server must be provided. If the certificate is signed by a
  145. certificate authority, the certFile should be the concatenation of the server's
  146. certificate followed by the CA's certificate.
  147. */
  148. func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler, destoryfn func()) error {
  149. server := NewServer(addr, handler, destoryfn)
  150. return server.ListenAndServeTLS(certFile, keyFile, true)
  151. }
  152. func (srv *endlessServer) getState() uint8 {
  153. srv.lock.RLock()
  154. defer srv.lock.RUnlock()
  155. return srv.state
  156. }
  157. func (srv *endlessServer) setState(st uint8) {
  158. srv.lock.Lock()
  159. defer srv.lock.Unlock()
  160. srv.state = st
  161. }
  162. /*
  163. Serve accepts incoming HTTP connections on the listener l, creating a new
  164. service goroutine for each. The service goroutines read requests and then call
  165. handler to reply to them. Handler is typically nil, in which case the
  166. DefaultServeMux is used.
  167. In addition to the stl Serve behaviour each connection is added to a
  168. sync.Waitgroup so that all outstanding connections can be served before shutting
  169. down the server.
  170. */
  171. func (srv *endlessServer) Serve() (err error) {
  172. defer log.Println(syscall.Getpid(), "Serve() returning...")
  173. srv.setState(STATE_RUNNING)
  174. err = srv.Server.Serve(srv.EndlessListener)
  175. log.Println(syscall.Getpid(), "Waiting for connections to finish...")
  176. srv.wg.Wait()
  177. srv.setState(STATE_TERMINATE)
  178. return
  179. }
  180. /*
  181. ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
  182. to handle requests on incoming connections. If srv.Addr is blank, ":http" is
  183. used.
  184. */
  185. func (srv *endlessServer) ListenAndServe(isServer bool) (l net.Listener, err error) {
  186. addr := srv.Addr
  187. if addr == "" {
  188. addr = ":http"
  189. }
  190. go srv.handleSignals(isServer)
  191. l, err = srv.getListener(addr)
  192. if err != nil {
  193. log.Println(err)
  194. return
  195. }
  196. srv.EndlessListener = newEndlessListener(l, srv)
  197. if srv.isChild {
  198. syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
  199. }
  200. log.Println(syscall.Getpid(), srv.Addr)
  201. if isServer {
  202. return l, srv.Serve()
  203. } else {
  204. return l, err
  205. }
  206. }
  207. /*
  208. ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
  209. Serve to handle requests on incoming TLS connections.
  210. Filenames containing a certificate and matching private key for the server must
  211. be provided. If the certificate is signed by a certificate authority, the
  212. certFile should be the concatenation of the server's certificate followed by the
  213. CA's certificate.
  214. If srv.Addr is blank, ":https" is used.
  215. */
  216. func (srv *endlessServer) ListenAndServeTLS(certFile, keyFile string, isServer bool) (err error) {
  217. addr := srv.Addr
  218. if addr == "" {
  219. addr = ":https"
  220. }
  221. config := &tls.Config{}
  222. if srv.TLSConfig != nil {
  223. *config = *srv.TLSConfig
  224. }
  225. if config.NextProtos == nil {
  226. config.NextProtos = []string{"http/1.1"}
  227. }
  228. config.Certificates = make([]tls.Certificate, 1)
  229. config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
  230. if err != nil {
  231. return
  232. }
  233. go srv.handleSignals(isServer)
  234. l, err := srv.getListener(addr)
  235. if err != nil {
  236. log.Println(err)
  237. return
  238. }
  239. srv.tlsInnerListener = newEndlessListener(l, srv)
  240. srv.EndlessListener = tls.NewListener(srv.tlsInnerListener, config)
  241. if srv.isChild {
  242. syscall.Kill(syscall.Getppid(), syscall.SIGTERM)
  243. }
  244. log.Println(syscall.Getpid(), srv.Addr)
  245. return srv.Serve()
  246. }
  247. /*
  248. getListener either opens a new socket to listen on, or takes the acceptor socket
  249. it got passed when restarted.
  250. */
  251. func (srv *endlessServer) getListener(laddr string) (l net.Listener, err error) {
  252. if srv.isChild {
  253. var ptrOffset uint = 0
  254. runningServerReg.RLock()
  255. defer runningServerReg.RUnlock()
  256. if len(socketPtrOffsetMap) > 0 {
  257. ptrOffset = socketPtrOffsetMap[laddr]
  258. // log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
  259. }
  260. f := os.NewFile(uintptr(3+ptrOffset), "")
  261. l, err = net.FileListener(f)
  262. if err != nil {
  263. err = fmt.Errorf("net.FileListener error: %v", err)
  264. return
  265. }
  266. } else {
  267. l, err = net.Listen("tcp", laddr)
  268. if err != nil {
  269. err = fmt.Errorf("net.Listen error: %v", err)
  270. return
  271. }
  272. }
  273. return
  274. }
  275. /*
  276. handleSignals listens for os Signals and calls any hooked in function that the
  277. user had registered with the signal.
  278. */
  279. func (srv *endlessServer) handleSignals(isServer bool) {
  280. var sig os.Signal
  281. signal.Notify(
  282. srv.sigChan,
  283. hookableSignals...,
  284. )
  285. pid := syscall.Getpid()
  286. for {
  287. sig = <-srv.sigChan
  288. srv.signalHooks(PRE_SIGNAL, sig)
  289. switch sig {
  290. case syscall.SIGHUP:
  291. log.Println(pid, "Received SIGHUP. forking.")
  292. err := srv.fork()
  293. if err != nil {
  294. log.Println("Fork err:", err)
  295. }
  296. case syscall.SIGUSR1:
  297. log.Println(pid, "Received SIGUSR1.")
  298. case syscall.SIGUSR2:
  299. log.Println(pid, "Received SIGUSR2.")
  300. srv.hammerTime(0 * time.Second)
  301. case syscall.SIGINT:
  302. log.Println(pid, "Received SIGINT.")
  303. srv.shutdown(isServer)
  304. case syscall.SIGTERM:
  305. log.Println(pid, "Received SIGTERM.")
  306. srv.shutdown(isServer)
  307. case syscall.SIGTSTP:
  308. log.Println(pid, "Received SIGTSTP.")
  309. default:
  310. log.Printf("Received %v: nothing i care about...\n", sig)
  311. }
  312. srv.signalHooks(POST_SIGNAL, sig)
  313. }
  314. }
  315. func (srv *endlessServer) signalHooks(ppFlag int, sig os.Signal) {
  316. if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
  317. return
  318. }
  319. for _, f := range srv.SignalHooks[ppFlag][sig] {
  320. f()
  321. }
  322. return
  323. }
  324. /*
  325. shutdown closes the listener so that no new connections are accepted. it also
  326. starts a goroutine that will hammer (stop all running requests) the server
  327. after DefaultHammerTime.
  328. */
  329. func (srv *endlessServer) shutdown(isServer bool) {
  330. if isServer && srv.getState() != STATE_RUNNING {
  331. return
  332. }
  333. srv.setState(STATE_SHUTTING_DOWN)
  334. if DefaultHammerTime >= 0 {
  335. go srv.hammerTime(DefaultHammerTime)
  336. }
  337. // disable keep-alives on existing connections
  338. srv.SetKeepAlivesEnabled(false)
  339. err := srv.EndlessListener.Close()
  340. if err != nil {
  341. log.Println(syscall.Getpid(), "Listener.Close() error:", err)
  342. } else {
  343. log.Println(syscall.Getpid(), srv.EndlessListener.Addr(), "Listener closed.")
  344. }
  345. }
  346. /*
  347. hammerTime forces the server to shutdown in a given timeout - whether it
  348. finished outstanding requests or not. if Read/WriteTimeout are not set or the
  349. max header size is very big a connection could hang...
  350. srv.Serve() will not return until all connections are served. this will
  351. unblock the srv.wg.Wait() in Serve() thus causing ListenAndServe(TLS) to
  352. return.
  353. */
  354. func (srv *endlessServer) hammerTime(d time.Duration) {
  355. defer func() {
  356. // we are calling srv.wg.Done() until it panics which means we called
  357. // Done() when the counter was already at 0 and we're done.
  358. // (and thus Serve() will return and the parent will exit)
  359. if r := recover(); r != nil {
  360. log.Println("WaitGroup at 0", r)
  361. }
  362. }()
  363. if srv.getState() != STATE_SHUTTING_DOWN {
  364. return
  365. }
  366. time.Sleep(d)
  367. log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
  368. for {
  369. if srv.getState() == STATE_TERMINATE {
  370. break
  371. }
  372. srv.wg.Done()
  373. runtime.Gosched()
  374. }
  375. }
  376. func (srv *endlessServer) fork() (err error) {
  377. runningServerReg.Lock()
  378. defer runningServerReg.Unlock()
  379. // only one server isntance should fork!
  380. if runningServersForked {
  381. return errors.New("Another process already forked. Ignoring this one.")
  382. }
  383. runningServersForked = true
  384. var files = make([]*os.File, len(runningServers))
  385. var orderArgs = make([]string, len(runningServers))
  386. // get the accessor socket fds for _all_ server instances
  387. for _, srvPtr := range runningServers {
  388. // introspect.PrintTypeDump(srvPtr.EndlessListener)
  389. switch srvPtr.EndlessListener.(type) {
  390. case *endlessListener:
  391. // normal listener
  392. files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.EndlessListener.(*endlessListener).File()
  393. default:
  394. // tls listener
  395. files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
  396. }
  397. orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
  398. }
  399. // log.Println(files)
  400. path := os.Args[0]
  401. var args []string
  402. if len(os.Args) > 1 {
  403. for _, arg := range os.Args[1:] {
  404. if arg == "-continue" {
  405. break
  406. }
  407. args = append(args, arg)
  408. }
  409. }
  410. args = append(args, "-continue")
  411. if len(runningServers) > 1 {
  412. args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
  413. // log.Println(args)
  414. }
  415. cmd := exec.Command(path, args...)
  416. cmd.Stdout = os.Stdout
  417. cmd.Stderr = os.Stderr
  418. cmd.ExtraFiles = files
  419. // cmd.SysProcAttr = &syscall.SysProcAttr{
  420. // Setsid: true,
  421. // Setctty: true,
  422. // Ctty: ,
  423. // }
  424. err = cmd.Start()
  425. if err != nil {
  426. log.Fatalf("Restart: Failed to launch, error: %v", err)
  427. }
  428. return
  429. }
  430. type endlessListener struct {
  431. net.Listener
  432. stopped bool
  433. server *endlessServer
  434. }
  435. func (el *endlessListener) Accept() (c net.Conn, err error) {
  436. tc, err := el.Listener.(*net.TCPListener).AcceptTCP()
  437. if err != nil {
  438. return
  439. }
  440. tc.SetKeepAlive(true) // see http.tcpKeepAliveListener
  441. tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
  442. c = endlessConn{
  443. Conn: tc,
  444. server: el.server,
  445. }
  446. el.server.wg.Add(1)
  447. return
  448. }
  449. func newEndlessListener(l net.Listener, srv *endlessServer) (el *endlessListener) {
  450. el = &endlessListener{
  451. Listener: l,
  452. server: srv,
  453. }
  454. return
  455. }
  456. func (el *endlessListener) Close() error {
  457. if el.stopped {
  458. return syscall.EINVAL
  459. }
  460. el.stopped = true
  461. return el.Listener.Close()
  462. }
  463. func (el *endlessListener) File() *os.File {
  464. // returns a dup(2) - FD_CLOEXEC flag *not* set
  465. tl := el.Listener.(*net.TCPListener)
  466. fl, _ := tl.File()
  467. return fl
  468. }
  469. type endlessConn struct {
  470. net.Conn
  471. server *endlessServer
  472. }
  473. func (w endlessConn) Close() error {
  474. err := w.Conn.Close()
  475. if err == nil {
  476. w.server.wg.Done()
  477. }
  478. return err
  479. }
  480. /*
  481. RegisterSignalHook registers a function to be run PRE_SIGNAL or POST_SIGNAL for
  482. a given signal. PRE or POST in this case means before or after the signal
  483. related code endless itself runs
  484. */
  485. func (srv *endlessServer) RegisterSignalHook(prePost int, sig os.Signal, f func()) (err error) {
  486. if prePost != PRE_SIGNAL && prePost != POST_SIGNAL {
  487. err = fmt.Errorf("Cannot use %v for prePost arg. Must be endless.PRE_SIGNAL or endless.POST_SIGNAL.")
  488. return
  489. }
  490. for _, s := range hookableSignals {
  491. if s == sig {
  492. srv.SignalHooks[prePost][sig] = append(srv.SignalHooks[prePost][sig], f)
  493. return
  494. }
  495. }
  496. err = fmt.Errorf("Signal %v is not supported.")
  497. return
  498. }