3abd0179 by astaxie

split into small files

1 parent ae376893
1 package grace
2
3 import "net"
4
5 type graceConn struct {
6 net.Conn
7 server *graceServer
8 }
9
10 func (c graceConn) Close() error {
11 c.server.wg.Done()
12 return c.Conn.Close()
13 }
...@@ -42,15 +42,9 @@ ...@@ -42,15 +42,9 @@
42 package grace 42 package grace
43 43
44 import ( 44 import (
45 "crypto/tls"
46 "flag" 45 "flag"
47 "fmt"
48 "log"
49 "net"
50 "net/http" 46 "net/http"
51 "os" 47 "os"
52 "os/exec"
53 "os/signal"
54 "strings" 48 "strings"
55 "sync" 49 "sync"
56 "syscall" 50 "syscall"
...@@ -93,25 +87,10 @@ func init() { ...@@ -93,25 +87,10 @@ func init() {
93 87
94 DefaultMaxHeaderBytes = 0 88 DefaultMaxHeaderBytes = 0
95 89
96 // after a restart the parent will finish ongoing requests before
97 // shutting down. set to a negative value to disable
98 DefaultTimeout = 60 * time.Second 90 DefaultTimeout = 60 * time.Second
99 } 91 }
100 92
101 type graceServer struct { 93 // NewServer returns a new graceServer.
102 *http.Server
103 GraceListener net.Listener
104 SignalHooks map[int]map[os.Signal][]func()
105 tlsInnerListener *graceListener
106 wg sync.WaitGroup
107 sigChan chan os.Signal
108 isChild bool
109 state uint8
110 Network string
111 }
112
113 // NewServer returns an intialized graceServer. Calling Serve on it will
114 // actually "start" the server.
115 func NewServer(addr string, handler http.Handler) (srv *graceServer) { 94 func NewServer(addr string, handler http.Handler) (srv *graceServer) {
116 regLock.Lock() 95 regLock.Lock()
117 defer regLock.Unlock() 96 defer regLock.Unlock()
...@@ -158,364 +137,14 @@ func NewServer(addr string, handler http.Handler) (srv *graceServer) { ...@@ -158,364 +137,14 @@ func NewServer(addr string, handler http.Handler) (srv *graceServer) {
158 return 137 return
159 } 138 }
160 139
161 // ListenAndServe listens on the TCP network address addr 140 // refer http.ListenAndServe
162 // and then calls Serve to handle requests on incoming connections.
163 func ListenAndServe(addr string, handler http.Handler) error { 141 func ListenAndServe(addr string, handler http.Handler) error {
164 server := NewServer(addr, handler) 142 server := NewServer(addr, handler)
165 return server.ListenAndServe() 143 return server.ListenAndServe()
166 } 144 }
167 145
168 // ListenAndServeTLS listens on the TCP network address addr and then calls 146 // refer http.ListenAndServeTLS
169 // Serve to handle requests on incoming TLS connections.
170 //
171 // Filenames containing a certificate and matching private key for the server must be provided.
172 // If the certificate is signed by a certificate authority,
173 // the certFile should be the concatenation of the server's certificate followed by the CA's certificate.
174 func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error { 147 func ListenAndServeTLS(addr string, certFile string, keyFile string, handler http.Handler) error {
175 server := NewServer(addr, handler) 148 server := NewServer(addr, handler)
176 return server.ListenAndServeTLS(certFile, keyFile) 149 return server.ListenAndServeTLS(certFile, keyFile)
177 } 150 }
178
179 // Serve accepts incoming connections on the Listener l,
180 // creating a new service goroutine for each.
181 // The service goroutines read requests and then call srv.Handler to reply to them.
182 func (srv *graceServer) Serve() (err error) {
183 srv.state = STATE_RUNNING
184 err = srv.Server.Serve(srv.GraceListener)
185 log.Println(syscall.Getpid(), "Waiting for connections to finish...")
186 srv.wg.Wait()
187 srv.state = STATE_TERMINATE
188 return
189 }
190
191 // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
192 // to handle requests on incoming connections. If srv.Addr is blank, ":http" is
193 // used.
194 func (srv *graceServer) ListenAndServe() (err error) {
195 addr := srv.Addr
196 if addr == "" {
197 addr = ":http"
198 }
199
200 go srv.handleSignals()
201
202 l, err := srv.getListener(addr)
203 if err != nil {
204 log.Println(err)
205 return err
206 }
207
208 srv.GraceListener = newGraceListener(l, srv)
209
210 if srv.isChild {
211 process, err := os.FindProcess(os.Getppid())
212 if err != nil {
213 log.Println(err)
214 return err
215 }
216 err = process.Kill()
217 if err != nil {
218 return err
219 }
220 }
221
222 log.Println(os.Getpid(), srv.Addr)
223 return srv.Serve()
224 }
225
226 // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
227 // Serve to handle requests on incoming TLS connections.
228 //
229 // Filenames containing a certificate and matching private key for the server must
230 // be provided. If the certificate is signed by a certificate authority, the
231 // certFile should be the concatenation of the server's certificate followed by the
232 // CA's certificate.
233 //
234 // If srv.Addr is blank, ":https" is used.
235 func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) {
236 addr := srv.Addr
237 if addr == "" {
238 addr = ":https"
239 }
240
241 config := &tls.Config{}
242 if srv.TLSConfig != nil {
243 *config = *srv.TLSConfig
244 }
245 if config.NextProtos == nil {
246 config.NextProtos = []string{"http/1.1"}
247 }
248
249 config.Certificates = make([]tls.Certificate, 1)
250 config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
251 if err != nil {
252 return
253 }
254
255 go srv.handleSignals()
256
257 l, err := srv.getListener(addr)
258 if err != nil {
259 log.Println(err)
260 return err
261 }
262
263 srv.tlsInnerListener = newGraceListener(l, srv)
264 srv.GraceListener = tls.NewListener(srv.tlsInnerListener, config)
265
266 if srv.isChild {
267 process, err := os.FindProcess(os.Getppid())
268 if err != nil {
269 log.Println(err)
270 return err
271 }
272 err = process.Kill()
273 if err != nil {
274 return err
275 }
276 }
277 log.Println(os.Getpid(), srv.Addr)
278 return srv.Serve()
279 }
280
281 // getListener either opens a new socket to listen on, or takes the acceptor socket
282 // it got passed when restarted.
283 func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) {
284 if srv.isChild {
285 var ptrOffset uint = 0
286 if len(socketPtrOffsetMap) > 0 {
287 ptrOffset = socketPtrOffsetMap[laddr]
288 log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
289 }
290
291 f := os.NewFile(uintptr(3+ptrOffset), "")
292 l, err = net.FileListener(f)
293 if err != nil {
294 err = fmt.Errorf("net.FileListener error: %v", err)
295 return
296 }
297 } else {
298 l, err = net.Listen(srv.Network, laddr)
299 if err != nil {
300 err = fmt.Errorf("net.Listen error: %v", err)
301 return
302 }
303 }
304 return
305 }
306
307 // handleSignals listens for os Signals and calls any hooked in function that the
308 // user had registered with the signal.
309 func (srv *graceServer) handleSignals() {
310 var sig os.Signal
311
312 signal.Notify(
313 srv.sigChan,
314 syscall.SIGHUP,
315 syscall.SIGINT,
316 syscall.SIGTERM,
317 )
318
319 pid := syscall.Getpid()
320 for {
321 sig = <-srv.sigChan
322 srv.signalHooks(PRE_SIGNAL, sig)
323 switch sig {
324 case syscall.SIGHUP:
325 log.Println(pid, "Received SIGHUP. forking.")
326 err := srv.fork()
327 if err != nil {
328 log.Println("Fork err:", err)
329 }
330 case syscall.SIGINT:
331 log.Println(pid, "Received SIGINT.")
332 srv.shutdown()
333 case syscall.SIGTERM:
334 log.Println(pid, "Received SIGTERM.")
335 srv.shutdown()
336 default:
337 log.Printf("Received %v: nothing i care about...\n", sig)
338 }
339 srv.signalHooks(POST_SIGNAL, sig)
340 }
341 }
342
343 func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) {
344 if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
345 return
346 }
347 for _, f := range srv.SignalHooks[ppFlag][sig] {
348 f()
349 }
350 return
351 }
352
353 // shutdown closes the listener so that no new connections are accepted. it also
354 // starts a goroutine that will hammer (stop all running requests) the server
355 // after DefaultTimeout.
356 func (srv *graceServer) shutdown() {
357 if srv.state != STATE_RUNNING {
358 return
359 }
360
361 srv.state = STATE_SHUTTING_DOWN
362 if DefaultTimeout >= 0 {
363 go srv.serverTimeout(DefaultTimeout)
364 }
365 err := srv.GraceListener.Close()
366 if err != nil {
367 log.Println(syscall.Getpid(), "Listener.Close() error:", err)
368 } else {
369 log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
370 }
371 }
372
373 // hammerTime forces the server to shutdown in a given timeout - whether it
374 // finished outstanding requests or not. if Read/WriteTimeout are not set or the
375 // max header size is very big a connection could hang...
376 //
377 // srv.Serve() will not return until all connections are served. this will
378 // unblock the srv.wg.Wait() in Serve() thus causing ListenAndServe(TLS) to
379 // return.
380 func (srv *graceServer) serverTimeout(d time.Duration) {
381 defer func() {
382 // we are calling srv.wg.Done() until it panics which means we called
383 // Done() when the counter was already at 0 and we're done.
384 // (and thus Serve() will return and the parent will exit)
385 if r := recover(); r != nil {
386 log.Println("WaitGroup at 0", r)
387 }
388 }()
389 if srv.state != STATE_SHUTTING_DOWN {
390 return
391 }
392 time.Sleep(d)
393 log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
394 for {
395 if srv.state == STATE_TERMINATE {
396 break
397 }
398 srv.wg.Done()
399 }
400 }
401
402 func (srv *graceServer) fork() (err error) {
403 // only one server isntance should fork!
404 regLock.Lock()
405 defer regLock.Unlock()
406 if runningServersForked {
407 return
408 }
409 runningServersForked = true
410
411 var files = make([]*os.File, len(runningServers))
412 var orderArgs = make([]string, len(runningServers))
413 // get the accessor socket fds for _all_ server instances
414 for _, srvPtr := range runningServers {
415 // introspect.PrintTypeDump(srvPtr.EndlessListener)
416 switch srvPtr.GraceListener.(type) {
417 case *graceListener:
418 // normal listener
419 files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
420 default:
421 // tls listener
422 files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
423 }
424 orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
425 }
426
427 log.Println(files)
428 path := os.Args[0]
429 var args []string
430 if len(os.Args) > 1 {
431 for _, arg := range os.Args[1:] {
432 if arg == "-graceful" {
433 break
434 }
435 args = append(args, arg)
436 }
437 }
438 args = append(args, "-graceful")
439 if len(runningServers) > 1 {
440 args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
441 log.Println(args)
442 }
443 cmd := exec.Command(path, args...)
444 cmd.Stdout = os.Stdout
445 cmd.Stderr = os.Stderr
446 cmd.ExtraFiles = files
447 err = cmd.Start()
448 if err != nil {
449 log.Fatalf("Restart: Failed to launch, error: %v", err)
450 }
451
452 return
453 }
454
455 type graceListener struct {
456 net.Listener
457 stop chan error
458 stopped bool
459 server *graceServer
460 }
461
462 func (gl *graceListener) Accept() (c net.Conn, err error) {
463 tc, err := gl.Listener.(*net.TCPListener).AcceptTCP()
464 if err != nil {
465 return
466 }
467
468 tc.SetKeepAlive(true) // see http.tcpKeepAliveListener
469 tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
470
471 c = graceConn{
472 Conn: tc,
473 server: gl.server,
474 }
475
476 gl.server.wg.Add(1)
477 return
478 }
479
480 func newGraceListener(l net.Listener, srv *graceServer) (el *graceListener) {
481 el = &graceListener{
482 Listener: l,
483 stop: make(chan error),
484 server: srv,
485 }
486
487 // Starting the listener for the stop signal here because Accept blocks on
488 // el.Listener.(*net.TCPListener).AcceptTCP()
489 // The goroutine will unblock it by closing the listeners fd
490 go func() {
491 _ = <-el.stop
492 el.stopped = true
493 el.stop <- el.Listener.Close()
494 }()
495 return
496 }
497
498 func (el *graceListener) Close() error {
499 if el.stopped {
500 return syscall.EINVAL
501 }
502 el.stop <- nil
503 return <-el.stop
504 }
505
506 func (el *graceListener) File() *os.File {
507 // returns a dup(2) - FD_CLOEXEC flag *not* set
508 tl := el.Listener.(*net.TCPListener)
509 fl, _ := tl.File()
510 return fl
511 }
512
513 type graceConn struct {
514 net.Conn
515 server *graceServer
516 }
517
518 func (c graceConn) Close() error {
519 c.server.wg.Done()
520 return c.Conn.Close()
521 }
......
1 package grace
2
3 import (
4 "net"
5 "os"
6 "syscall"
7 "time"
8 )
9
10 type graceListener struct {
11 net.Listener
12 stop chan error
13 stopped bool
14 server *graceServer
15 }
16
17 func newGraceListener(l net.Listener, srv *graceServer) (el *graceListener) {
18 el = &graceListener{
19 Listener: l,
20 stop: make(chan error),
21 server: srv,
22 }
23 go func() {
24 _ = <-el.stop
25 el.stopped = true
26 el.stop <- el.Listener.Close()
27 }()
28 return
29 }
30
31 func (gl *graceListener) Accept() (c net.Conn, err error) {
32 tc, err := gl.Listener.(*net.TCPListener).AcceptTCP()
33 if err != nil {
34 return
35 }
36
37 tc.SetKeepAlive(true)
38 tc.SetKeepAlivePeriod(3 * time.Minute)
39
40 c = graceConn{
41 Conn: tc,
42 server: gl.server,
43 }
44
45 gl.server.wg.Add(1)
46 return
47 }
48
49 func (el *graceListener) Close() error {
50 if el.stopped {
51 return syscall.EINVAL
52 }
53 el.stop <- nil
54 return <-el.stop
55 }
56
57 func (el *graceListener) File() *os.File {
58 // returns a dup(2) - FD_CLOEXEC flag *not* set
59 tl := el.Listener.(*net.TCPListener)
60 fl, _ := tl.File()
61 return fl
62 }
1 package grace
2
3 import (
4 "crypto/tls"
5 "fmt"
6 "log"
7 "net"
8 "net/http"
9 "os"
10 "os/exec"
11 "os/signal"
12 "strings"
13 "sync"
14 "syscall"
15 "time"
16 )
17
18 type graceServer struct {
19 *http.Server
20 GraceListener net.Listener
21 SignalHooks map[int]map[os.Signal][]func()
22 tlsInnerListener *graceListener
23 wg sync.WaitGroup
24 sigChan chan os.Signal
25 isChild bool
26 state uint8
27 Network string
28 }
29
30 // Serve accepts incoming connections on the Listener l,
31 // creating a new service goroutine for each.
32 // The service goroutines read requests and then call srv.Handler to reply to them.
33 func (srv *graceServer) Serve() (err error) {
34 srv.state = STATE_RUNNING
35 err = srv.Server.Serve(srv.GraceListener)
36 log.Println(syscall.Getpid(), "Waiting for connections to finish...")
37 srv.wg.Wait()
38 srv.state = STATE_TERMINATE
39 return
40 }
41
42 // ListenAndServe listens on the TCP network address srv.Addr and then calls Serve
43 // to handle requests on incoming connections. If srv.Addr is blank, ":http" is
44 // used.
45 func (srv *graceServer) ListenAndServe() (err error) {
46 addr := srv.Addr
47 if addr == "" {
48 addr = ":http"
49 }
50
51 go srv.handleSignals()
52
53 l, err := srv.getListener(addr)
54 if err != nil {
55 log.Println(err)
56 return err
57 }
58
59 srv.GraceListener = newGraceListener(l, srv)
60
61 if srv.isChild {
62 process, err := os.FindProcess(os.Getppid())
63 if err != nil {
64 log.Println(err)
65 return err
66 }
67 err = process.Kill()
68 if err != nil {
69 return err
70 }
71 }
72
73 log.Println(os.Getpid(), srv.Addr)
74 return srv.Serve()
75 }
76
77 // ListenAndServeTLS listens on the TCP network address srv.Addr and then calls
78 // Serve to handle requests on incoming TLS connections.
79 //
80 // Filenames containing a certificate and matching private key for the server must
81 // be provided. If the certificate is signed by a certificate authority, the
82 // certFile should be the concatenation of the server's certificate followed by the
83 // CA's certificate.
84 //
85 // If srv.Addr is blank, ":https" is used.
86 func (srv *graceServer) ListenAndServeTLS(certFile, keyFile string) (err error) {
87 addr := srv.Addr
88 if addr == "" {
89 addr = ":https"
90 }
91
92 config := &tls.Config{}
93 if srv.TLSConfig != nil {
94 *config = *srv.TLSConfig
95 }
96 if config.NextProtos == nil {
97 config.NextProtos = []string{"http/1.1"}
98 }
99
100 config.Certificates = make([]tls.Certificate, 1)
101 config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
102 if err != nil {
103 return
104 }
105
106 go srv.handleSignals()
107
108 l, err := srv.getListener(addr)
109 if err != nil {
110 log.Println(err)
111 return err
112 }
113
114 srv.tlsInnerListener = newGraceListener(l, srv)
115 srv.GraceListener = tls.NewListener(srv.tlsInnerListener, config)
116
117 if srv.isChild {
118 process, err := os.FindProcess(os.Getppid())
119 if err != nil {
120 log.Println(err)
121 return err
122 }
123 err = process.Kill()
124 if err != nil {
125 return err
126 }
127 }
128 log.Println(os.Getpid(), srv.Addr)
129 return srv.Serve()
130 }
131
132 // getListener either opens a new socket to listen on, or takes the acceptor socket
133 // it got passed when restarted.
134 func (srv *graceServer) getListener(laddr string) (l net.Listener, err error) {
135 if srv.isChild {
136 var ptrOffset uint = 0
137 if len(socketPtrOffsetMap) > 0 {
138 ptrOffset = socketPtrOffsetMap[laddr]
139 log.Println("laddr", laddr, "ptr offset", socketPtrOffsetMap[laddr])
140 }
141
142 f := os.NewFile(uintptr(3+ptrOffset), "")
143 l, err = net.FileListener(f)
144 if err != nil {
145 err = fmt.Errorf("net.FileListener error: %v", err)
146 return
147 }
148 } else {
149 l, err = net.Listen(srv.Network, laddr)
150 if err != nil {
151 err = fmt.Errorf("net.Listen error: %v", err)
152 return
153 }
154 }
155 return
156 }
157
158 // handleSignals listens for os Signals and calls any hooked in function that the
159 // user had registered with the signal.
160 func (srv *graceServer) handleSignals() {
161 var sig os.Signal
162
163 signal.Notify(
164 srv.sigChan,
165 syscall.SIGHUP,
166 syscall.SIGINT,
167 syscall.SIGTERM,
168 )
169
170 pid := syscall.Getpid()
171 for {
172 sig = <-srv.sigChan
173 srv.signalHooks(PRE_SIGNAL, sig)
174 switch sig {
175 case syscall.SIGHUP:
176 log.Println(pid, "Received SIGHUP. forking.")
177 err := srv.fork()
178 if err != nil {
179 log.Println("Fork err:", err)
180 }
181 case syscall.SIGINT:
182 log.Println(pid, "Received SIGINT.")
183 srv.shutdown()
184 case syscall.SIGTERM:
185 log.Println(pid, "Received SIGTERM.")
186 srv.shutdown()
187 default:
188 log.Printf("Received %v: nothing i care about...\n", sig)
189 }
190 srv.signalHooks(POST_SIGNAL, sig)
191 }
192 }
193
194 func (srv *graceServer) signalHooks(ppFlag int, sig os.Signal) {
195 if _, notSet := srv.SignalHooks[ppFlag][sig]; !notSet {
196 return
197 }
198 for _, f := range srv.SignalHooks[ppFlag][sig] {
199 f()
200 }
201 return
202 }
203
204 // shutdown closes the listener so that no new connections are accepted. it also
205 // starts a goroutine that will serverTimeout (stop all running requests) the server
206 // after DefaultTimeout.
207 func (srv *graceServer) shutdown() {
208 if srv.state != STATE_RUNNING {
209 return
210 }
211
212 srv.state = STATE_SHUTTING_DOWN
213 if DefaultTimeout >= 0 {
214 go srv.serverTimeout(DefaultTimeout)
215 }
216 err := srv.GraceListener.Close()
217 if err != nil {
218 log.Println(syscall.Getpid(), "Listener.Close() error:", err)
219 } else {
220 log.Println(syscall.Getpid(), srv.GraceListener.Addr(), "Listener closed.")
221 }
222 }
223
224 // serverTimeout forces the server to shutdown in a given timeout - whether it
225 // finished outstanding requests or not. if Read/WriteTimeout are not set or the
226 // max header size is very big a connection could hang
227 func (srv *graceServer) serverTimeout(d time.Duration) {
228 defer func() {
229 if r := recover(); r != nil {
230 log.Println("WaitGroup at 0", r)
231 }
232 }()
233 if srv.state != STATE_SHUTTING_DOWN {
234 return
235 }
236 time.Sleep(d)
237 log.Println("[STOP - Hammer Time] Forcefully shutting down parent")
238 for {
239 if srv.state == STATE_TERMINATE {
240 break
241 }
242 srv.wg.Done()
243 }
244 }
245
246 func (srv *graceServer) fork() (err error) {
247 regLock.Lock()
248 defer regLock.Unlock()
249 if runningServersForked {
250 return
251 }
252 runningServersForked = true
253
254 var files = make([]*os.File, len(runningServers))
255 var orderArgs = make([]string, len(runningServers))
256 for _, srvPtr := range runningServers {
257 switch srvPtr.GraceListener.(type) {
258 case *graceListener:
259 files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.GraceListener.(*graceListener).File()
260 default:
261 files[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.tlsInnerListener.File()
262 }
263 orderArgs[socketPtrOffsetMap[srvPtr.Server.Addr]] = srvPtr.Server.Addr
264 }
265
266 log.Println(files)
267 path := os.Args[0]
268 var args []string
269 if len(os.Args) > 1 {
270 for _, arg := range os.Args[1:] {
271 if arg == "-graceful" {
272 break
273 }
274 args = append(args, arg)
275 }
276 }
277 args = append(args, "-graceful")
278 if len(runningServers) > 1 {
279 args = append(args, fmt.Sprintf(`-socketorder=%s`, strings.Join(orderArgs, ",")))
280 log.Println(args)
281 }
282 cmd := exec.Command(path, args...)
283 cmd.Stdout = os.Stdout
284 cmd.Stderr = os.Stderr
285 cmd.ExtraFiles = files
286 err = cmd.Start()
287 if err != nil {
288 log.Fatalf("Restart: Failed to launch, error: %v", err)
289 }
290
291 return
292 }
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!