reflect.go raw

   1  // Copyright 2013 Google Inc.  All rights reserved.
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //     http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package pretty
  16  
  17  import (
  18  	"encoding"
  19  	"fmt"
  20  	"reflect"
  21  	"sort"
  22  )
  23  
  24  func isZeroVal(val reflect.Value) bool {
  25  	if !val.CanInterface() {
  26  		return false
  27  	}
  28  	z := reflect.Zero(val.Type()).Interface()
  29  	return reflect.DeepEqual(val.Interface(), z)
  30  }
  31  
  32  // pointerTracker is a helper for tracking pointer chasing to detect cycles.
  33  type pointerTracker struct {
  34  	addrs map[uintptr]int // addr[address] = seen count
  35  
  36  	lastID int
  37  	ids    map[uintptr]int // ids[address] = id
  38  }
  39  
  40  // track tracks following a reference (pointer, slice, map, etc).  Every call to
  41  // track should be paired with a call to untrack.
  42  func (p *pointerTracker) track(ptr uintptr) {
  43  	if p.addrs == nil {
  44  		p.addrs = make(map[uintptr]int)
  45  	}
  46  	p.addrs[ptr]++
  47  }
  48  
  49  // untrack registers that we have backtracked over the reference to the pointer.
  50  func (p *pointerTracker) untrack(ptr uintptr) {
  51  	p.addrs[ptr]--
  52  	if p.addrs[ptr] == 0 {
  53  		delete(p.addrs, ptr)
  54  	}
  55  }
  56  
  57  // seen returns whether the pointer was previously seen along this path.
  58  func (p *pointerTracker) seen(ptr uintptr) bool {
  59  	_, ok := p.addrs[ptr]
  60  	return ok
  61  }
  62  
  63  // keep allocates an ID for the given address and returns it.
  64  func (p *pointerTracker) keep(ptr uintptr) int {
  65  	if p.ids == nil {
  66  		p.ids = make(map[uintptr]int)
  67  	}
  68  	if _, ok := p.ids[ptr]; !ok {
  69  		p.lastID++
  70  		p.ids[ptr] = p.lastID
  71  	}
  72  	return p.ids[ptr]
  73  }
  74  
  75  // id returns the ID for the given address.
  76  func (p *pointerTracker) id(ptr uintptr) (int, bool) {
  77  	if p.ids == nil {
  78  		p.ids = make(map[uintptr]int)
  79  	}
  80  	id, ok := p.ids[ptr]
  81  	return id, ok
  82  }
  83  
  84  // reflector adds local state to the recursive reflection logic.
  85  type reflector struct {
  86  	*Config
  87  	*pointerTracker
  88  }
  89  
  90  // follow handles following a possiblly-recursive reference to the given value
  91  // from the given ptr address.
  92  func (r *reflector) follow(ptr uintptr, val reflect.Value) node {
  93  	if r.pointerTracker == nil {
  94  		// Tracking disabled
  95  		return r.val2node(val)
  96  	}
  97  
  98  	// If a parent already followed this, emit a reference marker
  99  	if r.seen(ptr) {
 100  		id := r.keep(ptr)
 101  		return ref{id}
 102  	}
 103  
 104  	// Track the pointer we're following while on this recursive branch
 105  	r.track(ptr)
 106  	defer r.untrack(ptr)
 107  	n := r.val2node(val)
 108  
 109  	// If the recursion used this ptr, wrap it with a target marker
 110  	if id, ok := r.id(ptr); ok {
 111  		return target{id, n}
 112  	}
 113  
 114  	// Otherwise, return the node unadulterated
 115  	return n
 116  }
 117  
 118  func (r *reflector) val2node(val reflect.Value) node {
 119  	if !val.IsValid() {
 120  		return rawVal("nil")
 121  	}
 122  
 123  	if val.CanInterface() {
 124  		v := val.Interface()
 125  		if formatter, ok := r.Formatter[val.Type()]; ok {
 126  			if formatter != nil {
 127  				res := reflect.ValueOf(formatter).Call([]reflect.Value{val})
 128  				return rawVal(res[0].Interface().(string))
 129  			}
 130  		} else {
 131  			if s, ok := v.(fmt.Stringer); ok && r.PrintStringers {
 132  				return stringVal(s.String())
 133  			}
 134  			if t, ok := v.(encoding.TextMarshaler); ok && r.PrintTextMarshalers {
 135  				if raw, err := t.MarshalText(); err == nil { // if NOT an error
 136  					return stringVal(string(raw))
 137  				}
 138  			}
 139  		}
 140  	}
 141  
 142  	switch kind := val.Kind(); kind {
 143  	case reflect.Ptr:
 144  		if val.IsNil() {
 145  			return rawVal("nil")
 146  		}
 147  		return r.follow(val.Pointer(), val.Elem())
 148  	case reflect.Interface:
 149  		if val.IsNil() {
 150  			return rawVal("nil")
 151  		}
 152  		return r.val2node(val.Elem())
 153  	case reflect.String:
 154  		return stringVal(val.String())
 155  	case reflect.Slice:
 156  		n := list{}
 157  		length := val.Len()
 158  		ptr := val.Pointer()
 159  		for i := 0; i < length; i++ {
 160  			n = append(n, r.follow(ptr, val.Index(i)))
 161  		}
 162  		return n
 163  	case reflect.Array:
 164  		n := list{}
 165  		length := val.Len()
 166  		for i := 0; i < length; i++ {
 167  			n = append(n, r.val2node(val.Index(i)))
 168  		}
 169  		return n
 170  	case reflect.Map:
 171  		// Extract the keys and sort them for stable iteration
 172  		keys := val.MapKeys()
 173  		pairs := make([]mapPair, 0, len(keys))
 174  		for _, key := range keys {
 175  			pairs = append(pairs, mapPair{
 176  				key:   new(formatter).compactString(r.val2node(key)), // can't be cyclic
 177  				value: val.MapIndex(key),
 178  			})
 179  		}
 180  		sort.Sort(byKey(pairs))
 181  
 182  		// Process the keys into the final representation
 183  		ptr, n := val.Pointer(), keyvals{}
 184  		for _, pair := range pairs {
 185  			n = append(n, keyval{
 186  				key: pair.key,
 187  				val: r.follow(ptr, pair.value),
 188  			})
 189  		}
 190  		return n
 191  	case reflect.Struct:
 192  		n := keyvals{}
 193  		typ := val.Type()
 194  		fields := typ.NumField()
 195  		for i := 0; i < fields; i++ {
 196  			sf := typ.Field(i)
 197  			if !r.IncludeUnexported && sf.PkgPath != "" {
 198  				continue
 199  			}
 200  			field := val.Field(i)
 201  			if r.SkipZeroFields && isZeroVal(field) {
 202  				continue
 203  			}
 204  			n = append(n, keyval{sf.Name, r.val2node(field)})
 205  		}
 206  		return n
 207  	case reflect.Bool:
 208  		if val.Bool() {
 209  			return rawVal("true")
 210  		}
 211  		return rawVal("false")
 212  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 213  		return rawVal(fmt.Sprintf("%d", val.Int()))
 214  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 215  		return rawVal(fmt.Sprintf("%d", val.Uint()))
 216  	case reflect.Uintptr:
 217  		return rawVal(fmt.Sprintf("0x%X", val.Uint()))
 218  	case reflect.Float32, reflect.Float64:
 219  		return rawVal(fmt.Sprintf("%v", val.Float()))
 220  	case reflect.Complex64, reflect.Complex128:
 221  		return rawVal(fmt.Sprintf("%v", val.Complex()))
 222  	}
 223  
 224  	// Fall back to the default %#v if we can
 225  	if val.CanInterface() {
 226  		return rawVal(fmt.Sprintf("%#v", val.Interface()))
 227  	}
 228  
 229  	return rawVal(val.String())
 230  }
 231  
 232  type mapPair struct {
 233  	key   string
 234  	value reflect.Value
 235  }
 236  
 237  type byKey []mapPair
 238  
 239  func (v byKey) Len() int           { return len(v) }
 240  func (v byKey) Swap(i, j int)      { v[i], v[j] = v[j], v[i] }
 241  func (v byKey) Less(i, j int) bool { return v[i].key < v[j].key }
 242