tracestate.go raw

   1  // Copyright The OpenTelemetry Authors
   2  // SPDX-License-Identifier: Apache-2.0
   3  
   4  package trace // import "go.opentelemetry.io/otel/trace"
   5  
   6  import (
   7  	"encoding/json"
   8  	"fmt"
   9  	"strings"
  10  )
  11  
  12  const (
  13  	maxListMembers = 32
  14  
  15  	listDelimiters  = ","
  16  	memberDelimiter = "="
  17  
  18  	errInvalidKey    errorConst = "invalid tracestate key"
  19  	errInvalidValue  errorConst = "invalid tracestate value"
  20  	errInvalidMember errorConst = "invalid tracestate list-member"
  21  	errMemberNumber  errorConst = "too many list-members in tracestate"
  22  	errDuplicate     errorConst = "duplicate list-member in tracestate"
  23  )
  24  
  25  type member struct {
  26  	Key   string
  27  	Value string
  28  }
  29  
  30  // according to (chr = %x20 / (nblk-char = %x21-2B / %x2D-3C / %x3E-7E) )
  31  // means (chr = %x20-2B / %x2D-3C / %x3E-7E) .
  32  func checkValueChar(v byte) bool {
  33  	return v >= '\x20' && v <= '\x7e' && v != '\x2c' && v != '\x3d'
  34  }
  35  
  36  // according to (nblk-chr = %x21-2B / %x2D-3C / %x3E-7E) .
  37  func checkValueLast(v byte) bool {
  38  	return v >= '\x21' && v <= '\x7e' && v != '\x2c' && v != '\x3d'
  39  }
  40  
  41  // based on the W3C Trace Context specification
  42  //
  43  //	value    = (0*255(chr)) nblk-chr
  44  //	nblk-chr = %x21-2B / %x2D-3C / %x3E-7E
  45  //	chr      = %x20 / nblk-chr
  46  //
  47  // see https://www.w3.org/TR/trace-context-1/#value
  48  func checkValue(val string) bool {
  49  	n := len(val)
  50  	if n == 0 || n > 256 {
  51  		return false
  52  	}
  53  	for i := 0; i < n-1; i++ {
  54  		if !checkValueChar(val[i]) {
  55  			return false
  56  		}
  57  	}
  58  	return checkValueLast(val[n-1])
  59  }
  60  
  61  func checkKeyRemain(key string) bool {
  62  	// ( lcalpha / DIGIT / "_" / "-"/ "*" / "/" )
  63  	for _, v := range key {
  64  		if isAlphaNum(byte(v)) {
  65  			continue
  66  		}
  67  		switch v {
  68  		case '_', '-', '*', '/':
  69  			continue
  70  		}
  71  		return false
  72  	}
  73  	return true
  74  }
  75  
  76  // according to
  77  //
  78  //	simple-key = lcalpha (0*255( lcalpha / DIGIT / "_" / "-"/ "*" / "/" ))
  79  //	system-id = lcalpha (0*13( lcalpha / DIGIT / "_" / "-"/ "*" / "/" ))
  80  //
  81  // param n is remain part length, should be 255 in simple-key or 13 in system-id.
  82  func checkKeyPart(key string, n int) bool {
  83  	if key == "" {
  84  		return false
  85  	}
  86  	first := key[0] // key's first char
  87  	ret := len(key[1:]) <= n
  88  	ret = ret && first >= 'a' && first <= 'z'
  89  	return ret && checkKeyRemain(key[1:])
  90  }
  91  
  92  func isAlphaNum(c byte) bool {
  93  	if c >= 'a' && c <= 'z' {
  94  		return true
  95  	}
  96  	return c >= '0' && c <= '9'
  97  }
  98  
  99  // according to
 100  //
 101  //	tenant-id = ( lcalpha / DIGIT ) 0*240( lcalpha / DIGIT / "_" / "-"/ "*" / "/" )
 102  //
 103  // param n is remain part length, should be 240 exactly.
 104  func checkKeyTenant(key string, n int) bool {
 105  	if key == "" {
 106  		return false
 107  	}
 108  	return isAlphaNum(key[0]) && len(key[1:]) <= n && checkKeyRemain(key[1:])
 109  }
 110  
 111  // based on the W3C Trace Context specification
 112  //
 113  //	key = simple-key / multi-tenant-key
 114  //	simple-key = lcalpha (0*255( lcalpha / DIGIT / "_" / "-"/ "*" / "/" ))
 115  //	multi-tenant-key = tenant-id "@" system-id
 116  //	tenant-id = ( lcalpha / DIGIT ) (0*240( lcalpha / DIGIT / "_" / "-"/ "*" / "/" ))
 117  //	system-id = lcalpha (0*13( lcalpha / DIGIT / "_" / "-"/ "*" / "/" ))
 118  //	lcalpha    = %x61-7A ; a-z
 119  //
 120  // see https://www.w3.org/TR/trace-context-1/#tracestate-header.
 121  func checkKey(key string) bool {
 122  	tenant, system, ok := strings.Cut(key, "@")
 123  	if !ok {
 124  		return checkKeyPart(key, 255)
 125  	}
 126  	return checkKeyTenant(tenant, 240) && checkKeyPart(system, 13)
 127  }
 128  
 129  func newMember(key, value string) (member, error) {
 130  	if !checkKey(key) {
 131  		return member{}, errInvalidKey
 132  	}
 133  	if !checkValue(value) {
 134  		return member{}, errInvalidValue
 135  	}
 136  	return member{Key: key, Value: value}, nil
 137  }
 138  
 139  func parseMember(m string) (member, error) {
 140  	key, val, ok := strings.Cut(m, memberDelimiter)
 141  	if !ok {
 142  		return member{}, fmt.Errorf("%w: %s", errInvalidMember, m)
 143  	}
 144  	key = strings.TrimLeft(key, " \t")
 145  	val = strings.TrimRight(val, " \t")
 146  	result, e := newMember(key, val)
 147  	if e != nil {
 148  		return member{}, fmt.Errorf("%w: %s", errInvalidMember, m)
 149  	}
 150  	return result, nil
 151  }
 152  
 153  // String encodes member into a string compliant with the W3C Trace Context
 154  // specification.
 155  func (m member) String() string {
 156  	return m.Key + "=" + m.Value
 157  }
 158  
 159  // TraceState provides additional vendor-specific trace identification
 160  // information across different distributed tracing systems. It represents an
 161  // immutable list consisting of key/value pairs, each pair is referred to as a
 162  // list-member.
 163  //
 164  // TraceState conforms to the W3C Trace Context specification
 165  // (https://www.w3.org/TR/trace-context-1). All operations that create or copy
 166  // a TraceState do so by validating all input and will only produce TraceState
 167  // that conform to the specification. Specifically, this means that all
 168  // list-member's key/value pairs are valid, no duplicate list-members exist,
 169  // and the maximum number of list-members (32) is not exceeded.
 170  type TraceState struct { //nolint:revive // revive complains about stutter of `trace.TraceState`
 171  	// list is the members in order.
 172  	list []member
 173  }
 174  
 175  var _ json.Marshaler = TraceState{}
 176  
 177  // ParseTraceState attempts to decode a TraceState from the passed
 178  // string. It returns an error if the input is invalid according to the W3C
 179  // Trace Context specification.
 180  func ParseTraceState(ts string) (TraceState, error) {
 181  	if ts == "" {
 182  		return TraceState{}, nil
 183  	}
 184  
 185  	wrapErr := func(err error) error {
 186  		return fmt.Errorf("failed to parse tracestate: %w", err)
 187  	}
 188  
 189  	var members []member
 190  	found := make(map[string]struct{})
 191  	for ts != "" {
 192  		var memberStr string
 193  		memberStr, ts, _ = strings.Cut(ts, listDelimiters)
 194  		if memberStr == "" {
 195  			continue
 196  		}
 197  
 198  		m, err := parseMember(memberStr)
 199  		if err != nil {
 200  			return TraceState{}, wrapErr(err)
 201  		}
 202  
 203  		if _, ok := found[m.Key]; ok {
 204  			return TraceState{}, wrapErr(errDuplicate)
 205  		}
 206  		found[m.Key] = struct{}{}
 207  
 208  		members = append(members, m)
 209  		if n := len(members); n > maxListMembers {
 210  			return TraceState{}, wrapErr(errMemberNumber)
 211  		}
 212  	}
 213  
 214  	return TraceState{list: members}, nil
 215  }
 216  
 217  // MarshalJSON marshals the TraceState into JSON.
 218  func (ts TraceState) MarshalJSON() ([]byte, error) {
 219  	return json.Marshal(ts.String())
 220  }
 221  
 222  // String encodes the TraceState into a string compliant with the W3C
 223  // Trace Context specification. The returned string will be invalid if the
 224  // TraceState contains any invalid members.
 225  func (ts TraceState) String() string {
 226  	if len(ts.list) == 0 {
 227  		return ""
 228  	}
 229  	var n int
 230  	n += len(ts.list)     // member delimiters: '='
 231  	n += len(ts.list) - 1 // list delimiters: ','
 232  	for _, mem := range ts.list {
 233  		n += len(mem.Key)
 234  		n += len(mem.Value)
 235  	}
 236  
 237  	var sb strings.Builder
 238  	sb.Grow(n)
 239  	_, _ = sb.WriteString(ts.list[0].Key)
 240  	_ = sb.WriteByte('=')
 241  	_, _ = sb.WriteString(ts.list[0].Value)
 242  	for i := 1; i < len(ts.list); i++ {
 243  		_ = sb.WriteByte(listDelimiters[0])
 244  		_, _ = sb.WriteString(ts.list[i].Key)
 245  		_ = sb.WriteByte('=')
 246  		_, _ = sb.WriteString(ts.list[i].Value)
 247  	}
 248  	return sb.String()
 249  }
 250  
 251  // Get returns the value paired with key from the corresponding TraceState
 252  // list-member if it exists, otherwise an empty string is returned.
 253  func (ts TraceState) Get(key string) string {
 254  	for _, member := range ts.list {
 255  		if member.Key == key {
 256  			return member.Value
 257  		}
 258  	}
 259  
 260  	return ""
 261  }
 262  
 263  // Walk walks all key value pairs in the TraceState by calling f
 264  // Iteration stops if f returns false.
 265  func (ts TraceState) Walk(f func(key, value string) bool) {
 266  	for _, m := range ts.list {
 267  		if !f(m.Key, m.Value) {
 268  			break
 269  		}
 270  	}
 271  }
 272  
 273  // Insert adds a new list-member defined by the key/value pair to the
 274  // TraceState. If a list-member already exists for the given key, that
 275  // list-member's value is updated. The new or updated list-member is always
 276  // moved to the beginning of the TraceState as specified by the W3C Trace
 277  // Context specification.
 278  //
 279  // If key or value are invalid according to the W3C Trace Context
 280  // specification an error is returned with the original TraceState.
 281  //
 282  // If adding a new list-member means the TraceState would have more members
 283  // then is allowed, the new list-member will be inserted and the right-most
 284  // list-member will be dropped in the returned TraceState.
 285  func (ts TraceState) Insert(key, value string) (TraceState, error) {
 286  	m, err := newMember(key, value)
 287  	if err != nil {
 288  		return ts, err
 289  	}
 290  	n := len(ts.list)
 291  	found := n
 292  	for i := range ts.list {
 293  		if ts.list[i].Key == key {
 294  			found = i
 295  		}
 296  	}
 297  	cTS := TraceState{}
 298  	if found == n && n < maxListMembers {
 299  		cTS.list = make([]member, n+1)
 300  	} else {
 301  		cTS.list = make([]member, n)
 302  	}
 303  	cTS.list[0] = m
 304  	// When the number of members exceeds capacity, drop the "right-most".
 305  	copy(cTS.list[1:], ts.list[0:found])
 306  	if found < n {
 307  		copy(cTS.list[1+found:], ts.list[found+1:])
 308  	}
 309  	return cTS, nil
 310  }
 311  
 312  // Delete returns a copy of the TraceState with the list-member identified by
 313  // key removed.
 314  func (ts TraceState) Delete(key string) TraceState {
 315  	members := make([]member, ts.Len())
 316  	copy(members, ts.list)
 317  	for i, member := range ts.list {
 318  		if member.Key == key {
 319  			members = append(members[:i], members[i+1:]...)
 320  			// TraceState should contain no duplicate members.
 321  			break
 322  		}
 323  	}
 324  	return TraceState{list: members}
 325  }
 326  
 327  // Len returns the number of list-members in the TraceState.
 328  func (ts TraceState) Len() int {
 329  	return len(ts.list)
 330  }
 331