1 // Copyright 2013 Google Inc. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 15 package pretty
16 17 import (
18 "encoding"
19 "fmt"
20 "reflect"
21 "sort"
22 )
23 24 func isZeroVal(val reflect.Value) bool {
25 if !val.CanInterface() {
26 return false
27 }
28 z := reflect.Zero(val.Type()).Interface()
29 return reflect.DeepEqual(val.Interface(), z)
30 }
31 32 // pointerTracker is a helper for tracking pointer chasing to detect cycles.
33 type pointerTracker struct {
34 addrs map[uintptr]int // addr[address] = seen count
35 36 lastID int
37 ids map[uintptr]int // ids[address] = id
38 }
39 40 // track tracks following a reference (pointer, slice, map, etc). Every call to
41 // track should be paired with a call to untrack.
42 func (p *pointerTracker) track(ptr uintptr) {
43 if p.addrs == nil {
44 p.addrs = make(map[uintptr]int)
45 }
46 p.addrs[ptr]++
47 }
48 49 // untrack registers that we have backtracked over the reference to the pointer.
50 func (p *pointerTracker) untrack(ptr uintptr) {
51 p.addrs[ptr]--
52 if p.addrs[ptr] == 0 {
53 delete(p.addrs, ptr)
54 }
55 }
56 57 // seen returns whether the pointer was previously seen along this path.
58 func (p *pointerTracker) seen(ptr uintptr) bool {
59 _, ok := p.addrs[ptr]
60 return ok
61 }
62 63 // keep allocates an ID for the given address and returns it.
64 func (p *pointerTracker) keep(ptr uintptr) int {
65 if p.ids == nil {
66 p.ids = make(map[uintptr]int)
67 }
68 if _, ok := p.ids[ptr]; !ok {
69 p.lastID++
70 p.ids[ptr] = p.lastID
71 }
72 return p.ids[ptr]
73 }
74 75 // id returns the ID for the given address.
76 func (p *pointerTracker) id(ptr uintptr) (int, bool) {
77 if p.ids == nil {
78 p.ids = make(map[uintptr]int)
79 }
80 id, ok := p.ids[ptr]
81 return id, ok
82 }
83 84 // reflector adds local state to the recursive reflection logic.
85 type reflector struct {
86 *Config
87 *pointerTracker
88 }
89 90 // follow handles following a possiblly-recursive reference to the given value
91 // from the given ptr address.
92 func (r *reflector) follow(ptr uintptr, val reflect.Value) node {
93 if r.pointerTracker == nil {
94 // Tracking disabled
95 return r.val2node(val)
96 }
97 98 // If a parent already followed this, emit a reference marker
99 if r.seen(ptr) {
100 id := r.keep(ptr)
101 return ref{id}
102 }
103 104 // Track the pointer we're following while on this recursive branch
105 r.track(ptr)
106 defer r.untrack(ptr)
107 n := r.val2node(val)
108 109 // If the recursion used this ptr, wrap it with a target marker
110 if id, ok := r.id(ptr); ok {
111 return target{id, n}
112 }
113 114 // Otherwise, return the node unadulterated
115 return n
116 }
117 118 func (r *reflector) val2node(val reflect.Value) node {
119 if !val.IsValid() {
120 return rawVal("nil")
121 }
122 123 if val.CanInterface() {
124 v := val.Interface()
125 if formatter, ok := r.Formatter[val.Type()]; ok {
126 if formatter != nil {
127 res := reflect.ValueOf(formatter).Call([]reflect.Value{val})
128 return rawVal(res[0].Interface().(string))
129 }
130 } else {
131 if s, ok := v.(fmt.Stringer); ok && r.PrintStringers {
132 return stringVal(s.String())
133 }
134 if t, ok := v.(encoding.TextMarshaler); ok && r.PrintTextMarshalers {
135 if raw, err := t.MarshalText(); err == nil { // if NOT an error
136 return stringVal(string(raw))
137 }
138 }
139 }
140 }
141 142 switch kind := val.Kind(); kind {
143 case reflect.Ptr:
144 if val.IsNil() {
145 return rawVal("nil")
146 }
147 return r.follow(val.Pointer(), val.Elem())
148 case reflect.Interface:
149 if val.IsNil() {
150 return rawVal("nil")
151 }
152 return r.val2node(val.Elem())
153 case reflect.String:
154 return stringVal(val.String())
155 case reflect.Slice:
156 n := list{}
157 length := val.Len()
158 ptr := val.Pointer()
159 for i := 0; i < length; i++ {
160 n = append(n, r.follow(ptr, val.Index(i)))
161 }
162 return n
163 case reflect.Array:
164 n := list{}
165 length := val.Len()
166 for i := 0; i < length; i++ {
167 n = append(n, r.val2node(val.Index(i)))
168 }
169 return n
170 case reflect.Map:
171 // Extract the keys and sort them for stable iteration
172 keys := val.MapKeys()
173 pairs := make([]mapPair, 0, len(keys))
174 for _, key := range keys {
175 pairs = append(pairs, mapPair{
176 key: new(formatter).compactString(r.val2node(key)), // can't be cyclic
177 value: val.MapIndex(key),
178 })
179 }
180 sort.Sort(byKey(pairs))
181 182 // Process the keys into the final representation
183 ptr, n := val.Pointer(), keyvals{}
184 for _, pair := range pairs {
185 n = append(n, keyval{
186 key: pair.key,
187 val: r.follow(ptr, pair.value),
188 })
189 }
190 return n
191 case reflect.Struct:
192 n := keyvals{}
193 typ := val.Type()
194 fields := typ.NumField()
195 for i := 0; i < fields; i++ {
196 sf := typ.Field(i)
197 if !r.IncludeUnexported && sf.PkgPath != "" {
198 continue
199 }
200 field := val.Field(i)
201 if r.SkipZeroFields && isZeroVal(field) {
202 continue
203 }
204 n = append(n, keyval{sf.Name, r.val2node(field)})
205 }
206 return n
207 case reflect.Bool:
208 if val.Bool() {
209 return rawVal("true")
210 }
211 return rawVal("false")
212 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
213 return rawVal(fmt.Sprintf("%d", val.Int()))
214 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
215 return rawVal(fmt.Sprintf("%d", val.Uint()))
216 case reflect.Uintptr:
217 return rawVal(fmt.Sprintf("0x%X", val.Uint()))
218 case reflect.Float32, reflect.Float64:
219 return rawVal(fmt.Sprintf("%v", val.Float()))
220 case reflect.Complex64, reflect.Complex128:
221 return rawVal(fmt.Sprintf("%v", val.Complex()))
222 }
223 224 // Fall back to the default %#v if we can
225 if val.CanInterface() {
226 return rawVal(fmt.Sprintf("%#v", val.Interface()))
227 }
228 229 return rawVal(val.String())
230 }
231 232 type mapPair struct {
233 key string
234 value reflect.Value
235 }
236 237 type byKey []mapPair
238 239 func (v byKey) Len() int { return len(v) }
240 func (v byKey) Swap(i, j int) { v[i], v[j] = v[j], v[i] }
241 func (v byKey) Less(i, j int) bool { return v[i].key < v[j].key }
242