struct.go raw

   1  // Copyright 2016 Qiang Xue. All rights reserved.
   2  // Use of this source code is governed by a MIT-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package validation
   6  
   7  import (
   8  	"context"
   9  	"errors"
  10  	"fmt"
  11  	"reflect"
  12  	"strings"
  13  )
  14  
  15  var (
  16  	// ErrStructPointer is the error that a struct being validated is not specified as a pointer.
  17  	ErrStructPointer = errors.New("only a pointer to a struct can be validated")
  18  )
  19  
  20  type (
  21  	// ErrFieldPointer is the error that a field is not specified as a pointer.
  22  	ErrFieldPointer int
  23  
  24  	// ErrFieldNotFound is the error that a field cannot be found in the struct.
  25  	ErrFieldNotFound int
  26  
  27  	// FieldRules represents a rule set associated with a struct field.
  28  	FieldRules struct {
  29  		fieldPtr interface{}
  30  		rules    []Rule
  31  	}
  32  )
  33  
  34  // Error returns the error string of ErrFieldPointer.
  35  func (e ErrFieldPointer) Error() string {
  36  	return fmt.Sprintf("field #%v must be specified as a pointer", int(e))
  37  }
  38  
  39  // Error returns the error string of ErrFieldNotFound.
  40  func (e ErrFieldNotFound) Error() string {
  41  	return fmt.Sprintf("field #%v cannot be found in the struct", int(e))
  42  }
  43  
  44  // ValidateStruct validates a struct by checking the specified struct fields against the corresponding validation rules.
  45  // Note that the struct being validated must be specified as a pointer to it. If the pointer is nil, it is considered valid.
  46  // Use Field() to specify struct fields that need to be validated. Each Field() call specifies a single field which
  47  // should be specified as a pointer to the field. A field can be associated with multiple rules.
  48  // For example,
  49  //
  50  //    value := struct {
  51  //        Name  string
  52  //        Value string
  53  //    }{"name", "demo"}
  54  //    err := validation.ValidateStruct(&value,
  55  //        validation.Field(&a.Name, validation.Required),
  56  //        validation.Field(&a.Value, validation.Required, validation.Length(5, 10)),
  57  //    )
  58  //    fmt.Println(err)
  59  //    // Value: the length must be between 5 and 10.
  60  //
  61  // An error will be returned if validation fails.
  62  func ValidateStruct(structPtr interface{}, fields ...*FieldRules) error {
  63  	return ValidateStructWithContext(nil, structPtr, fields...)
  64  }
  65  
  66  // ValidateStructWithContext validates a struct with the given context.
  67  // The only difference between ValidateStructWithContext and ValidateStruct is that the former will
  68  // validate struct fields with the provided context.
  69  // Please refer to ValidateStruct for the detailed instructions on how to use this function.
  70  func ValidateStructWithContext(ctx context.Context, structPtr interface{}, fields ...*FieldRules) error {
  71  	value := reflect.ValueOf(structPtr)
  72  	if value.Kind() != reflect.Ptr || !value.IsNil() && value.Elem().Kind() != reflect.Struct {
  73  		// must be a pointer to a struct
  74  		return NewInternalError(ErrStructPointer)
  75  	}
  76  	if value.IsNil() {
  77  		// treat a nil struct pointer as valid
  78  		return nil
  79  	}
  80  	value = value.Elem()
  81  
  82  	errs := Errors{}
  83  
  84  	for i, fr := range fields {
  85  		fv := reflect.ValueOf(fr.fieldPtr)
  86  		if fv.Kind() != reflect.Ptr {
  87  			return NewInternalError(ErrFieldPointer(i))
  88  		}
  89  		ft := findStructField(value, fv)
  90  		if ft == nil {
  91  			return NewInternalError(ErrFieldNotFound(i))
  92  		}
  93  		var err error
  94  		if ctx == nil {
  95  			err = Validate(fv.Elem().Interface(), fr.rules...)
  96  		} else {
  97  			err = ValidateWithContext(ctx, fv.Elem().Interface(), fr.rules...)
  98  		}
  99  		if err != nil {
 100  			if ie, ok := err.(InternalError); ok && ie.InternalError() != nil {
 101  				return err
 102  			}
 103  			if ft.Anonymous {
 104  				// merge errors from anonymous struct field
 105  				if es, ok := err.(Errors); ok {
 106  					for name, value := range es {
 107  						errs[name] = value
 108  					}
 109  					continue
 110  				}
 111  			}
 112  			errs[getErrorFieldName(ft)] = err
 113  		}
 114  	}
 115  
 116  	if len(errs) > 0 {
 117  		return errs
 118  	}
 119  	return nil
 120  }
 121  
 122  // Field specifies a struct field and the corresponding validation rules.
 123  // The struct field must be specified as a pointer to it.
 124  func Field(fieldPtr interface{}, rules ...Rule) *FieldRules {
 125  	return &FieldRules{
 126  		fieldPtr: fieldPtr,
 127  		rules:    rules,
 128  	}
 129  }
 130  
 131  // findStructField looks for a field in the given struct.
 132  // The field being looked for should be a pointer to the actual struct field.
 133  // If found, the field info will be returned. Otherwise, nil will be returned.
 134  func findStructField(structValue reflect.Value, fieldValue reflect.Value) *reflect.StructField {
 135  	ptr := fieldValue.Pointer()
 136  	for i := structValue.NumField() - 1; i >= 0; i-- {
 137  		sf := structValue.Type().Field(i)
 138  		if ptr == structValue.Field(i).UnsafeAddr() {
 139  			// do additional type comparison because it's possible that the address of
 140  			// an embedded struct is the same as the first field of the embedded struct
 141  			if sf.Type == fieldValue.Elem().Type() {
 142  				return &sf
 143  			}
 144  		}
 145  		if sf.Anonymous {
 146  			// delve into anonymous struct to look for the field
 147  			fi := structValue.Field(i)
 148  			if sf.Type.Kind() == reflect.Ptr {
 149  				fi = fi.Elem()
 150  			}
 151  			if fi.Kind() == reflect.Struct {
 152  				if f := findStructField(fi, fieldValue); f != nil {
 153  					return f
 154  				}
 155  			}
 156  		}
 157  	}
 158  	return nil
 159  }
 160  
 161  // getErrorFieldName returns the name that should be used to represent the validation error of a struct field.
 162  func getErrorFieldName(f *reflect.StructField) string {
 163  	if tag := f.Tag.Get(ErrorTag); tag != "" && tag != "-" {
 164  		if cps := strings.SplitN(tag, ",", 2); cps[0] != "" {
 165  			return cps[0]
 166  		}
 167  	}
 168  	return f.Name
 169  }
 170