decoder.go raw

   1  package xmlrpc
   2  
   3  import (
   4  	"bytes"
   5  	"encoding/xml"
   6  	"errors"
   7  	"fmt"
   8  	"io"
   9  	"reflect"
  10  	"regexp"
  11  	"strconv"
  12  	"strings"
  13  	"time"
  14  
  15  	"golang.org/x/text/encoding/charmap"
  16  )
  17  
  18  const (
  19  	iso8601         = "20060102T15:04:05"
  20  	iso8601hyphen   = "2006-01-02T15:04:05Z"
  21  	iso8601hyphenTZ = "2006-01-02T15:04:05-07:00"
  22  	iso8601Z        = "20060102T15:04:05Z07:00"
  23  	iso8601Hyphen   = "2006-01-02T15:04:05"
  24  	iso8601HyphenZ  = "2006-01-02T15:04:05Z07:00"
  25  )
  26  
  27  var (
  28  	// CharsetReader is a function to generate reader which converts a non UTF-8
  29  	// charset into UTF-8.
  30  	CharsetReader func(string, io.Reader) (io.Reader, error)
  31  
  32  	timeLayouts     = []string{iso8601, iso8601hyphen, iso8601hyphenTZ, iso8601Z, iso8601Hyphen, iso8601HyphenZ}
  33  	invalidXmlError = errors.New("invalid xml")
  34  
  35  	// This Regex exists to detect repsponses that contain an array. Which is required because the SoftLayer API
  36  	// will say it is returning an array, but actually return a struct if there is only one element.
  37  	topArrayRE = regexp.MustCompile(`^<\?xml version="1.0" encoding=".+"\?>\s*<params>\s*<param>\s*<value>\s*<array>`)
  38  )
  39  
  40  type TypeMismatchError string
  41  
  42  func (e TypeMismatchError) Error() string { return string(e) }
  43  
  44  type decoder struct {
  45  	*xml.Decoder
  46  }
  47  
  48  func unmarshal(data []byte, v interface{}) (err error) {
  49  	dec := &decoder{xml.NewDecoder(bytes.NewBuffer(data))}
  50  
  51  	if CharsetReader != nil {
  52  		dec.CharsetReader = CharsetReader
  53  	} else {
  54  		dec.CharsetReader = defaultCharsetReader
  55  	}
  56  
  57  	var tok xml.Token
  58  	for {
  59  		if tok, err = dec.Token(); err != nil {
  60  			return err
  61  		}
  62  
  63  		if t, ok := tok.(xml.StartElement); ok {
  64  			if t.Name.Local == "value" {
  65  				val := reflect.ValueOf(v)
  66  				if val.Kind() != reflect.Ptr {
  67  					return errors.New("non-pointer value passed to unmarshal")
  68  				}
  69  
  70  				val = val.Elem()
  71  				// Some APIs that normally return a collection, omit the []'s when
  72  				// the API returns a single value.
  73  				if val.Kind() == reflect.Slice && !topArrayRE.MatchString(string(data)) {
  74  					val.Set(reflect.MakeSlice(val.Type(), 1, 1))
  75  					val = val.Index(0)
  76  				}
  77  
  78  				if err = dec.decodeValue(val); err != nil {
  79  					return err
  80  				}
  81  
  82  				break
  83  			}
  84  		}
  85  	}
  86  
  87  	// read until end of document
  88  	err = dec.Skip()
  89  	if err != nil && err != io.EOF {
  90  		return err
  91  	}
  92  
  93  	return nil
  94  }
  95  
  96  func (dec *decoder) decodeValue(val reflect.Value) error {
  97  	var tok xml.Token
  98  	var err error
  99  
 100  	if val.Kind() == reflect.Ptr {
 101  		if val.IsNil() {
 102  			val.Set(reflect.New(val.Type().Elem()))
 103  		}
 104  		val = val.Elem()
 105  	}
 106  
 107  	var typeName string
 108  	for {
 109  		if tok, err = dec.Token(); err != nil {
 110  			return err
 111  		}
 112  
 113  		if t, ok := tok.(xml.EndElement); ok {
 114  			if t.Name.Local == "value" {
 115  				return nil
 116  			} else {
 117  				return invalidXmlError
 118  			}
 119  		}
 120  
 121  		if t, ok := tok.(xml.StartElement); ok {
 122  			typeName = t.Name.Local
 123  			break
 124  		}
 125  
 126  		// Treat value data without type identifier as string
 127  		if t, ok := tok.(xml.CharData); ok {
 128  			if value := strings.TrimSpace(string(t)); value != "" {
 129  				if err = checkType(val, reflect.String); err != nil {
 130  					return err
 131  				}
 132  
 133  				val.SetString(value)
 134  				return nil
 135  			}
 136  		}
 137  	}
 138  
 139  	switch typeName {
 140  	case "struct":
 141  		ismap := false
 142  		pmap := val
 143  		valType := val.Type()
 144  
 145  		if err = checkType(val, reflect.Struct); err != nil {
 146  			if checkType(val, reflect.Map) == nil {
 147  				if valType.Key().Kind() != reflect.String {
 148  					return fmt.Errorf("only maps with string key type can be unmarshalled")
 149  				}
 150  				ismap = true
 151  			} else if checkType(val, reflect.Interface) == nil && val.IsNil() {
 152  				var dummy map[string]interface{}
 153  				valType = reflect.TypeOf(dummy)
 154  				pmap = reflect.New(valType).Elem()
 155  				val.Set(pmap)
 156  				ismap = true
 157  			} else {
 158  				return err
 159  			}
 160  		}
 161  
 162  		var fields map[string]reflect.Value
 163  
 164  		if !ismap {
 165  			fields = make(map[string]reflect.Value)
 166  			buildStructFieldMap(&fields, val)
 167  		} else {
 168  			// Create initial empty map
 169  			pmap.Set(reflect.MakeMap(valType))
 170  		}
 171  
 172  		// Process struct members.
 173  	StructLoop:
 174  		for {
 175  			if tok, err = dec.Token(); err != nil {
 176  				return err
 177  			}
 178  			switch t := tok.(type) {
 179  			case xml.StartElement:
 180  				if t.Name.Local != "member" {
 181  					return invalidXmlError
 182  				}
 183  
 184  				tagName, fieldName, err := dec.readTag()
 185  				if err != nil {
 186  					return err
 187  				}
 188  				if tagName != "name" {
 189  					return invalidXmlError
 190  				}
 191  
 192  				var fv reflect.Value
 193  				ok := true
 194  
 195  				if !ismap {
 196  					fv, ok = fields[string(fieldName)]
 197  				} else {
 198  					fv = reflect.New(valType.Elem())
 199  				}
 200  
 201  				if ok {
 202  					for {
 203  						if tok, err = dec.Token(); err != nil {
 204  							return err
 205  						}
 206  						if t, ok := tok.(xml.StartElement); ok && t.Name.Local == "value" {
 207  							if err = dec.decodeValue(fv); err != nil {
 208  								return err
 209  							}
 210  
 211  							// </value>
 212  							if err = dec.Skip(); err != nil {
 213  								return err
 214  							}
 215  
 216  							break
 217  						}
 218  					}
 219  				}
 220  
 221  				// </member>
 222  				if err = dec.Skip(); err != nil {
 223  					return err
 224  				}
 225  
 226  				if ismap {
 227  					pmap.SetMapIndex(reflect.ValueOf(string(fieldName)), reflect.Indirect(fv))
 228  					val.Set(pmap)
 229  				}
 230  			case xml.EndElement:
 231  				break StructLoop
 232  			}
 233  		}
 234  	case "array":
 235  		slice := val
 236  		if checkType(val, reflect.Interface) == nil && val.IsNil() {
 237  			slice = reflect.ValueOf([]interface{}{})
 238  		} else if err = checkType(val, reflect.Slice); err != nil {
 239  			// Check to see if we have an unexpected array when we expect
 240  			// a struct. Adjust by expecting an array of the struct type
 241  			// and see if things still work.
 242  			// https://github.com/renier/xmlrpc/pull/2
 243  			if val.Kind() == reflect.Struct {
 244  				slice = reflect.New(reflect.SliceOf(reflect.TypeOf(val.Interface()))).Elem()
 245  				val = slice
 246  			} else {
 247  				return err
 248  			}
 249  		}
 250  
 251  	ArrayLoop:
 252  		for {
 253  			if tok, err = dec.Token(); err != nil {
 254  				return err
 255  			}
 256  
 257  			switch t := tok.(type) {
 258  			case xml.StartElement:
 259  				var index int
 260  				if t.Name.Local != "data" {
 261  					return invalidXmlError
 262  				}
 263  			DataLoop:
 264  				for {
 265  					if tok, err = dec.Token(); err != nil {
 266  						return err
 267  					}
 268  
 269  					switch tt := tok.(type) {
 270  					case xml.StartElement:
 271  						if tt.Name.Local != "value" {
 272  							return invalidXmlError
 273  						}
 274  
 275  						// Incase the incoming val is already defined.
 276  						if index < slice.Len() {
 277  							v := slice.Index(index)
 278  							if v.Kind() == reflect.Interface {
 279  								v = v.Elem()
 280  							}
 281  							if v.Kind() != reflect.Ptr {
 282  								return errors.New("error: cannot write to non-pointer array element")
 283  							}
 284  							if err = dec.decodeValue(v); err != nil {
 285  								return err
 286  							}
 287  						} else {
 288  							v := reflect.New(slice.Type().Elem())
 289  							if err = dec.decodeValue(v); err != nil {
 290  								return err
 291  							}
 292  							slice = reflect.Append(slice, v.Elem())
 293  						}
 294  
 295  						// </value>
 296  						if err = dec.Skip(); err != nil {
 297  							return err
 298  						}
 299  						index++
 300  					case xml.EndElement:
 301  						val.Set(slice)
 302  						break DataLoop
 303  					}
 304  				}
 305  			case xml.EndElement:
 306  				break ArrayLoop
 307  			}
 308  		}
 309  	default:
 310  		if tok, err = dec.Token(); err != nil {
 311  			return err
 312  		}
 313  
 314  		var data []byte
 315  
 316  		switch t := tok.(type) {
 317  		case xml.EndElement:
 318  			return nil
 319  		case xml.CharData:
 320  			data = []byte(t.Copy())
 321  		default:
 322  			return invalidXmlError
 323  		}
 324  
 325  	ParseValue:
 326  		switch typeName {
 327  		case "int", "i4", "i8":
 328  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 329  				i, err := strconv.ParseInt(string(data), 10, 64)
 330  				if err != nil {
 331  					return err
 332  				}
 333  
 334  				pi := reflect.New(reflect.TypeOf(i)).Elem()
 335  				pi.SetInt(i)
 336  				val.Set(pi)
 337  			} else if err = checkType(val, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64); err != nil {
 338  				return err
 339  			} else {
 340  				k := val.Kind()
 341  				isInt := k == reflect.Int || k == reflect.Int8 || k == reflect.Int16 || k == reflect.Int32 || k == reflect.Int64
 342  
 343  				if isInt {
 344  					i, err := strconv.ParseInt(string(data), 10, val.Type().Bits())
 345  					if err != nil {
 346  						return err
 347  					}
 348  
 349  					val.SetInt(i)
 350  				} else {
 351  					i, err := strconv.ParseUint(string(data), 10, val.Type().Bits())
 352  					if err != nil {
 353  						return err
 354  					}
 355  
 356  					val.SetUint(i)
 357  				}
 358  			}
 359  		case "string", "base64":
 360  			str := string(data)
 361  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 362  				pstr := reflect.New(reflect.TypeOf(str)).Elem()
 363  				pstr.SetString(str)
 364  				val.Set(pstr)
 365  			} else if err = checkType(val, reflect.String); err != nil {
 366  				valName := val.Type().Name()
 367  				if valName == "" {
 368  					valName = reflect.Indirect(val).Type().Name()
 369  				}
 370  
 371  				if valName == "Time" {
 372  					timeField := val.FieldByName(valName)
 373  					if timeField.IsValid() {
 374  						val = timeField
 375  					}
 376  					typeName = "dateTime.iso8601"
 377  					goto ParseValue
 378  				} else if strings.HasPrefix(strings.ToLower(valName), "float") {
 379  					typeName = "double"
 380  					goto ParseValue
 381  				}
 382  				return err
 383  			} else {
 384  				val.SetString(str)
 385  			}
 386  		case "dateTime.iso8601":
 387  			var t time.Time
 388  			var err error
 389  			for _, df := range timeLayouts {
 390  				t, err = time.Parse(df, string(data))
 391  
 392  				if err == nil {
 393  					break
 394  				}
 395  			}
 396  			if err != nil {
 397  				return err
 398  			}
 399  
 400  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 401  				ptime := reflect.New(reflect.TypeOf(t)).Elem()
 402  				ptime.Set(reflect.ValueOf(t))
 403  				val.Set(ptime)
 404  			} else if !reflect.TypeOf((time.Time)(t)).ConvertibleTo(val.Type()) {
 405  				return TypeMismatchError(
 406  					fmt.Sprintf(
 407  						"error: type mismatch error - can't decode %v (%s.%s) to time",
 408  						val.Kind(),
 409  						val.Type().PkgPath(),
 410  						val.Type().Name(),
 411  					),
 412  				)
 413  			} else {
 414  				val.Set(reflect.ValueOf(t).Convert(val.Type()))
 415  			}
 416  		case "boolean":
 417  			v, err := strconv.ParseBool(string(data))
 418  			if err != nil {
 419  				return err
 420  			}
 421  
 422  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 423  				pv := reflect.New(reflect.TypeOf(v)).Elem()
 424  				pv.SetBool(v)
 425  				val.Set(pv)
 426  			} else if err = checkType(val, reflect.Bool); err != nil {
 427  				return err
 428  			} else {
 429  				val.SetBool(v)
 430  			}
 431  		case "double":
 432  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 433  				i, err := strconv.ParseFloat(string(data), 64)
 434  				if err != nil {
 435  					return err
 436  				}
 437  
 438  				pdouble := reflect.New(reflect.TypeOf(i)).Elem()
 439  				pdouble.SetFloat(i)
 440  				val.Set(pdouble)
 441  			} else if err = checkType(val, reflect.Float32, reflect.Float64); err != nil {
 442  				return err
 443  			} else {
 444  				i, err := strconv.ParseFloat(string(data), val.Type().Bits())
 445  				if err != nil {
 446  					return err
 447  				}
 448  
 449  				val.SetFloat(i)
 450  			}
 451  		default:
 452  			return errors.New("unsupported type")
 453  		}
 454  
 455  		// </type>
 456  		if err = dec.Skip(); err != nil {
 457  			return err
 458  		}
 459  	}
 460  
 461  	return nil
 462  }
 463  
 464  func (dec *decoder) readTag() (string, []byte, error) {
 465  	var tok xml.Token
 466  	var err error
 467  
 468  	var name string
 469  	for {
 470  		if tok, err = dec.Token(); err != nil {
 471  			return "", nil, err
 472  		}
 473  
 474  		if t, ok := tok.(xml.StartElement); ok {
 475  			name = t.Name.Local
 476  			break
 477  		}
 478  	}
 479  
 480  	value, err := dec.readCharData()
 481  	if err != nil {
 482  		return "", nil, err
 483  	}
 484  
 485  	return name, value, dec.Skip()
 486  }
 487  
 488  func (dec *decoder) readCharData() ([]byte, error) {
 489  	var tok xml.Token
 490  	var err error
 491  
 492  	if tok, err = dec.Token(); err != nil {
 493  		return nil, err
 494  	}
 495  
 496  	if t, ok := tok.(xml.CharData); ok {
 497  		return []byte(t.Copy()), nil
 498  	} else {
 499  		return nil, invalidXmlError
 500  	}
 501  }
 502  
 503  func checkType(val reflect.Value, kinds ...reflect.Kind) error {
 504  	if len(kinds) == 0 {
 505  		return nil
 506  	}
 507  
 508  	if val.Kind() == reflect.Ptr {
 509  		val = val.Elem()
 510  	}
 511  
 512  	match := false
 513  
 514  	for _, kind := range kinds {
 515  		if val.Kind() == kind {
 516  			match = true
 517  			break
 518  		}
 519  	}
 520  
 521  	if !match {
 522  		return TypeMismatchError(fmt.Sprintf("error: type mismatch - can't unmarshal %v to %v",
 523  			val.Kind(), kinds[0]))
 524  	}
 525  
 526  	return nil
 527  }
 528  
 529  func buildStructFieldMap(fieldMap *map[string]reflect.Value, val reflect.Value) {
 530  	valType := val.Type()
 531  	valFieldNum := valType.NumField()
 532  	for i := 0; i < valFieldNum; i++ {
 533  		field := valType.Field(i)
 534  		fieldVal := val.FieldByName(field.Name)
 535  
 536  		if field.Anonymous {
 537  			// Drill down into embedded structs
 538  			buildStructFieldMap(fieldMap, fieldVal)
 539  			continue
 540  		}
 541  
 542  		if fieldVal.CanSet() {
 543  			if fn := field.Tag.Get("xmlrpc"); fn != "" {
 544  				fn = strings.Split(fn, ",")[0]
 545  				(*fieldMap)[fn] = fieldVal
 546  			} else {
 547  				(*fieldMap)[field.Name] = fieldVal
 548  			}
 549  		}
 550  	}
 551  }
 552  
 553  // http://stackoverflow.com/a/34712322/3160958
 554  // https://groups.google.com/forum/#!topic/golang-nuts/VudK_05B62k
 555  func defaultCharsetReader(charset string, input io.Reader) (io.Reader, error) {
 556  	if charset == "iso-8859-1" || charset == "ISO-8859-1" {
 557  		return charmap.ISO8859_1.NewDecoder().Reader(input), nil
 558  	} else if strings.HasPrefix(charset, "utf") || strings.HasPrefix(charset, "UTF") {
 559  		return input, nil
 560  	}
 561  
 562  	return nil, fmt.Errorf("Unknown charset: %s", charset)
 563  }
 564