scgi.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package xweb
  2. import (
  3. "bufio"
  4. "bytes"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net"
  9. "net/http"
  10. "net/http/cgi"
  11. "strconv"
  12. "strings"
  13. )
  14. type scgiBody struct {
  15. reader io.Reader
  16. conn io.ReadWriteCloser
  17. closed bool
  18. }
  19. func (b *scgiBody) Read(p []byte) (n int, err error) {
  20. if b.closed {
  21. return 0, errors.New("SCGI read after close")
  22. }
  23. return b.reader.Read(p)
  24. }
  25. func (b *scgiBody) Close() error {
  26. b.closed = true
  27. return b.conn.Close()
  28. }
  29. type scgiConn struct {
  30. fd io.ReadWriteCloser
  31. req *http.Request
  32. headers http.Header
  33. wroteHeaders bool
  34. }
  35. func (conn *scgiConn) WriteHeader(status int) {
  36. if !conn.wroteHeaders {
  37. conn.wroteHeaders = true
  38. var buf bytes.Buffer
  39. text := statusText[status]
  40. fmt.Fprintf(&buf, "HTTP/1.1 %d %s\r\n", status, text)
  41. for k, v := range conn.headers {
  42. for _, i := range v {
  43. buf.WriteString(k + ": " + i + "\r\n")
  44. }
  45. }
  46. buf.WriteString("\r\n")
  47. conn.fd.Write(buf.Bytes())
  48. }
  49. }
  50. func (conn *scgiConn) Header() http.Header {
  51. return conn.headers
  52. }
  53. func (conn *scgiConn) Write(data []byte) (n int, err error) {
  54. if !conn.wroteHeaders {
  55. conn.WriteHeader(200)
  56. }
  57. if conn.req.Method == "HEAD" {
  58. return 0, errors.New("Body Not Allowed")
  59. }
  60. return conn.fd.Write(data)
  61. }
  62. func (conn *scgiConn) Close() { conn.fd.Close() }
  63. func (conn *scgiConn) finishRequest() error {
  64. var buf bytes.Buffer
  65. if !conn.wroteHeaders {
  66. conn.wroteHeaders = true
  67. for k, v := range conn.headers {
  68. for _, i := range v {
  69. buf.WriteString(k + ": " + i + "\r\n")
  70. }
  71. }
  72. buf.WriteString("\r\n")
  73. conn.fd.Write(buf.Bytes())
  74. }
  75. return nil
  76. }
  77. func (s *Server) readScgiRequest(fd io.ReadWriteCloser) (*http.Request, error) {
  78. reader := bufio.NewReader(fd)
  79. line, err := reader.ReadString(':')
  80. if err != nil {
  81. s.Logger.Println("Error during SCGI read: ", err.Error())
  82. }
  83. length, _ := strconv.Atoi(line[0 : len(line)-1])
  84. if length > 16384 {
  85. s.Logger.Println("Error: max header size is 16k")
  86. }
  87. headerData := make([]byte, length)
  88. _, err = reader.Read(headerData)
  89. if err != nil {
  90. return nil, err
  91. }
  92. b, err := reader.ReadByte()
  93. if err != nil {
  94. return nil, err
  95. }
  96. // discard the trailing comma
  97. if b != ',' {
  98. return nil, errors.New("SCGI protocol error: missing comma")
  99. }
  100. headerList := bytes.Split(headerData, []byte{0})
  101. headers := map[string]string{}
  102. for i := 0; i < len(headerList)-1; i += 2 {
  103. headers[string(headerList[i])] = string(headerList[i+1])
  104. }
  105. httpReq, err := cgi.RequestFromMap(headers)
  106. if err != nil {
  107. return nil, err
  108. }
  109. if httpReq.ContentLength > 0 {
  110. httpReq.Body = &scgiBody{
  111. reader: io.LimitReader(reader, httpReq.ContentLength),
  112. conn: fd,
  113. }
  114. } else {
  115. httpReq.Body = &scgiBody{reader: reader, conn: fd}
  116. }
  117. return httpReq, nil
  118. }
  119. func (s *Server) handleScgiRequest(fd io.ReadWriteCloser) {
  120. req, err := s.readScgiRequest(fd)
  121. if err != nil {
  122. s.Logger.Println("SCGI error: %q", err.Error())
  123. }
  124. sc := scgiConn{fd, req, make(map[string][]string), false}
  125. for _, app := range s.Apps {
  126. app.routeHandler(req, &sc)
  127. }
  128. sc.finishRequest()
  129. fd.Close()
  130. }
  131. func (s *Server) listenAndServeScgi(addr string) error {
  132. var l net.Listener
  133. var err error
  134. //if the path begins with a "/", assume it's a unix address
  135. if strings.HasPrefix(addr, "/") {
  136. l, err = net.Listen("unix", addr)
  137. } else {
  138. l, err = net.Listen("tcp", addr)
  139. }
  140. //save the listener so it can be closed
  141. s.l = l
  142. if err != nil {
  143. s.Logger.Println("SCGI listen error", err.Error())
  144. return err
  145. }
  146. for {
  147. fd, err := l.Accept()
  148. if err != nil {
  149. s.Logger.Println("SCGI accept error", err.Error())
  150. return err
  151. }
  152. go s.handleScgiRequest(fd)
  153. }
  154. return nil
  155. }