decoder.go raw
1 package xmlrpc
2
3 import (
4 "bytes"
5 "encoding/xml"
6 "errors"
7 "fmt"
8 "io"
9 "reflect"
10 "strconv"
11 "strings"
12 "time"
13 )
14
15 const (
16 iso8601 = "20060102T15:04:05"
17 iso8601Z = "20060102T15:04:05Z07:00"
18 iso8601Hyphen = "2006-01-02T15:04:05"
19 iso8601HyphenZ = "2006-01-02T15:04:05Z07:00"
20 )
21
22 var (
23 // CharsetReader is a function to generate reader which converts a non UTF-8
24 // charset into UTF-8.
25 CharsetReader func(string, io.Reader) (io.Reader, error)
26
27 timeLayouts = []string{iso8601, iso8601Z, iso8601Hyphen, iso8601HyphenZ}
28 invalidXmlError = errors.New("invalid xml")
29 )
30
31 type TypeMismatchError string
32
33 func (e TypeMismatchError) Error() string { return string(e) }
34
35 type decoder struct {
36 *xml.Decoder
37 }
38
39 func unmarshal(data []byte, v interface{}) (err error) {
40 dec := &decoder{xml.NewDecoder(bytes.NewBuffer(data))}
41
42 if CharsetReader != nil {
43 dec.CharsetReader = CharsetReader
44 }
45
46 var tok xml.Token
47 for {
48 if tok, err = dec.Token(); err != nil {
49 return err
50 }
51
52 if t, ok := tok.(xml.StartElement); ok {
53 if t.Name.Local == "value" {
54 val := reflect.ValueOf(v)
55 if val.Kind() != reflect.Ptr {
56 return errors.New("non-pointer value passed to unmarshal")
57 }
58 if err = dec.decodeValue(val.Elem()); err != nil {
59 return err
60 }
61
62 break
63 }
64 }
65 }
66
67 // read until end of document
68 err = dec.Skip()
69 if err != nil && err != io.EOF {
70 return err
71 }
72
73 return nil
74 }
75
76 func (dec *decoder) decodeValue(val reflect.Value) error {
77 var tok xml.Token
78 var err error
79
80 if val.Kind() == reflect.Ptr {
81 if val.IsNil() {
82 val.Set(reflect.New(val.Type().Elem()))
83 }
84 val = val.Elem()
85 }
86
87 var typeName string
88 for {
89 if tok, err = dec.Token(); err != nil {
90 return err
91 }
92
93 if t, ok := tok.(xml.EndElement); ok {
94 if t.Name.Local == "value" {
95 return nil
96 } else {
97 return invalidXmlError
98 }
99 }
100
101 if t, ok := tok.(xml.StartElement); ok {
102 typeName = t.Name.Local
103 break
104 }
105
106 // Treat value data without type identifier as string
107 if t, ok := tok.(xml.CharData); ok {
108 if value := strings.TrimSpace(string(t)); value != "" {
109 if err = checkType(val, reflect.String); err != nil {
110 return err
111 }
112
113 val.SetString(value)
114 return nil
115 }
116 }
117 }
118
119 switch typeName {
120 case "struct":
121 ismap := false
122 pmap := val
123 valType := val.Type()
124
125 if err = checkType(val, reflect.Struct); err != nil {
126 if checkType(val, reflect.Map) == nil {
127 if valType.Key().Kind() != reflect.String {
128 return fmt.Errorf("only maps with string key type can be unmarshalled")
129 }
130 ismap = true
131 } else if checkType(val, reflect.Interface) == nil && val.IsNil() {
132 var dummy map[string]interface{}
133 valType = reflect.TypeOf(dummy)
134 pmap = reflect.New(valType).Elem()
135 val.Set(pmap)
136 ismap = true
137 } else {
138 return err
139 }
140 }
141
142 var fields map[string]reflect.Value
143
144 if !ismap {
145 fields = make(map[string]reflect.Value)
146
147 for i := 0; i < valType.NumField(); i++ {
148 field := valType.Field(i)
149 fieldVal := val.FieldByName(field.Name)
150
151 if fieldVal.CanSet() {
152 name := field.Tag.Get("xmlrpc")
153 name = strings.TrimSuffix(name, ",omitempty")
154 if name == "-" {
155 continue
156 }
157 if name == "" {
158 name = field.Name
159 }
160 fields[name] = fieldVal
161 }
162 }
163 } else {
164 // Create initial empty map
165 pmap.Set(reflect.MakeMap(valType))
166 }
167
168 // Process struct members.
169 StructLoop:
170 for {
171 if tok, err = dec.Token(); err != nil {
172 return err
173 }
174 switch t := tok.(type) {
175 case xml.StartElement:
176 if t.Name.Local != "member" {
177 return invalidXmlError
178 }
179
180 tagName, fieldName, err := dec.readTag()
181 if err != nil {
182 return err
183 }
184 if tagName != "name" {
185 return invalidXmlError
186 }
187
188 var fv reflect.Value
189 ok := true
190
191 if !ismap {
192 fv, ok = fields[string(fieldName)]
193 } else {
194 fv = reflect.New(valType.Elem())
195 }
196
197 if ok {
198 for {
199 if tok, err = dec.Token(); err != nil {
200 return err
201 }
202 if t, ok := tok.(xml.StartElement); ok && t.Name.Local == "value" {
203 if err = dec.decodeValue(fv); err != nil {
204 return err
205 }
206
207 // </value>
208 if err = dec.Skip(); err != nil {
209 return err
210 }
211
212 break
213 }
214 }
215 }
216
217 // </member>
218 if err = dec.Skip(); err != nil {
219 return err
220 }
221
222 if ismap {
223 pmap.SetMapIndex(reflect.ValueOf(string(fieldName)), reflect.Indirect(fv))
224 val.Set(pmap)
225 }
226 case xml.EndElement:
227 break StructLoop
228 }
229 }
230 case "array":
231 slice := val
232 if checkType(val, reflect.Interface) == nil && val.IsNil() {
233 slice = reflect.ValueOf([]interface{}{})
234 } else if err = checkType(val, reflect.Slice); err != nil {
235 return err
236 }
237
238 ArrayLoop:
239 for {
240 if tok, err = dec.Token(); err != nil {
241 return err
242 }
243
244 switch t := tok.(type) {
245 case xml.StartElement:
246 var index int
247 if t.Name.Local != "data" {
248 return invalidXmlError
249 }
250 DataLoop:
251 for {
252 if tok, err = dec.Token(); err != nil {
253 return err
254 }
255
256 switch tt := tok.(type) {
257 case xml.StartElement:
258 if tt.Name.Local != "value" {
259 return invalidXmlError
260 }
261
262 if index < slice.Len() {
263 v := slice.Index(index)
264 if v.Kind() == reflect.Interface {
265 v = v.Elem()
266 }
267 if v.Kind() != reflect.Ptr {
268 return errors.New("error: cannot write to non-pointer array element")
269 }
270 if err = dec.decodeValue(v); err != nil {
271 return err
272 }
273 } else {
274 v := reflect.New(slice.Type().Elem())
275 if err = dec.decodeValue(v); err != nil {
276 return err
277 }
278 slice = reflect.Append(slice, v.Elem())
279 }
280
281 // </value>
282 if err = dec.Skip(); err != nil {
283 return err
284 }
285 index++
286 case xml.EndElement:
287 val.Set(slice)
288 break DataLoop
289 }
290 }
291 case xml.EndElement:
292 break ArrayLoop
293 }
294 }
295 default:
296 if tok, err = dec.Token(); err != nil {
297 return err
298 }
299
300 var data []byte
301
302 switch t := tok.(type) {
303 case xml.EndElement:
304 return nil
305 case xml.CharData:
306 data = []byte(t.Copy())
307 default:
308 return invalidXmlError
309 }
310
311 switch typeName {
312 case "int", "i4", "i8":
313 if checkType(val, reflect.Interface) == nil && val.IsNil() {
314 i, err := strconv.ParseInt(string(data), 10, 64)
315 if err != nil {
316 return err
317 }
318
319 pi := reflect.New(reflect.TypeOf(i)).Elem()
320 pi.SetInt(i)
321 val.Set(pi)
322 } else if err = checkType(val, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64); err != nil {
323 return err
324 } else {
325 i, err := strconv.ParseInt(string(data), 10, val.Type().Bits())
326 if err != nil {
327 return err
328 }
329
330 val.SetInt(i)
331 }
332 case "string", "base64":
333 str := string(data)
334 if checkType(val, reflect.Interface) == nil && val.IsNil() {
335 pstr := reflect.New(reflect.TypeOf(str)).Elem()
336 pstr.SetString(str)
337 val.Set(pstr)
338 } else if err = checkType(val, reflect.String); err != nil {
339 return err
340 } else {
341 val.SetString(str)
342 }
343 case "dateTime.iso8601":
344 var t time.Time
345 var err error
346
347 for _, layout := range timeLayouts {
348 t, err = time.Parse(layout, string(data))
349 if err == nil {
350 break
351 }
352 }
353 if err != nil {
354 return err
355 }
356
357 if checkType(val, reflect.Interface) == nil && val.IsNil() {
358 ptime := reflect.New(reflect.TypeOf(t)).Elem()
359 ptime.Set(reflect.ValueOf(t))
360 val.Set(ptime)
361 } else if _, ok := val.Interface().(time.Time); !ok {
362 return TypeMismatchError(fmt.Sprintf("error: type mismatch error - can't decode %v to time", val.Kind()))
363 } else {
364 val.Set(reflect.ValueOf(t))
365 }
366 case "boolean":
367 v, err := strconv.ParseBool(string(data))
368 if err != nil {
369 return err
370 }
371
372 if checkType(val, reflect.Interface) == nil && val.IsNil() {
373 pv := reflect.New(reflect.TypeOf(v)).Elem()
374 pv.SetBool(v)
375 val.Set(pv)
376 } else if err = checkType(val, reflect.Bool); err != nil {
377 return err
378 } else {
379 val.SetBool(v)
380 }
381 case "double":
382 if checkType(val, reflect.Interface) == nil && val.IsNil() {
383 i, err := strconv.ParseFloat(string(data), 64)
384 if err != nil {
385 return err
386 }
387
388 pdouble := reflect.New(reflect.TypeOf(i)).Elem()
389 pdouble.SetFloat(i)
390 val.Set(pdouble)
391 } else if err = checkType(val, reflect.Float32, reflect.Float64); err != nil {
392 return err
393 } else {
394 i, err := strconv.ParseFloat(string(data), val.Type().Bits())
395 if err != nil {
396 return err
397 }
398
399 val.SetFloat(i)
400 }
401 default:
402 return errors.New("unsupported type")
403 }
404
405 // </type>
406 if err = dec.Skip(); err != nil {
407 return err
408 }
409 }
410
411 return nil
412 }
413
414 func (dec *decoder) readTag() (string, []byte, error) {
415 var tok xml.Token
416 var err error
417
418 var name string
419 for {
420 if tok, err = dec.Token(); err != nil {
421 return "", nil, err
422 }
423
424 if t, ok := tok.(xml.StartElement); ok {
425 name = t.Name.Local
426 break
427 }
428 }
429
430 value, err := dec.readCharData()
431 if err != nil {
432 return "", nil, err
433 }
434
435 return name, value, dec.Skip()
436 }
437
438 func (dec *decoder) readCharData() ([]byte, error) {
439 var tok xml.Token
440 var err error
441
442 if tok, err = dec.Token(); err != nil {
443 return nil, err
444 }
445
446 if t, ok := tok.(xml.CharData); ok {
447 return []byte(t.Copy()), nil
448 } else {
449 return nil, invalidXmlError
450 }
451 }
452
453 func checkType(val reflect.Value, kinds ...reflect.Kind) error {
454 if len(kinds) == 0 {
455 return nil
456 }
457
458 if val.Kind() == reflect.Ptr {
459 val = val.Elem()
460 }
461
462 match := false
463
464 for _, kind := range kinds {
465 if val.Kind() == kind {
466 match = true
467 break
468 }
469 }
470
471 if !match {
472 return TypeMismatchError(fmt.Sprintf("error: type mismatch - can't unmarshal %v to %v",
473 val.Kind(), kinds[0]))
474 }
475
476 return nil
477 }
478