websocket_test.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. // Copyright 2009 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bytes"
  7. "fmt"
  8. "io"
  9. "log"
  10. "net"
  11. "net/http"
  12. "net/http/httptest"
  13. "net/url"
  14. "strings"
  15. "sync"
  16. "testing"
  17. )
  18. var serverAddr string
  19. var once sync.Once
  20. func echoServer(ws *Conn) { io.Copy(ws, ws) }
  21. type Count struct {
  22. S string
  23. N int
  24. }
  25. func countServer(ws *Conn) {
  26. for {
  27. var count Count
  28. err := JSON.Receive(ws, &count)
  29. if err != nil {
  30. return
  31. }
  32. count.N++
  33. count.S = strings.Repeat(count.S, count.N)
  34. err = JSON.Send(ws, count)
  35. if err != nil {
  36. return
  37. }
  38. }
  39. }
  40. func subProtocolHandshake(config *Config, req *http.Request) error {
  41. for _, proto := range config.Protocol {
  42. if proto == "chat" {
  43. config.Protocol = []string{proto}
  44. return nil
  45. }
  46. }
  47. return ErrBadWebSocketProtocol
  48. }
  49. func subProtoServer(ws *Conn) {
  50. for _, proto := range ws.Config().Protocol {
  51. io.WriteString(ws, proto)
  52. }
  53. }
  54. func startServer() {
  55. http.Handle("/echo", Handler(echoServer))
  56. http.Handle("/count", Handler(countServer))
  57. subproto := Server{
  58. Handshake: subProtocolHandshake,
  59. Handler: Handler(subProtoServer),
  60. }
  61. http.Handle("/subproto", subproto)
  62. server := httptest.NewServer(nil)
  63. serverAddr = server.Listener.Addr().String()
  64. log.Print("Test WebSocket server listening on ", serverAddr)
  65. }
  66. func newConfig(t *testing.T, path string) *Config {
  67. config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
  68. return config
  69. }
  70. func TestEcho(t *testing.T) {
  71. once.Do(startServer)
  72. // websocket.Dial()
  73. client, err := net.Dial("tcp", serverAddr)
  74. if err != nil {
  75. t.Fatal("dialing", err)
  76. }
  77. conn, err := NewClient(newConfig(t, "/echo"), client)
  78. if err != nil {
  79. t.Errorf("WebSocket handshake error: %v", err)
  80. return
  81. }
  82. msg := []byte("hello, world\n")
  83. if _, err := conn.Write(msg); err != nil {
  84. t.Errorf("Write: %v", err)
  85. }
  86. var actual_msg = make([]byte, 512)
  87. n, err := conn.Read(actual_msg)
  88. if err != nil {
  89. t.Errorf("Read: %v", err)
  90. }
  91. actual_msg = actual_msg[0:n]
  92. if !bytes.Equal(msg, actual_msg) {
  93. t.Errorf("Echo: expected %q got %q", msg, actual_msg)
  94. }
  95. conn.Close()
  96. }
  97. func TestAddr(t *testing.T) {
  98. once.Do(startServer)
  99. // websocket.Dial()
  100. client, err := net.Dial("tcp", serverAddr)
  101. if err != nil {
  102. t.Fatal("dialing", err)
  103. }
  104. conn, err := NewClient(newConfig(t, "/echo"), client)
  105. if err != nil {
  106. t.Errorf("WebSocket handshake error: %v", err)
  107. return
  108. }
  109. ra := conn.RemoteAddr().String()
  110. if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
  111. t.Errorf("Bad remote addr: %v", ra)
  112. }
  113. la := conn.LocalAddr().String()
  114. if !strings.HasPrefix(la, "http://") {
  115. t.Errorf("Bad local addr: %v", la)
  116. }
  117. conn.Close()
  118. }
  119. func TestCount(t *testing.T) {
  120. once.Do(startServer)
  121. // websocket.Dial()
  122. client, err := net.Dial("tcp", serverAddr)
  123. if err != nil {
  124. t.Fatal("dialing", err)
  125. }
  126. conn, err := NewClient(newConfig(t, "/count"), client)
  127. if err != nil {
  128. t.Errorf("WebSocket handshake error: %v", err)
  129. return
  130. }
  131. var count Count
  132. count.S = "hello"
  133. if err := JSON.Send(conn, count); err != nil {
  134. t.Errorf("Write: %v", err)
  135. }
  136. if err := JSON.Receive(conn, &count); err != nil {
  137. t.Errorf("Read: %v", err)
  138. }
  139. if count.N != 1 {
  140. t.Errorf("count: expected %d got %d", 1, count.N)
  141. }
  142. if count.S != "hello" {
  143. t.Errorf("count: expected %q got %q", "hello", count.S)
  144. }
  145. if err := JSON.Send(conn, count); err != nil {
  146. t.Errorf("Write: %v", err)
  147. }
  148. if err := JSON.Receive(conn, &count); err != nil {
  149. t.Errorf("Read: %v", err)
  150. }
  151. if count.N != 2 {
  152. t.Errorf("count: expected %d got %d", 2, count.N)
  153. }
  154. if count.S != "hellohello" {
  155. t.Errorf("count: expected %q got %q", "hellohello", count.S)
  156. }
  157. conn.Close()
  158. }
  159. func TestWithQuery(t *testing.T) {
  160. once.Do(startServer)
  161. client, err := net.Dial("tcp", serverAddr)
  162. if err != nil {
  163. t.Fatal("dialing", err)
  164. }
  165. config := newConfig(t, "/echo")
  166. config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
  167. if err != nil {
  168. t.Fatal("location url", err)
  169. }
  170. ws, err := NewClient(config, client)
  171. if err != nil {
  172. t.Errorf("WebSocket handshake: %v", err)
  173. return
  174. }
  175. ws.Close()
  176. }
  177. func testWithProtocol(t *testing.T, subproto []string) (string, error) {
  178. once.Do(startServer)
  179. client, err := net.Dial("tcp", serverAddr)
  180. if err != nil {
  181. t.Fatal("dialing", err)
  182. }
  183. config := newConfig(t, "/subproto")
  184. config.Protocol = subproto
  185. ws, err := NewClient(config, client)
  186. if err != nil {
  187. return "", err
  188. }
  189. msg := make([]byte, 16)
  190. n, err := ws.Read(msg)
  191. if err != nil {
  192. return "", err
  193. }
  194. ws.Close()
  195. return string(msg[:n]), nil
  196. }
  197. func TestWithProtocol(t *testing.T) {
  198. proto, err := testWithProtocol(t, []string{"chat"})
  199. if err != nil {
  200. t.Errorf("SubProto: unexpected error: %v", err)
  201. }
  202. if proto != "chat" {
  203. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  204. }
  205. }
  206. func TestWithTwoProtocol(t *testing.T) {
  207. proto, err := testWithProtocol(t, []string{"test", "chat"})
  208. if err != nil {
  209. t.Errorf("SubProto: unexpected error: %v", err)
  210. }
  211. if proto != "chat" {
  212. t.Errorf("SubProto: expected %q, got %q", "chat", proto)
  213. }
  214. }
  215. func TestWithBadProtocol(t *testing.T) {
  216. _, err := testWithProtocol(t, []string{"test"})
  217. if err != ErrBadStatus {
  218. t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
  219. }
  220. }
  221. func TestHTTP(t *testing.T) {
  222. once.Do(startServer)
  223. // If the client did not send a handshake that matches the protocol
  224. // specification, the server MUST return an HTTP response with an
  225. // appropriate error code (such as 400 Bad Request)
  226. resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
  227. if err != nil {
  228. t.Errorf("Get: error %#v", err)
  229. return
  230. }
  231. if resp == nil {
  232. t.Error("Get: resp is null")
  233. return
  234. }
  235. if resp.StatusCode != http.StatusBadRequest {
  236. t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
  237. }
  238. }
  239. func TestTrailingSpaces(t *testing.T) {
  240. // http://code.google.com/p/go/issues/detail?id=955
  241. // The last runs of this create keys with trailing spaces that should not be
  242. // generated by the client.
  243. once.Do(startServer)
  244. config := newConfig(t, "/echo")
  245. for i := 0; i < 30; i++ {
  246. // body
  247. ws, err := DialConfig(config)
  248. if err != nil {
  249. t.Errorf("Dial #%d failed: %v", i, err)
  250. break
  251. }
  252. ws.Close()
  253. }
  254. }
  255. func TestDialConfigBadVersion(t *testing.T) {
  256. once.Do(startServer)
  257. config := newConfig(t, "/echo")
  258. config.Version = 1234
  259. _, err := DialConfig(config)
  260. if dialerr, ok := err.(*DialError); ok {
  261. if dialerr.Err != ErrBadProtocolVersion {
  262. t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
  263. }
  264. }
  265. }
  266. func TestSmallBuffer(t *testing.T) {
  267. // http://code.google.com/p/go/issues/detail?id=1145
  268. // Read should be able to handle reading a fragment of a frame.
  269. once.Do(startServer)
  270. // websocket.Dial()
  271. client, err := net.Dial("tcp", serverAddr)
  272. if err != nil {
  273. t.Fatal("dialing", err)
  274. }
  275. conn, err := NewClient(newConfig(t, "/echo"), client)
  276. if err != nil {
  277. t.Errorf("WebSocket handshake error: %v", err)
  278. return
  279. }
  280. msg := []byte("hello, world\n")
  281. if _, err := conn.Write(msg); err != nil {
  282. t.Errorf("Write: %v", err)
  283. }
  284. var small_msg = make([]byte, 8)
  285. n, err := conn.Read(small_msg)
  286. if err != nil {
  287. t.Errorf("Read: %v", err)
  288. }
  289. if !bytes.Equal(msg[:len(small_msg)], small_msg) {
  290. t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
  291. }
  292. var second_msg = make([]byte, len(msg))
  293. n, err = conn.Read(second_msg)
  294. if err != nil {
  295. t.Errorf("Read: %v", err)
  296. }
  297. second_msg = second_msg[0:n]
  298. if !bytes.Equal(msg[len(small_msg):], second_msg) {
  299. t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
  300. }
  301. conn.Close()
  302. }
  303. var parseAuthorityTests = []struct {
  304. in *url.URL
  305. out string
  306. }{
  307. {
  308. &url.URL{
  309. Scheme: "ws",
  310. Host: "www.google.com",
  311. },
  312. "www.google.com:80",
  313. },
  314. {
  315. &url.URL{
  316. Scheme: "wss",
  317. Host: "www.google.com",
  318. },
  319. "www.google.com:443",
  320. },
  321. {
  322. &url.URL{
  323. Scheme: "ws",
  324. Host: "www.google.com:80",
  325. },
  326. "www.google.com:80",
  327. },
  328. {
  329. &url.URL{
  330. Scheme: "wss",
  331. Host: "www.google.com:443",
  332. },
  333. "www.google.com:443",
  334. },
  335. // some invalid ones for parseAuthority. parseAuthority doesn't
  336. // concern itself with the scheme unless it actually knows about it
  337. {
  338. &url.URL{
  339. Scheme: "http",
  340. Host: "www.google.com",
  341. },
  342. "www.google.com",
  343. },
  344. {
  345. &url.URL{
  346. Scheme: "http",
  347. Host: "www.google.com:80",
  348. },
  349. "www.google.com:80",
  350. },
  351. {
  352. &url.URL{
  353. Scheme: "asdf",
  354. Host: "127.0.0.1",
  355. },
  356. "127.0.0.1",
  357. },
  358. {
  359. &url.URL{
  360. Scheme: "asdf",
  361. Host: "www.google.com",
  362. },
  363. "www.google.com",
  364. },
  365. }
  366. func TestParseAuthority(t *testing.T) {
  367. for _, tt := range parseAuthorityTests {
  368. out := parseAuthority(tt.in)
  369. if out != tt.out {
  370. t.Errorf("got %v; want %v", out, tt.out)
  371. }
  372. }
  373. }