decode.go raw

   1  package toml
   2  
   3  import (
   4  	"bytes"
   5  	"encoding"
   6  	"encoding/json"
   7  	"fmt"
   8  	"io"
   9  	"io/fs"
  10  	"math"
  11  	"os"
  12  	"reflect"
  13  	"strconv"
  14  	"strings"
  15  	"time"
  16  )
  17  
  18  // Unmarshaler is the interface implemented by objects that can unmarshal a
  19  // TOML description of themselves.
  20  type Unmarshaler interface {
  21  	UnmarshalTOML(any) error
  22  }
  23  
  24  // Unmarshal decodes the contents of data in TOML format into a pointer v.
  25  //
  26  // See [Decoder] for a description of the decoding process.
  27  func Unmarshal(data []byte, v any) error {
  28  	_, err := NewDecoder(bytes.NewReader(data)).Decode(v)
  29  	return err
  30  }
  31  
  32  // Decode the TOML data in to the pointer v.
  33  //
  34  // See [Decoder] for a description of the decoding process.
  35  func Decode(data string, v any) (MetaData, error) {
  36  	return NewDecoder(strings.NewReader(data)).Decode(v)
  37  }
  38  
  39  // DecodeFile reads the contents of a file and decodes it with [Decode].
  40  func DecodeFile(path string, v any) (MetaData, error) {
  41  	fp, err := os.Open(path)
  42  	if err != nil {
  43  		return MetaData{}, err
  44  	}
  45  	defer fp.Close()
  46  	return NewDecoder(fp).Decode(v)
  47  }
  48  
  49  // DecodeFS reads the contents of a file from [fs.FS] and decodes it with
  50  // [Decode].
  51  func DecodeFS(fsys fs.FS, path string, v any) (MetaData, error) {
  52  	fp, err := fsys.Open(path)
  53  	if err != nil {
  54  		return MetaData{}, err
  55  	}
  56  	defer fp.Close()
  57  	return NewDecoder(fp).Decode(v)
  58  }
  59  
  60  // Primitive is a TOML value that hasn't been decoded into a Go value.
  61  //
  62  // This type can be used for any value, which will cause decoding to be delayed.
  63  // You can use [PrimitiveDecode] to "manually" decode these values.
  64  //
  65  // NOTE: The underlying representation of a `Primitive` value is subject to
  66  // change. Do not rely on it.
  67  //
  68  // NOTE: Primitive values are still parsed, so using them will only avoid the
  69  // overhead of reflection. They can be useful when you don't know the exact type
  70  // of TOML data until runtime.
  71  type Primitive struct {
  72  	undecoded any
  73  	context   Key
  74  }
  75  
  76  // The significand precision for float32 and float64 is 24 and 53 bits; this is
  77  // the range a natural number can be stored in a float without loss of data.
  78  const (
  79  	maxSafeFloat32Int = 16777215                // 2^24-1
  80  	maxSafeFloat64Int = int64(9007199254740991) // 2^53-1
  81  )
  82  
  83  // Decoder decodes TOML data.
  84  //
  85  // TOML tables correspond to Go structs or maps; they can be used
  86  // interchangeably, but structs offer better type safety.
  87  //
  88  // TOML table arrays correspond to either a slice of structs or a slice of maps.
  89  //
  90  // TOML datetimes correspond to [time.Time]. Local datetimes are parsed in the
  91  // local timezone.
  92  //
  93  // [time.Duration] types are treated as nanoseconds if the TOML value is an
  94  // integer, or they're parsed with time.ParseDuration() if they're strings.
  95  //
  96  // All other TOML types (float, string, int, bool and array) correspond to the
  97  // obvious Go types.
  98  //
  99  // An exception to the above rules is if a type implements the TextUnmarshaler
 100  // interface, in which case any primitive TOML value (floats, strings, integers,
 101  // booleans, datetimes) will be converted to a []byte and given to the value's
 102  // UnmarshalText method. See the Unmarshaler example for a demonstration with
 103  // email addresses.
 104  //
 105  // # Key mapping
 106  //
 107  // TOML keys can map to either keys in a Go map or field names in a Go struct.
 108  // The special `toml` struct tag can be used to map TOML keys to struct fields
 109  // that don't match the key name exactly (see the example). A case insensitive
 110  // match to struct names will be tried if an exact match can't be found.
 111  //
 112  // The mapping between TOML values and Go values is loose. That is, there may
 113  // exist TOML values that cannot be placed into your representation, and there
 114  // may be parts of your representation that do not correspond to TOML values.
 115  // This loose mapping can be made stricter by using the IsDefined and/or
 116  // Undecoded methods on the MetaData returned.
 117  //
 118  // This decoder does not handle cyclic types. Decode will not terminate if a
 119  // cyclic type is passed.
 120  type Decoder struct {
 121  	r io.Reader
 122  }
 123  
 124  // NewDecoder creates a new Decoder.
 125  func NewDecoder(r io.Reader) *Decoder {
 126  	return &Decoder{r: r}
 127  }
 128  
 129  var (
 130  	unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
 131  	unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
 132  	primitiveType = reflect.TypeOf((*Primitive)(nil)).Elem()
 133  )
 134  
 135  // Decode TOML data in to the pointer `v`.
 136  func (dec *Decoder) Decode(v any) (MetaData, error) {
 137  	rv := reflect.ValueOf(v)
 138  	if rv.Kind() != reflect.Ptr {
 139  		s := "%q"
 140  		if reflect.TypeOf(v) == nil {
 141  			s = "%v"
 142  		}
 143  
 144  		return MetaData{}, fmt.Errorf("toml: cannot decode to non-pointer "+s, reflect.TypeOf(v))
 145  	}
 146  	if rv.IsNil() {
 147  		return MetaData{}, fmt.Errorf("toml: cannot decode to nil value of %q", reflect.TypeOf(v))
 148  	}
 149  
 150  	// Check if this is a supported type: struct, map, any, or something that
 151  	// implements UnmarshalTOML or UnmarshalText.
 152  	rv = indirect(rv)
 153  	rt := rv.Type()
 154  	if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map &&
 155  		!(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) &&
 156  		!rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) {
 157  		return MetaData{}, fmt.Errorf("toml: cannot decode to type %s", rt)
 158  	}
 159  
 160  	// TODO: parser should read from io.Reader? Or at the very least, make it
 161  	// read from []byte rather than string
 162  	data, err := io.ReadAll(dec.r)
 163  	if err != nil {
 164  		return MetaData{}, err
 165  	}
 166  
 167  	p, err := parse(string(data))
 168  	if err != nil {
 169  		return MetaData{}, err
 170  	}
 171  
 172  	md := MetaData{
 173  		mapping: p.mapping,
 174  		keyInfo: p.keyInfo,
 175  		keys:    p.ordered,
 176  		decoded: make(map[string]struct{}, len(p.ordered)),
 177  		context: nil,
 178  		data:    data,
 179  	}
 180  	return md, md.unify(p.mapping, rv)
 181  }
 182  
 183  // PrimitiveDecode is just like the other Decode* functions, except it decodes a
 184  // TOML value that has already been parsed. Valid primitive values can *only* be
 185  // obtained from values filled by the decoder functions, including this method.
 186  // (i.e., v may contain more [Primitive] values.)
 187  //
 188  // Meta data for primitive values is included in the meta data returned by the
 189  // Decode* functions with one exception: keys returned by the Undecoded method
 190  // will only reflect keys that were decoded. Namely, any keys hidden behind a
 191  // Primitive will be considered undecoded. Executing this method will update the
 192  // undecoded keys in the meta data. (See the example.)
 193  func (md *MetaData) PrimitiveDecode(primValue Primitive, v any) error {
 194  	md.context = primValue.context
 195  	defer func() { md.context = nil }()
 196  	return md.unify(primValue.undecoded, rvalue(v))
 197  }
 198  
 199  // markDecodedRecursive is a helper to mark any key under the given tmap as
 200  // decoded, recursing as needed
 201  func markDecodedRecursive(md *MetaData, tmap map[string]any) {
 202  	for key := range tmap {
 203  		md.decoded[md.context.add(key).String()] = struct{}{}
 204  		if tmap, ok := tmap[key].(map[string]any); ok {
 205  			md.context = append(md.context, key)
 206  			markDecodedRecursive(md, tmap)
 207  			md.context = md.context[0 : len(md.context)-1]
 208  		}
 209  		if tarr, ok := tmap[key].([]map[string]any); ok {
 210  			for _, elm := range tarr {
 211  				md.context = append(md.context, key)
 212  				markDecodedRecursive(md, elm)
 213  				md.context = md.context[0 : len(md.context)-1]
 214  			}
 215  		}
 216  	}
 217  }
 218  
 219  // unify performs a sort of type unification based on the structure of `rv`,
 220  // which is the client representation.
 221  //
 222  // Any type mismatch produces an error. Finding a type that we don't know
 223  // how to handle produces an unsupported type error.
 224  func (md *MetaData) unify(data any, rv reflect.Value) error {
 225  	// Special case. Look for a `Primitive` value.
 226  	// TODO: #76 would make this superfluous after implemented.
 227  	if rv.Type() == primitiveType {
 228  		// Save the undecoded data and the key context into the primitive
 229  		// value.
 230  		context := make(Key, len(md.context))
 231  		copy(context, md.context)
 232  		rv.Set(reflect.ValueOf(Primitive{
 233  			undecoded: data,
 234  			context:   context,
 235  		}))
 236  		return nil
 237  	}
 238  
 239  	rvi := rv.Interface()
 240  	if v, ok := rvi.(Unmarshaler); ok {
 241  		err := v.UnmarshalTOML(data)
 242  		if err != nil {
 243  			return md.parseErr(err)
 244  		}
 245  		// Assume the Unmarshaler decoded everything, so mark all keys under
 246  		// this table as decoded.
 247  		if tmap, ok := data.(map[string]any); ok {
 248  			markDecodedRecursive(md, tmap)
 249  		}
 250  		if aot, ok := data.([]map[string]any); ok {
 251  			for _, tmap := range aot {
 252  				markDecodedRecursive(md, tmap)
 253  			}
 254  		}
 255  		return nil
 256  	}
 257  	if v, ok := rvi.(encoding.TextUnmarshaler); ok {
 258  		return md.unifyText(data, v)
 259  	}
 260  
 261  	// TODO:
 262  	// The behavior here is incorrect whenever a Go type satisfies the
 263  	// encoding.TextUnmarshaler interface but also corresponds to a TOML hash or
 264  	// array. In particular, the unmarshaler should only be applied to primitive
 265  	// TOML values. But at this point, it will be applied to all kinds of values
 266  	// and produce an incorrect error whenever those values are hashes or arrays
 267  	// (including arrays of tables).
 268  
 269  	k := rv.Kind()
 270  
 271  	if k >= reflect.Int && k <= reflect.Uint64 {
 272  		return md.unifyInt(data, rv)
 273  	}
 274  	switch k {
 275  	case reflect.Struct:
 276  		return md.unifyStruct(data, rv)
 277  	case reflect.Map:
 278  		return md.unifyMap(data, rv)
 279  	case reflect.Array:
 280  		return md.unifyArray(data, rv)
 281  	case reflect.Slice:
 282  		return md.unifySlice(data, rv)
 283  	case reflect.String:
 284  		return md.unifyString(data, rv)
 285  	case reflect.Bool:
 286  		return md.unifyBool(data, rv)
 287  	case reflect.Interface:
 288  		if rv.NumMethod() > 0 { /// Only empty interfaces are supported.
 289  			return md.e("unsupported type %s", rv.Type())
 290  		}
 291  		return md.unifyAnything(data, rv)
 292  	case reflect.Float32, reflect.Float64:
 293  		return md.unifyFloat64(data, rv)
 294  	}
 295  	return md.e("unsupported type %s", rv.Kind())
 296  }
 297  
 298  func (md *MetaData) unifyStruct(mapping any, rv reflect.Value) error {
 299  	tmap, ok := mapping.(map[string]any)
 300  	if !ok {
 301  		if mapping == nil {
 302  			return nil
 303  		}
 304  		return md.e("type mismatch for %s: expected table but found %s", rv.Type().String(), fmtType(mapping))
 305  	}
 306  
 307  	for key, datum := range tmap {
 308  		var f *field
 309  		fields := cachedTypeFields(rv.Type())
 310  		for i := range fields {
 311  			ff := &fields[i]
 312  			if ff.name == key {
 313  				f = ff
 314  				break
 315  			}
 316  			if f == nil && strings.EqualFold(ff.name, key) {
 317  				f = ff
 318  			}
 319  		}
 320  		if f != nil {
 321  			subv := rv
 322  			for _, i := range f.index {
 323  				subv = indirect(subv.Field(i))
 324  			}
 325  
 326  			if isUnifiable(subv) {
 327  				md.decoded[md.context.add(key).String()] = struct{}{}
 328  				md.context = append(md.context, key)
 329  
 330  				err := md.unify(datum, subv)
 331  				if err != nil {
 332  					return err
 333  				}
 334  				md.context = md.context[0 : len(md.context)-1]
 335  			} else if f.name != "" {
 336  				return md.e("cannot write unexported field %s.%s", rv.Type().String(), f.name)
 337  			}
 338  		}
 339  	}
 340  	return nil
 341  }
 342  
 343  func (md *MetaData) unifyMap(mapping any, rv reflect.Value) error {
 344  	keyType := rv.Type().Key().Kind()
 345  	if keyType != reflect.String && keyType != reflect.Interface {
 346  		return fmt.Errorf("toml: cannot decode to a map with non-string key type (%s in %q)",
 347  			keyType, rv.Type())
 348  	}
 349  
 350  	tmap, ok := mapping.(map[string]any)
 351  	if !ok {
 352  		if tmap == nil {
 353  			return nil
 354  		}
 355  		return md.badtype("map", mapping)
 356  	}
 357  	if rv.IsNil() {
 358  		rv.Set(reflect.MakeMap(rv.Type()))
 359  	}
 360  	for k, v := range tmap {
 361  		md.decoded[md.context.add(k).String()] = struct{}{}
 362  		md.context = append(md.context, k)
 363  
 364  		rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
 365  
 366  		err := md.unify(v, indirect(rvval))
 367  		if err != nil {
 368  			return err
 369  		}
 370  		md.context = md.context[0 : len(md.context)-1]
 371  
 372  		rvkey := indirect(reflect.New(rv.Type().Key()))
 373  
 374  		switch keyType {
 375  		case reflect.Interface:
 376  			rvkey.Set(reflect.ValueOf(k))
 377  		case reflect.String:
 378  			rvkey.SetString(k)
 379  		}
 380  
 381  		rv.SetMapIndex(rvkey, rvval)
 382  	}
 383  	return nil
 384  }
 385  
 386  func (md *MetaData) unifyArray(data any, rv reflect.Value) error {
 387  	datav := reflect.ValueOf(data)
 388  	if datav.Kind() != reflect.Slice {
 389  		if !datav.IsValid() {
 390  			return nil
 391  		}
 392  		return md.badtype("slice", data)
 393  	}
 394  	if l := datav.Len(); l != rv.Len() {
 395  		return md.e("expected array length %d; got TOML array of length %d", rv.Len(), l)
 396  	}
 397  	return md.unifySliceArray(datav, rv)
 398  }
 399  
 400  func (md *MetaData) unifySlice(data any, rv reflect.Value) error {
 401  	datav := reflect.ValueOf(data)
 402  	if datav.Kind() != reflect.Slice {
 403  		if !datav.IsValid() {
 404  			return nil
 405  		}
 406  		return md.badtype("slice", data)
 407  	}
 408  	n := datav.Len()
 409  	if rv.IsNil() || rv.Cap() < n {
 410  		rv.Set(reflect.MakeSlice(rv.Type(), n, n))
 411  	}
 412  	rv.SetLen(n)
 413  	return md.unifySliceArray(datav, rv)
 414  }
 415  
 416  func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
 417  	l := data.Len()
 418  	for i := 0; i < l; i++ {
 419  		err := md.unify(data.Index(i).Interface(), indirect(rv.Index(i)))
 420  		if err != nil {
 421  			return err
 422  		}
 423  	}
 424  	return nil
 425  }
 426  
 427  func (md *MetaData) unifyString(data any, rv reflect.Value) error {
 428  	_, ok := rv.Interface().(json.Number)
 429  	if ok {
 430  		if i, ok := data.(int64); ok {
 431  			rv.SetString(strconv.FormatInt(i, 10))
 432  		} else if f, ok := data.(float64); ok {
 433  			rv.SetString(strconv.FormatFloat(f, 'g', -1, 64))
 434  		} else {
 435  			return md.badtype("string", data)
 436  		}
 437  		return nil
 438  	}
 439  
 440  	if s, ok := data.(string); ok {
 441  		rv.SetString(s)
 442  		return nil
 443  	}
 444  	return md.badtype("string", data)
 445  }
 446  
 447  func (md *MetaData) unifyFloat64(data any, rv reflect.Value) error {
 448  	rvk := rv.Kind()
 449  
 450  	if num, ok := data.(float64); ok {
 451  		switch rvk {
 452  		case reflect.Float32:
 453  			if num < -math.MaxFloat32 || num > math.MaxFloat32 {
 454  				return md.parseErr(errParseRange{i: num, size: rvk.String()})
 455  			}
 456  			fallthrough
 457  		case reflect.Float64:
 458  			rv.SetFloat(num)
 459  		default:
 460  			panic("bug")
 461  		}
 462  		return nil
 463  	}
 464  
 465  	if num, ok := data.(int64); ok {
 466  		if (rvk == reflect.Float32 && (num < -maxSafeFloat32Int || num > maxSafeFloat32Int)) ||
 467  			(rvk == reflect.Float64 && (num < -maxSafeFloat64Int || num > maxSafeFloat64Int)) {
 468  			return md.parseErr(errUnsafeFloat{i: num, size: rvk.String()})
 469  		}
 470  		rv.SetFloat(float64(num))
 471  		return nil
 472  	}
 473  
 474  	return md.badtype("float", data)
 475  }
 476  
 477  func (md *MetaData) unifyInt(data any, rv reflect.Value) error {
 478  	_, ok := rv.Interface().(time.Duration)
 479  	if ok {
 480  		// Parse as string duration, and fall back to regular integer parsing
 481  		// (as nanosecond) if this is not a string.
 482  		if s, ok := data.(string); ok {
 483  			dur, err := time.ParseDuration(s)
 484  			if err != nil {
 485  				return md.parseErr(errParseDuration{s})
 486  			}
 487  			rv.SetInt(int64(dur))
 488  			return nil
 489  		}
 490  	}
 491  
 492  	num, ok := data.(int64)
 493  	if !ok {
 494  		return md.badtype("integer", data)
 495  	}
 496  
 497  	rvk := rv.Kind()
 498  	switch {
 499  	case rvk >= reflect.Int && rvk <= reflect.Int64:
 500  		if (rvk == reflect.Int8 && (num < math.MinInt8 || num > math.MaxInt8)) ||
 501  			(rvk == reflect.Int16 && (num < math.MinInt16 || num > math.MaxInt16)) ||
 502  			(rvk == reflect.Int32 && (num < math.MinInt32 || num > math.MaxInt32)) {
 503  			return md.parseErr(errParseRange{i: num, size: rvk.String()})
 504  		}
 505  		rv.SetInt(num)
 506  	case rvk >= reflect.Uint && rvk <= reflect.Uint64:
 507  		unum := uint64(num)
 508  		if rvk == reflect.Uint8 && (num < 0 || unum > math.MaxUint8) ||
 509  			rvk == reflect.Uint16 && (num < 0 || unum > math.MaxUint16) ||
 510  			rvk == reflect.Uint32 && (num < 0 || unum > math.MaxUint32) {
 511  			return md.parseErr(errParseRange{i: num, size: rvk.String()})
 512  		}
 513  		rv.SetUint(unum)
 514  	default:
 515  		panic("unreachable")
 516  	}
 517  	return nil
 518  }
 519  
 520  func (md *MetaData) unifyBool(data any, rv reflect.Value) error {
 521  	if b, ok := data.(bool); ok {
 522  		rv.SetBool(b)
 523  		return nil
 524  	}
 525  	return md.badtype("boolean", data)
 526  }
 527  
 528  func (md *MetaData) unifyAnything(data any, rv reflect.Value) error {
 529  	rv.Set(reflect.ValueOf(data))
 530  	return nil
 531  }
 532  
 533  func (md *MetaData) unifyText(data any, v encoding.TextUnmarshaler) error {
 534  	var s string
 535  	switch sdata := data.(type) {
 536  	case Marshaler:
 537  		text, err := sdata.MarshalTOML()
 538  		if err != nil {
 539  			return err
 540  		}
 541  		s = string(text)
 542  	case encoding.TextMarshaler:
 543  		text, err := sdata.MarshalText()
 544  		if err != nil {
 545  			return err
 546  		}
 547  		s = string(text)
 548  	case fmt.Stringer:
 549  		s = sdata.String()
 550  	case string:
 551  		s = sdata
 552  	case bool:
 553  		s = fmt.Sprintf("%v", sdata)
 554  	case int64:
 555  		s = fmt.Sprintf("%d", sdata)
 556  	case float64:
 557  		s = fmt.Sprintf("%f", sdata)
 558  	default:
 559  		return md.badtype("primitive (string-like)", data)
 560  	}
 561  	if err := v.UnmarshalText([]byte(s)); err != nil {
 562  		return md.parseErr(err)
 563  	}
 564  	return nil
 565  }
 566  
 567  func (md *MetaData) badtype(dst string, data any) error {
 568  	return md.e("incompatible types: TOML value has type %s; destination has type %s", fmtType(data), dst)
 569  }
 570  
 571  func (md *MetaData) parseErr(err error) error {
 572  	k := md.context.String()
 573  	d := string(md.data)
 574  	return ParseError{
 575  		Message:  err.Error(),
 576  		err:      err,
 577  		LastKey:  k,
 578  		Position: md.keyInfo[k].pos.withCol(d),
 579  		Line:     md.keyInfo[k].pos.Line,
 580  		input:    d,
 581  	}
 582  }
 583  
 584  func (md *MetaData) e(format string, args ...any) error {
 585  	f := "toml: "
 586  	if len(md.context) > 0 {
 587  		f = fmt.Sprintf("toml: (last key %q): ", md.context)
 588  		p := md.keyInfo[md.context.String()].pos
 589  		if p.Line > 0 {
 590  			f = fmt.Sprintf("toml: line %d (last key %q): ", p.Line, md.context)
 591  		}
 592  	}
 593  	return fmt.Errorf(f+format, args...)
 594  }
 595  
 596  // rvalue returns a reflect.Value of `v`. All pointers are resolved.
 597  func rvalue(v any) reflect.Value {
 598  	return indirect(reflect.ValueOf(v))
 599  }
 600  
 601  // indirect returns the value pointed to by a pointer.
 602  //
 603  // Pointers are followed until the value is not a pointer. New values are
 604  // allocated for each nil pointer.
 605  //
 606  // An exception to this rule is if the value satisfies an interface of interest
 607  // to us (like encoding.TextUnmarshaler).
 608  func indirect(v reflect.Value) reflect.Value {
 609  	if v.Kind() != reflect.Ptr {
 610  		if v.CanSet() {
 611  			pv := v.Addr()
 612  			pvi := pv.Interface()
 613  			if _, ok := pvi.(encoding.TextUnmarshaler); ok {
 614  				return pv
 615  			}
 616  			if _, ok := pvi.(Unmarshaler); ok {
 617  				return pv
 618  			}
 619  		}
 620  		return v
 621  	}
 622  	if v.IsNil() {
 623  		v.Set(reflect.New(v.Type().Elem()))
 624  	}
 625  	return indirect(reflect.Indirect(v))
 626  }
 627  
 628  func isUnifiable(rv reflect.Value) bool {
 629  	if rv.CanSet() {
 630  		return true
 631  	}
 632  	rvi := rv.Interface()
 633  	if _, ok := rvi.(encoding.TextUnmarshaler); ok {
 634  		return true
 635  	}
 636  	if _, ok := rvi.(Unmarshaler); ok {
 637  		return true
 638  	}
 639  	return false
 640  }
 641  
 642  // fmt %T with "interface {}" replaced with "any", which is far more readable.
 643  func fmtType(t any) string {
 644  	return strings.ReplaceAll(fmt.Sprintf("%T", t), "interface {}", "any")
 645  }
 646