cache.go raw

   1  package validator
   2  
   3  import (
   4  	"fmt"
   5  	"reflect"
   6  	"strings"
   7  	"sync"
   8  	"sync/atomic"
   9  )
  10  
  11  type tagType uint8
  12  
  13  const (
  14  	typeDefault tagType = iota
  15  	typeOmitEmpty
  16  	typeIsDefault
  17  	typeNoStructLevel
  18  	typeStructOnly
  19  	typeDive
  20  	typeOr
  21  	typeKeys
  22  	typeEndKeys
  23  	typeOmitNil
  24  )
  25  
  26  const (
  27  	invalidValidation   = "Invalid validation tag on field '%s'"
  28  	undefinedValidation = "Undefined validation function '%s' on field '%s'"
  29  	keysTagNotDefined   = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag"
  30  )
  31  
  32  type structCache struct {
  33  	lock sync.Mutex
  34  	m    atomic.Value // map[reflect.Type]*cStruct
  35  }
  36  
  37  func (sc *structCache) Get(key reflect.Type) (c *cStruct, found bool) {
  38  	c, found = sc.m.Load().(map[reflect.Type]*cStruct)[key]
  39  	return
  40  }
  41  
  42  func (sc *structCache) Set(key reflect.Type, value *cStruct) {
  43  	m := sc.m.Load().(map[reflect.Type]*cStruct)
  44  	nm := make(map[reflect.Type]*cStruct, len(m)+1)
  45  	for k, v := range m {
  46  		nm[k] = v
  47  	}
  48  	nm[key] = value
  49  	sc.m.Store(nm)
  50  }
  51  
  52  type tagCache struct {
  53  	lock sync.Mutex
  54  	m    atomic.Value // map[string]*cTag
  55  }
  56  
  57  func (tc *tagCache) Get(key string) (c *cTag, found bool) {
  58  	c, found = tc.m.Load().(map[string]*cTag)[key]
  59  	return
  60  }
  61  
  62  func (tc *tagCache) Set(key string, value *cTag) {
  63  	m := tc.m.Load().(map[string]*cTag)
  64  	nm := make(map[string]*cTag, len(m)+1)
  65  	for k, v := range m {
  66  		nm[k] = v
  67  	}
  68  	nm[key] = value
  69  	tc.m.Store(nm)
  70  }
  71  
  72  type cStruct struct {
  73  	name   string
  74  	fields []*cField
  75  	fn     StructLevelFuncCtx
  76  }
  77  
  78  type cField struct {
  79  	idx        int
  80  	name       string
  81  	altName    string
  82  	namesEqual bool
  83  	cTags      *cTag
  84  }
  85  
  86  type cTag struct {
  87  	tag                  string
  88  	aliasTag             string
  89  	actualAliasTag       string
  90  	param                string
  91  	keys                 *cTag // only populated when using tag's 'keys' and 'endkeys' for map key validation
  92  	next                 *cTag
  93  	fn                   FuncCtx
  94  	typeof               tagType
  95  	hasTag               bool
  96  	hasAlias             bool
  97  	hasParam             bool // true if parameter used eg. eq= where the equal sign has been set
  98  	isBlockEnd           bool // indicates the current tag represents the last validation in the block
  99  	runValidationWhenNil bool
 100  }
 101  
 102  func (v *Validate) extractStructCache(current reflect.Value, sName string) *cStruct {
 103  	v.structCache.lock.Lock()
 104  	defer v.structCache.lock.Unlock() // leave as defer! because if inner panics, it will never get unlocked otherwise!
 105  
 106  	typ := current.Type()
 107  
 108  	// could have been multiple trying to access, but once first is done this ensures struct
 109  	// isn't parsed again.
 110  	cs, ok := v.structCache.Get(typ)
 111  	if ok {
 112  		return cs
 113  	}
 114  
 115  	cs = &cStruct{name: sName, fields: make([]*cField, 0), fn: v.structLevelFuncs[typ]}
 116  
 117  	numFields := current.NumField()
 118  	rules := v.rules[typ]
 119  
 120  	var ctag *cTag
 121  	var fld reflect.StructField
 122  	var tag string
 123  	var customName string
 124  
 125  	for i := 0; i < numFields; i++ {
 126  
 127  		fld = typ.Field(i)
 128  
 129  		if !v.privateFieldValidation && !fld.Anonymous && len(fld.PkgPath) > 0 {
 130  			continue
 131  		}
 132  
 133  		if rtag, ok := rules[fld.Name]; ok {
 134  			tag = rtag
 135  		} else {
 136  			tag = fld.Tag.Get(v.tagName)
 137  		}
 138  
 139  		if tag == skipValidationTag {
 140  			continue
 141  		}
 142  
 143  		customName = fld.Name
 144  
 145  		if v.hasTagNameFunc {
 146  			name := v.tagNameFunc(fld)
 147  			if len(name) > 0 {
 148  				customName = name
 149  			}
 150  		}
 151  
 152  		// NOTE: cannot use shared tag cache, because tags may be equal, but things like alias may be different
 153  		// and so only struct level caching can be used instead of combined with Field tag caching
 154  
 155  		if len(tag) > 0 {
 156  			ctag, _ = v.parseFieldTagsRecursive(tag, fld.Name, "", false)
 157  		} else {
 158  			// even if field doesn't have validations need cTag for traversing to potential inner/nested
 159  			// elements of the field.
 160  			ctag = new(cTag)
 161  		}
 162  
 163  		cs.fields = append(cs.fields, &cField{
 164  			idx:        i,
 165  			name:       fld.Name,
 166  			altName:    customName,
 167  			cTags:      ctag,
 168  			namesEqual: fld.Name == customName,
 169  		})
 170  	}
 171  	v.structCache.Set(typ, cs)
 172  	return cs
 173  }
 174  
 175  func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias string, hasAlias bool) (firstCtag *cTag, current *cTag) {
 176  	var t string
 177  	noAlias := len(alias) == 0
 178  	tags := strings.Split(tag, tagSeparator)
 179  
 180  	for i := 0; i < len(tags); i++ {
 181  		t = tags[i]
 182  		if noAlias {
 183  			alias = t
 184  		}
 185  
 186  		// check map for alias and process new tags, otherwise process as usual
 187  		if tagsVal, found := v.aliases[t]; found {
 188  			if i == 0 {
 189  				firstCtag, current = v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
 190  			} else {
 191  				next, curr := v.parseFieldTagsRecursive(tagsVal, fieldName, t, true)
 192  				current.next, current = next, curr
 193  
 194  			}
 195  			continue
 196  		}
 197  
 198  		var prevTag tagType
 199  
 200  		if i == 0 {
 201  			current = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true, typeof: typeDefault}
 202  			firstCtag = current
 203  		} else {
 204  			prevTag = current.typeof
 205  			current.next = &cTag{aliasTag: alias, hasAlias: hasAlias, hasTag: true}
 206  			current = current.next
 207  		}
 208  
 209  		switch t {
 210  		case diveTag:
 211  			current.typeof = typeDive
 212  			continue
 213  
 214  		case keysTag:
 215  			current.typeof = typeKeys
 216  
 217  			if i == 0 || prevTag != typeDive {
 218  				panic(fmt.Sprintf("'%s' tag must be immediately preceded by the '%s' tag", keysTag, diveTag))
 219  			}
 220  
 221  			current.typeof = typeKeys
 222  
 223  			// need to pass along only keys tag
 224  			// need to increment i to skip over the keys tags
 225  			b := make([]byte, 0, 64)
 226  
 227  			i++
 228  
 229  			for ; i < len(tags); i++ {
 230  
 231  				b = append(b, tags[i]...)
 232  				b = append(b, ',')
 233  
 234  				if tags[i] == endKeysTag {
 235  					break
 236  				}
 237  			}
 238  
 239  			current.keys, _ = v.parseFieldTagsRecursive(string(b[:len(b)-1]), fieldName, "", false)
 240  			continue
 241  
 242  		case endKeysTag:
 243  			current.typeof = typeEndKeys
 244  
 245  			// if there are more in tags then there was no keysTag defined
 246  			// and an error should be thrown
 247  			if i != len(tags)-1 {
 248  				panic(keysTagNotDefined)
 249  			}
 250  			return
 251  
 252  		case omitempty:
 253  			current.typeof = typeOmitEmpty
 254  			continue
 255  
 256  		case omitnil:
 257  			current.typeof = typeOmitNil
 258  			continue
 259  
 260  		case structOnlyTag:
 261  			current.typeof = typeStructOnly
 262  			continue
 263  
 264  		case noStructLevelTag:
 265  			current.typeof = typeNoStructLevel
 266  			continue
 267  
 268  		default:
 269  			if t == isdefault {
 270  				current.typeof = typeIsDefault
 271  			}
 272  			// if a pipe character is needed within the param you must use the utf8Pipe representation "0x7C"
 273  			orVals := strings.Split(t, orSeparator)
 274  
 275  			for j := 0; j < len(orVals); j++ {
 276  				vals := strings.SplitN(orVals[j], tagKeySeparator, 2)
 277  				if noAlias {
 278  					alias = vals[0]
 279  					current.aliasTag = alias
 280  				} else {
 281  					current.actualAliasTag = t
 282  				}
 283  
 284  				if j > 0 {
 285  					current.next = &cTag{aliasTag: alias, actualAliasTag: current.actualAliasTag, hasAlias: hasAlias, hasTag: true}
 286  					current = current.next
 287  				}
 288  				current.hasParam = len(vals) > 1
 289  
 290  				current.tag = vals[0]
 291  				if len(current.tag) == 0 {
 292  					panic(strings.TrimSpace(fmt.Sprintf(invalidValidation, fieldName)))
 293  				}
 294  
 295  				if wrapper, ok := v.validations[current.tag]; ok {
 296  					current.fn = wrapper.fn
 297  					current.runValidationWhenNil = wrapper.runValidationOnNil
 298  				} else {
 299  					panic(strings.TrimSpace(fmt.Sprintf(undefinedValidation, current.tag, fieldName)))
 300  				}
 301  
 302  				if len(orVals) > 1 {
 303  					current.typeof = typeOr
 304  				}
 305  
 306  				if len(vals) > 1 {
 307  					current.param = strings.Replace(strings.Replace(vals[1], utf8HexComma, ",", -1), utf8Pipe, "|", -1)
 308  				}
 309  			}
 310  			current.isBlockEnd = true
 311  		}
 312  	}
 313  	return
 314  }
 315  
 316  func (v *Validate) fetchCacheTag(tag string) *cTag {
 317  	// find cached tag
 318  	ctag, found := v.tagCache.Get(tag)
 319  	if !found {
 320  		v.tagCache.lock.Lock()
 321  		defer v.tagCache.lock.Unlock()
 322  
 323  		// could have been multiple trying to access, but once first is done this ensures tag
 324  		// isn't parsed again.
 325  		ctag, found = v.tagCache.Get(tag)
 326  		if !found {
 327  			ctag, _ = v.parseFieldTagsRecursive(tag, "", "", false)
 328  			v.tagCache.Set(tag, ctag)
 329  		}
 330  	}
 331  	return ctag
 332  }
 333