encoder_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. // Copyright (C) MongoDB, Inc. 2017-present.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License"); you may
  4. // not use this file except in compliance with the License. You may obtain
  5. // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
  6. package bson
  7. import (
  8. "bytes"
  9. "errors"
  10. "reflect"
  11. "testing"
  12. "go.mongodb.org/mongo-driver/bson/bsoncodec"
  13. "go.mongodb.org/mongo-driver/bson/bsonrw"
  14. "go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
  15. )
  16. func TestBasicEncode(t *testing.T) {
  17. for _, tc := range marshalingTestCases {
  18. t.Run(tc.name, func(t *testing.T) {
  19. got := make(bsonrw.SliceWriter, 0, 1024)
  20. vw, err := bsonrw.NewBSONValueWriter(&got)
  21. noerr(t, err)
  22. reg := DefaultRegistry
  23. encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val))
  24. noerr(t, err)
  25. err = encoder.EncodeValue(bsoncodec.EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.val))
  26. noerr(t, err)
  27. if !bytes.Equal(got, tc.want) {
  28. t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
  29. t.Errorf("Bytes:\n%v\n%v", got, tc.want)
  30. }
  31. })
  32. }
  33. }
  34. func TestEncoderEncode(t *testing.T) {
  35. for _, tc := range marshalingTestCases {
  36. t.Run(tc.name, func(t *testing.T) {
  37. got := make(bsonrw.SliceWriter, 0, 1024)
  38. vw, err := bsonrw.NewBSONValueWriter(&got)
  39. noerr(t, err)
  40. enc, err := NewEncoder(vw)
  41. noerr(t, err)
  42. err = enc.Encode(tc.val)
  43. noerr(t, err)
  44. if !bytes.Equal(got, tc.want) {
  45. t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
  46. t.Errorf("Bytes:\n%v\n%v", got, tc.want)
  47. }
  48. })
  49. }
  50. t.Run("Marshaler", func(t *testing.T) {
  51. testCases := []struct {
  52. name string
  53. buf []byte
  54. err error
  55. wanterr error
  56. vw bsonrw.ValueWriter
  57. }{
  58. {
  59. "error",
  60. nil,
  61. errors.New("Marshaler error"),
  62. errors.New("Marshaler error"),
  63. &bsonrwtest.ValueReaderWriter{},
  64. },
  65. {
  66. "copy error",
  67. []byte{0x05, 0x00, 0x00, 0x00, 0x00},
  68. nil,
  69. errors.New("copy error"),
  70. &bsonrwtest.ValueReaderWriter{Err: errors.New("copy error"), ErrAfter: bsonrwtest.WriteDocument},
  71. },
  72. {
  73. "success",
  74. []byte{0x07, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00},
  75. nil,
  76. nil,
  77. nil,
  78. },
  79. }
  80. for _, tc := range testCases {
  81. t.Run(tc.name, func(t *testing.T) {
  82. marshaler := testMarshaler{buf: tc.buf, err: tc.err}
  83. var vw bsonrw.ValueWriter
  84. var err error
  85. b := make(bsonrw.SliceWriter, 0, 100)
  86. compareVW := false
  87. if tc.vw != nil {
  88. vw = tc.vw
  89. } else {
  90. compareVW = true
  91. vw, err = bsonrw.NewBSONValueWriter(&b)
  92. noerr(t, err)
  93. }
  94. enc, err := NewEncoder(vw)
  95. noerr(t, err)
  96. got := enc.Encode(marshaler)
  97. want := tc.wanterr
  98. if !compareErrors(got, want) {
  99. t.Errorf("Did not receive expected error. got %v; want %v", got, want)
  100. }
  101. if compareVW {
  102. buf := b
  103. if !bytes.Equal(buf, tc.buf) {
  104. t.Errorf("Copied bytes do not match. got %v; want %v", buf, tc.buf)
  105. }
  106. }
  107. })
  108. }
  109. })
  110. }
  111. type testMarshaler struct {
  112. buf []byte
  113. err error
  114. }
  115. func (tm testMarshaler) MarshalBSON() ([]byte, error) { return tm.buf, tm.err }
  116. func docToBytes(d interface{}) []byte {
  117. b, err := Marshal(d)
  118. if err != nil {
  119. panic(err)
  120. }
  121. return b
  122. }
  123. type byteMarshaler []byte
  124. func (bm byteMarshaler) MarshalBSON() ([]byte, error) { return bm, nil }
  125. type _Interface interface {
  126. method()
  127. }
  128. type _impl struct {
  129. Foo string
  130. }
  131. func (_impl) method() {}