validation.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. package validation
  2. import (
  3. "fmt"
  4. "reflect"
  5. "regexp"
  6. "strings"
  7. "github.com/coscms/tagfast"
  8. )
  9. type ValidFormer interface {
  10. Valid(*Validation)
  11. }
  12. type ValidationError struct {
  13. Message, Key, Name, Field, Tmpl string
  14. Value interface{}
  15. LimitValue interface{}
  16. }
  17. // Returns the Message.
  18. func (e *ValidationError) String() string {
  19. if e == nil {
  20. return ""
  21. }
  22. return e.Message
  23. }
  24. // A ValidationResult is returned from every validation method.
  25. // It provides an indication of success, and a pointer to the Error (if any).
  26. type ValidationResult struct {
  27. Error *ValidationError
  28. Ok bool
  29. }
  30. func (r *ValidationResult) Key(key string) *ValidationResult {
  31. if r.Error != nil {
  32. r.Error.Key = key
  33. }
  34. return r
  35. }
  36. func (r *ValidationResult) Message(message string, args ...interface{}) *ValidationResult {
  37. if r.Error != nil {
  38. if len(args) == 0 {
  39. r.Error.Message = message
  40. } else {
  41. r.Error.Message = fmt.Sprintf(message, args...)
  42. }
  43. }
  44. return r
  45. }
  46. // A Validation context manages data validation and error messages.
  47. type Validation struct {
  48. Errors []*ValidationError
  49. ErrorsMap map[string]*ValidationError
  50. }
  51. func (v *Validation) Clear() {
  52. v.Errors = []*ValidationError{}
  53. }
  54. func (v *Validation) HasErrors() bool {
  55. return len(v.Errors) > 0
  56. }
  57. // Return the errors mapped by key.
  58. // If there are multiple validation errors associated with a single key, the
  59. // first one "wins". (Typically the first validation will be the more basic).
  60. func (v *Validation) ErrorMap() map[string]*ValidationError {
  61. return v.ErrorsMap
  62. }
  63. // Add an error to the validation context.
  64. func (v *Validation) Error(message string, args ...interface{}) *ValidationResult {
  65. result := (&ValidationResult{
  66. Ok: false,
  67. Error: &ValidationError{},
  68. }).Message(message, args...)
  69. v.Errors = append(v.Errors, result.Error)
  70. return result
  71. }
  72. // Test that the argument is non-nil and non-empty (if string or list)
  73. func (v *Validation) Required(obj interface{}, key string) *ValidationResult {
  74. return v.apply(Required{key}, obj)
  75. }
  76. // Test that the obj is greater than min if obj's type is int
  77. func (v *Validation) Min(obj interface{}, min int, key string) *ValidationResult {
  78. return v.apply(Min{min, key}, obj)
  79. }
  80. // Test that the obj is less than max if obj's type is int
  81. func (v *Validation) Max(obj interface{}, max int, key string) *ValidationResult {
  82. return v.apply(Max{max, key}, obj)
  83. }
  84. // Test that the obj is between mni and max if obj's type is int
  85. func (v *Validation) Range(obj interface{}, min, max int, key string) *ValidationResult {
  86. return v.apply(Range{Min{Min: min}, Max{Max: max}, key}, obj)
  87. }
  88. func (v *Validation) MinSize(obj interface{}, min int, key string) *ValidationResult {
  89. return v.apply(MinSize{min, key}, obj)
  90. }
  91. func (v *Validation) MaxSize(obj interface{}, max int, key string) *ValidationResult {
  92. return v.apply(MaxSize{max, key}, obj)
  93. }
  94. func (v *Validation) Length(obj interface{}, n int, key string) *ValidationResult {
  95. return v.apply(Length{n, key}, obj)
  96. }
  97. func (v *Validation) Alpha(obj interface{}, key string) *ValidationResult {
  98. return v.apply(Alpha{key}, obj)
  99. }
  100. func (v *Validation) Numeric(obj interface{}, key string) *ValidationResult {
  101. return v.apply(Numeric{key}, obj)
  102. }
  103. func (v *Validation) AlphaNumeric(obj interface{}, key string) *ValidationResult {
  104. return v.apply(AlphaNumeric{key}, obj)
  105. }
  106. func (v *Validation) Match(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
  107. return v.apply(Match{regex, key}, obj)
  108. }
  109. func (v *Validation) NoMatch(obj interface{}, regex *regexp.Regexp, key string) *ValidationResult {
  110. return v.apply(NoMatch{Match{Regexp: regex}, key}, obj)
  111. }
  112. func (v *Validation) AlphaDash(obj interface{}, key string) *ValidationResult {
  113. return v.apply(AlphaDash{NoMatch{Match: Match{Regexp: alphaDashPattern}}, key}, obj)
  114. }
  115. func (v *Validation) Email(obj interface{}, key string) *ValidationResult {
  116. return v.apply(Email{Match{Regexp: emailPattern}, key}, obj)
  117. }
  118. func (v *Validation) IP(obj interface{}, key string) *ValidationResult {
  119. return v.apply(IP{Match{Regexp: ipPattern}, key}, obj)
  120. }
  121. func (v *Validation) Base64(obj interface{}, key string) *ValidationResult {
  122. return v.apply(Base64{Match{Regexp: base64Pattern}, key}, obj)
  123. }
  124. func (v *Validation) Mobile(obj interface{}, key string) *ValidationResult {
  125. return v.apply(Mobile{Match{Regexp: mobilePattern}, key}, obj)
  126. }
  127. func (v *Validation) Tel(obj interface{}, key string) *ValidationResult {
  128. return v.apply(Tel{Match{Regexp: telPattern}, key}, obj)
  129. }
  130. func (v *Validation) Phone(obj interface{}, key string) *ValidationResult {
  131. return v.apply(Phone{Mobile{Match: Match{Regexp: mobilePattern}},
  132. Tel{Match: Match{Regexp: telPattern}}, key}, obj)
  133. }
  134. func (v *Validation) ZipCode(obj interface{}, key string) *ValidationResult {
  135. return v.apply(ZipCode{Match{Regexp: zipCodePattern}, key}, obj)
  136. }
  137. func (v *Validation) apply(chk Validator, obj interface{}) *ValidationResult {
  138. if chk.IsSatisfied(obj) {
  139. return &ValidationResult{Ok: true}
  140. }
  141. // Add the error to the validation context.
  142. key := chk.GetKey()
  143. Field := key
  144. Name := ""
  145. parts := strings.Split(key, "|")
  146. if len(parts) == 2 {
  147. Field = parts[0]
  148. Name = parts[1]
  149. }
  150. err := &ValidationError{
  151. Message: chk.DefaultMessage(),
  152. Key: key,
  153. Name: Name,
  154. Field: Field,
  155. Value: obj,
  156. Tmpl: MessageTmpls[Name],
  157. LimitValue: chk.GetLimitValue(),
  158. }
  159. v.setError(err)
  160. // Also return it in the result.
  161. return &ValidationResult{
  162. Ok: false,
  163. Error: err,
  164. }
  165. }
  166. func (v *Validation) setError(err *ValidationError) {
  167. v.Errors = append(v.Errors, err)
  168. if v.ErrorsMap == nil {
  169. v.ErrorsMap = make(map[string]*ValidationError)
  170. }
  171. if _, ok := v.ErrorsMap[err.Field]; !ok {
  172. v.ErrorsMap[err.Field] = err
  173. }
  174. }
  175. func (v *Validation) SetError(fieldName string, errMsg string) *ValidationError {
  176. err := &ValidationError{Key: fieldName, Field: fieldName, Tmpl: errMsg, Message: errMsg}
  177. v.setError(err)
  178. return err
  179. }
  180. // Apply a group of validators to a field, in order, and return the
  181. // ValidationResult from the first one that fails, or the last one that
  182. // succeeds.
  183. func (v *Validation) Check(obj interface{}, checks ...Validator) *ValidationResult {
  184. var result *ValidationResult
  185. for _, check := range checks {
  186. result = v.apply(check, obj)
  187. if !result.Ok {
  188. return result
  189. }
  190. }
  191. return result
  192. }
  193. // the obj parameter must be a struct or a struct pointer
  194. func (v *Validation) Valid(obj interface{}, args ...string) (b bool, err error) {
  195. err = v.validExec(obj, "", args...)
  196. if err != nil {
  197. fmt.Println(err)
  198. return
  199. }
  200. if !v.HasErrors() {
  201. if form, ok := obj.(ValidFormer); ok {
  202. form.Valid(v)
  203. }
  204. }
  205. return !v.HasErrors(), nil
  206. }
  207. func (v *Validation) validExec(obj interface{}, baseName string, args ...string) (err error) {
  208. objT := reflect.TypeOf(obj)
  209. objV := reflect.ValueOf(obj)
  210. switch {
  211. case isStruct(objT):
  212. case isStructPtr(objT):
  213. objT = objT.Elem()
  214. objV = objV.Elem()
  215. default:
  216. err = fmt.Errorf("%v must be a struct or a struct pointer", obj)
  217. return
  218. }
  219. var chkFields map[string][]string = make(map[string][]string)
  220. var pNum int = len(args)
  221. //fmt.Println(objT.Name(), ":[Struct NumIn]", pNum)
  222. if pNum > 0 {
  223. //aa.b.c,ab.b.c
  224. for _, v := range args {
  225. arr := strings.SplitN(v, ".", 2)
  226. if _, ok := chkFields[arr[0]]; !ok {
  227. chkFields[arr[0]] = make([]string, 0)
  228. }
  229. if len(arr) > 1 {
  230. chkFields[arr[0]] = append(chkFields[arr[0]], arr[1])
  231. }
  232. }
  233. }
  234. args = make([]string, 0)
  235. if len(chkFields) > 0 { //检测指定字段
  236. for field, args := range chkFields {
  237. f, ok := objT.FieldByName(field)
  238. if !ok {
  239. err = fmt.Errorf("No name for the '%s' field", field)
  240. return
  241. }
  242. tag := tagfast.Tag(objT, f, VALIDTAG)
  243. if tag == "-" {
  244. continue
  245. }
  246. var vfs []ValidFunc
  247. var fName string
  248. if baseName == "" {
  249. fName = f.Name
  250. } else {
  251. fName = strings.Join([]string{baseName, f.Name}, ".")
  252. }
  253. fv := objV.FieldByName(field)
  254. if isStruct(f.Type) || isStructPtr(f.Type) {
  255. if fv.CanInterface() {
  256. err = v.validExec(fv.Interface(), fName, args...)
  257. }
  258. continue
  259. }
  260. if vfs, err = getValidFuncs(f, objT, fName); err != nil {
  261. return
  262. }
  263. for _, vf := range vfs {
  264. if _, err = funcs.Call(vf.Name,
  265. mergeParam(v, fv.Interface(), vf.Params)...); err != nil {
  266. return
  267. }
  268. }
  269. }
  270. } else { //检测全部字段
  271. for i := 0; i < objT.NumField(); i++ {
  272. tag := tagfast.Tag(objT, objT.Field(i), VALIDTAG)
  273. if tag == "-" {
  274. continue
  275. }
  276. var vfs []ValidFunc
  277. var fName string
  278. if baseName == "" {
  279. fName = objT.Field(i).Name
  280. } else {
  281. fName = strings.Join([]string{baseName, objT.Field(i).Name}, ".")
  282. }
  283. //fmt.Println(fName, ":[Type]:", objT.Field(i).Type.Kind())
  284. if isStruct(objT.Field(i).Type) || isStructPtr(objT.Field(i).Type) {
  285. if objV.Field(i).CanInterface() {
  286. err = v.validExec(objV.Field(i).Interface(), fName)
  287. }
  288. continue
  289. }
  290. if vfs, err = getValidFuncs(objT.Field(i), objT, fName); err != nil {
  291. return
  292. }
  293. for _, vf := range vfs {
  294. if _, err = funcs.Call(vf.Name,
  295. mergeParam(v, objV.Field(i).Interface(), vf.Params)...); err != nil {
  296. return
  297. }
  298. }
  299. }
  300. }
  301. return
  302. }