decoder.go raw

   1  package xmlrpc
   2  
   3  import (
   4  	"bytes"
   5  	"encoding/xml"
   6  	"errors"
   7  	"fmt"
   8  	"io"
   9  	"reflect"
  10  	"strconv"
  11  	"strings"
  12  	"time"
  13  )
  14  
  15  const (
  16  	iso8601        = "20060102T15:04:05"
  17  	iso8601Z       = "20060102T15:04:05Z07:00"
  18  	iso8601Hyphen  = "2006-01-02T15:04:05"
  19  	iso8601HyphenZ = "2006-01-02T15:04:05Z07:00"
  20  )
  21  
  22  var (
  23  	// CharsetReader is a function to generate reader which converts a non UTF-8
  24  	// charset into UTF-8.
  25  	CharsetReader func(string, io.Reader) (io.Reader, error)
  26  
  27  	timeLayouts     = []string{iso8601, iso8601Z, iso8601Hyphen, iso8601HyphenZ}
  28  	invalidXmlError = errors.New("invalid xml")
  29  )
  30  
  31  type TypeMismatchError string
  32  
  33  func (e TypeMismatchError) Error() string { return string(e) }
  34  
  35  type decoder struct {
  36  	*xml.Decoder
  37  }
  38  
  39  func unmarshal(data []byte, v interface{}) (err error) {
  40  	dec := &decoder{xml.NewDecoder(bytes.NewBuffer(data))}
  41  
  42  	if CharsetReader != nil {
  43  		dec.CharsetReader = CharsetReader
  44  	}
  45  
  46  	var tok xml.Token
  47  	for {
  48  		if tok, err = dec.Token(); err != nil {
  49  			return err
  50  		}
  51  
  52  		if t, ok := tok.(xml.StartElement); ok {
  53  			if t.Name.Local == "value" {
  54  				val := reflect.ValueOf(v)
  55  				if val.Kind() != reflect.Ptr {
  56  					return errors.New("non-pointer value passed to unmarshal")
  57  				}
  58  				if err = dec.decodeValue(val.Elem()); err != nil {
  59  					return err
  60  				}
  61  
  62  				break
  63  			}
  64  		}
  65  	}
  66  
  67  	// read until end of document
  68  	err = dec.Skip()
  69  	if err != nil && err != io.EOF {
  70  		return err
  71  	}
  72  
  73  	return nil
  74  }
  75  
  76  func (dec *decoder) decodeValue(val reflect.Value) error {
  77  	var tok xml.Token
  78  	var err error
  79  
  80  	if val.Kind() == reflect.Ptr {
  81  		if val.IsNil() {
  82  			val.Set(reflect.New(val.Type().Elem()))
  83  		}
  84  		val = val.Elem()
  85  	}
  86  
  87  	var typeName string
  88  	for {
  89  		if tok, err = dec.Token(); err != nil {
  90  			return err
  91  		}
  92  
  93  		if t, ok := tok.(xml.EndElement); ok {
  94  			if t.Name.Local == "value" {
  95  				return nil
  96  			} else {
  97  				return invalidXmlError
  98  			}
  99  		}
 100  
 101  		if t, ok := tok.(xml.StartElement); ok {
 102  			typeName = t.Name.Local
 103  			break
 104  		}
 105  
 106  		// Treat value data without type identifier as string
 107  		if t, ok := tok.(xml.CharData); ok {
 108  			if value := strings.TrimSpace(string(t)); value != "" {
 109  				if err = checkType(val, reflect.String); err != nil {
 110  					return err
 111  				}
 112  
 113  				val.SetString(value)
 114  				return nil
 115  			}
 116  		}
 117  	}
 118  
 119  	switch typeName {
 120  	case "struct":
 121  		ismap := false
 122  		pmap := val
 123  		valType := val.Type()
 124  
 125  		if err = checkType(val, reflect.Struct); err != nil {
 126  			if checkType(val, reflect.Map) == nil {
 127  				if valType.Key().Kind() != reflect.String {
 128  					return fmt.Errorf("only maps with string key type can be unmarshalled")
 129  				}
 130  				ismap = true
 131  			} else if checkType(val, reflect.Interface) == nil && val.IsNil() {
 132  				var dummy map[string]interface{}
 133  				valType = reflect.TypeOf(dummy)
 134  				pmap = reflect.New(valType).Elem()
 135  				val.Set(pmap)
 136  				ismap = true
 137  			} else {
 138  				return err
 139  			}
 140  		}
 141  
 142  		var fields map[string]reflect.Value
 143  
 144  		if !ismap {
 145  			fields = make(map[string]reflect.Value)
 146  
 147  			for i := 0; i < valType.NumField(); i++ {
 148  				field := valType.Field(i)
 149  				fieldVal := val.FieldByName(field.Name)
 150  
 151  				if fieldVal.CanSet() {
 152  					name := field.Tag.Get("xmlrpc")
 153  					name = strings.TrimSuffix(name, ",omitempty")
 154  					if name == "-" {
 155  						continue
 156  					}
 157  					if name == "" {
 158  						name = field.Name
 159  					}
 160  					fields[name] = fieldVal
 161  				}
 162  			}
 163  		} else {
 164  			// Create initial empty map
 165  			pmap.Set(reflect.MakeMap(valType))
 166  		}
 167  
 168  		// Process struct members.
 169  	StructLoop:
 170  		for {
 171  			if tok, err = dec.Token(); err != nil {
 172  				return err
 173  			}
 174  			switch t := tok.(type) {
 175  			case xml.StartElement:
 176  				if t.Name.Local != "member" {
 177  					return invalidXmlError
 178  				}
 179  
 180  				tagName, fieldName, err := dec.readTag()
 181  				if err != nil {
 182  					return err
 183  				}
 184  				if tagName != "name" {
 185  					return invalidXmlError
 186  				}
 187  
 188  				var fv reflect.Value
 189  				ok := true
 190  
 191  				if !ismap {
 192  					fv, ok = fields[string(fieldName)]
 193  				} else {
 194  					fv = reflect.New(valType.Elem())
 195  				}
 196  
 197  				if ok {
 198  					for {
 199  						if tok, err = dec.Token(); err != nil {
 200  							return err
 201  						}
 202  						if t, ok := tok.(xml.StartElement); ok && t.Name.Local == "value" {
 203  							if err = dec.decodeValue(fv); err != nil {
 204  								return err
 205  							}
 206  
 207  							// </value>
 208  							if err = dec.Skip(); err != nil {
 209  								return err
 210  							}
 211  
 212  							break
 213  						}
 214  					}
 215  				}
 216  
 217  				// </member>
 218  				if err = dec.Skip(); err != nil {
 219  					return err
 220  				}
 221  
 222  				if ismap {
 223  					pmap.SetMapIndex(reflect.ValueOf(string(fieldName)), reflect.Indirect(fv))
 224  					val.Set(pmap)
 225  				}
 226  			case xml.EndElement:
 227  				break StructLoop
 228  			}
 229  		}
 230  	case "array":
 231  		slice := val
 232  		if checkType(val, reflect.Interface) == nil && val.IsNil() {
 233  			slice = reflect.ValueOf([]interface{}{})
 234  		} else if err = checkType(val, reflect.Slice); err != nil {
 235  			return err
 236  		}
 237  
 238  	ArrayLoop:
 239  		for {
 240  			if tok, err = dec.Token(); err != nil {
 241  				return err
 242  			}
 243  
 244  			switch t := tok.(type) {
 245  			case xml.StartElement:
 246  				var index int
 247  				if t.Name.Local != "data" {
 248  					return invalidXmlError
 249  				}
 250  			DataLoop:
 251  				for {
 252  					if tok, err = dec.Token(); err != nil {
 253  						return err
 254  					}
 255  
 256  					switch tt := tok.(type) {
 257  					case xml.StartElement:
 258  						if tt.Name.Local != "value" {
 259  							return invalidXmlError
 260  						}
 261  
 262  						if index < slice.Len() {
 263  							v := slice.Index(index)
 264  							if v.Kind() == reflect.Interface {
 265  								v = v.Elem()
 266  							}
 267  							if v.Kind() != reflect.Ptr {
 268  								return errors.New("error: cannot write to non-pointer array element")
 269  							}
 270  							if err = dec.decodeValue(v); err != nil {
 271  								return err
 272  							}
 273  						} else {
 274  							v := reflect.New(slice.Type().Elem())
 275  							if err = dec.decodeValue(v); err != nil {
 276  								return err
 277  							}
 278  							slice = reflect.Append(slice, v.Elem())
 279  						}
 280  
 281  						// </value>
 282  						if err = dec.Skip(); err != nil {
 283  							return err
 284  						}
 285  						index++
 286  					case xml.EndElement:
 287  						val.Set(slice)
 288  						break DataLoop
 289  					}
 290  				}
 291  			case xml.EndElement:
 292  				break ArrayLoop
 293  			}
 294  		}
 295  	default:
 296  		if tok, err = dec.Token(); err != nil {
 297  			return err
 298  		}
 299  
 300  		var data []byte
 301  
 302  		switch t := tok.(type) {
 303  		case xml.EndElement:
 304  			return nil
 305  		case xml.CharData:
 306  			data = []byte(t.Copy())
 307  		default:
 308  			return invalidXmlError
 309  		}
 310  
 311  		switch typeName {
 312  		case "int", "i4", "i8":
 313  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 314  				i, err := strconv.ParseInt(string(data), 10, 64)
 315  				if err != nil {
 316  					return err
 317  				}
 318  
 319  				pi := reflect.New(reflect.TypeOf(i)).Elem()
 320  				pi.SetInt(i)
 321  				val.Set(pi)
 322  			} else if err = checkType(val, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64); err != nil {
 323  				return err
 324  			} else {
 325  				i, err := strconv.ParseInt(string(data), 10, val.Type().Bits())
 326  				if err != nil {
 327  					return err
 328  				}
 329  
 330  				val.SetInt(i)
 331  			}
 332  		case "string", "base64":
 333  			str := string(data)
 334  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 335  				pstr := reflect.New(reflect.TypeOf(str)).Elem()
 336  				pstr.SetString(str)
 337  				val.Set(pstr)
 338  			} else if err = checkType(val, reflect.String); err != nil {
 339  				return err
 340  			} else {
 341  				val.SetString(str)
 342  			}
 343  		case "dateTime.iso8601":
 344  			var t time.Time
 345  			var err error
 346  
 347  			for _, layout := range timeLayouts {
 348  				t, err = time.Parse(layout, string(data))
 349  				if err == nil {
 350  					break
 351  				}
 352  			}
 353  			if err != nil {
 354  				return err
 355  			}
 356  
 357  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 358  				ptime := reflect.New(reflect.TypeOf(t)).Elem()
 359  				ptime.Set(reflect.ValueOf(t))
 360  				val.Set(ptime)
 361  			} else if _, ok := val.Interface().(time.Time); !ok {
 362  				return TypeMismatchError(fmt.Sprintf("error: type mismatch error - can't decode %v to time", val.Kind()))
 363  			} else {
 364  				val.Set(reflect.ValueOf(t))
 365  			}
 366  		case "boolean":
 367  			v, err := strconv.ParseBool(string(data))
 368  			if err != nil {
 369  				return err
 370  			}
 371  
 372  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 373  				pv := reflect.New(reflect.TypeOf(v)).Elem()
 374  				pv.SetBool(v)
 375  				val.Set(pv)
 376  			} else if err = checkType(val, reflect.Bool); err != nil {
 377  				return err
 378  			} else {
 379  				val.SetBool(v)
 380  			}
 381  		case "double":
 382  			if checkType(val, reflect.Interface) == nil && val.IsNil() {
 383  				i, err := strconv.ParseFloat(string(data), 64)
 384  				if err != nil {
 385  					return err
 386  				}
 387  
 388  				pdouble := reflect.New(reflect.TypeOf(i)).Elem()
 389  				pdouble.SetFloat(i)
 390  				val.Set(pdouble)
 391  			} else if err = checkType(val, reflect.Float32, reflect.Float64); err != nil {
 392  				return err
 393  			} else {
 394  				i, err := strconv.ParseFloat(string(data), val.Type().Bits())
 395  				if err != nil {
 396  					return err
 397  				}
 398  
 399  				val.SetFloat(i)
 400  			}
 401  		default:
 402  			return errors.New("unsupported type")
 403  		}
 404  
 405  		// </type>
 406  		if err = dec.Skip(); err != nil {
 407  			return err
 408  		}
 409  	}
 410  
 411  	return nil
 412  }
 413  
 414  func (dec *decoder) readTag() (string, []byte, error) {
 415  	var tok xml.Token
 416  	var err error
 417  
 418  	var name string
 419  	for {
 420  		if tok, err = dec.Token(); err != nil {
 421  			return "", nil, err
 422  		}
 423  
 424  		if t, ok := tok.(xml.StartElement); ok {
 425  			name = t.Name.Local
 426  			break
 427  		}
 428  	}
 429  
 430  	value, err := dec.readCharData()
 431  	if err != nil {
 432  		return "", nil, err
 433  	}
 434  
 435  	return name, value, dec.Skip()
 436  }
 437  
 438  func (dec *decoder) readCharData() ([]byte, error) {
 439  	var tok xml.Token
 440  	var err error
 441  
 442  	if tok, err = dec.Token(); err != nil {
 443  		return nil, err
 444  	}
 445  
 446  	if t, ok := tok.(xml.CharData); ok {
 447  		return []byte(t.Copy()), nil
 448  	} else {
 449  		return nil, invalidXmlError
 450  	}
 451  }
 452  
 453  func checkType(val reflect.Value, kinds ...reflect.Kind) error {
 454  	if len(kinds) == 0 {
 455  		return nil
 456  	}
 457  
 458  	if val.Kind() == reflect.Ptr {
 459  		val = val.Elem()
 460  	}
 461  
 462  	match := false
 463  
 464  	for _, kind := range kinds {
 465  		if val.Kind() == kind {
 466  			match = true
 467  			break
 468  		}
 469  	}
 470  
 471  	if !match {
 472  		return TypeMismatchError(fmt.Sprintf("error: type mismatch - can't unmarshal %v to %v",
 473  			val.Kind(), kinds[0]))
 474  	}
 475  
 476  	return nil
 477  }
 478