state.go raw

   1  // Copyright 2018 The gVisor Authors.
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //     http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  // Package state provides functionality related to saving and loading object
  16  // graphs.  For most types, it provides a set of default saving / loading logic
  17  // that will be invoked automatically if custom logic is not defined.
  18  //
  19  //	Kind             Support
  20  //	----             -------
  21  //	Bool             default
  22  //	Int              default
  23  //	Int8             default
  24  //	Int16            default
  25  //	Int32            default
  26  //	Int64            default
  27  //	Uint             default
  28  //	Uint8            default
  29  //	Uint16           default
  30  //	Uint32           default
  31  //	Uint64           default
  32  //	Float32          default
  33  //	Float64          default
  34  //	Complex64        default
  35  //	Complex128       default
  36  //	Array            default
  37  //	Chan             custom
  38  //	Func             custom
  39  //	Interface        default
  40  //	Map              default
  41  //	Ptr              default
  42  //	Slice            default
  43  //	String           default
  44  //	Struct           custom (*) Unless zero-sized.
  45  //	UnsafePointer    custom
  46  //
  47  // See README.md for an overview of how encoding and decoding works.
  48  package state
  49  
  50  import (
  51  	"context"
  52  	"fmt"
  53  	"io"
  54  	"reflect"
  55  	"runtime"
  56  
  57  	"gvisor.dev/gvisor/pkg/state/wire"
  58  )
  59  
  60  // objectID is a unique identifier assigned to each object to be serialized.
  61  // Each instance of an object is considered separately, i.e. if there are two
  62  // objects of the same type in the object graph being serialized, they'll be
  63  // assigned unique objectIDs.
  64  type objectID uint32
  65  
  66  // typeID is the identifier for a type. Types are serialized and tracked
  67  // alongside objects in order to avoid the overhead of encoding field names in
  68  // all objects.
  69  type typeID uint32
  70  
  71  // ErrState is returned when an error is encountered during encode/decode.
  72  type ErrState struct {
  73  	// err is the underlying error.
  74  	err error
  75  
  76  	// trace is the stack trace.
  77  	trace string
  78  }
  79  
  80  // Error returns a sensible description of the state error.
  81  func (e *ErrState) Error() string {
  82  	return fmt.Sprintf("%v:\n%s", e.err, e.trace)
  83  }
  84  
  85  // Unwrap implements standard unwrapping.
  86  func (e *ErrState) Unwrap() error {
  87  	return e.err
  88  }
  89  
  90  // Save saves the given object state.
  91  func Save(ctx context.Context, w io.Writer, rootPtr any) (Stats, error) {
  92  	// Create the encoding state.
  93  	es := encodeState{
  94  		ctx:            ctx,
  95  		w:              wire.Writer{Writer: w},
  96  		types:          makeTypeEncodeDatabase(),
  97  		zeroValues:     make(map[reflect.Type]*objectEncodeState),
  98  		pending:        make(map[objectID]*objectEncodeState),
  99  		encodedStructs: make(map[reflect.Value]*wire.Struct),
 100  	}
 101  
 102  	// Perform the encoding.
 103  	err := safely(func() {
 104  		es.Save(reflect.ValueOf(rootPtr).Elem())
 105  	})
 106  	return es.stats, err
 107  }
 108  
 109  // Load loads a checkpoint.
 110  func Load(ctx context.Context, r io.Reader, rootPtr any) (Stats, error) {
 111  	// Create the decoding state.
 112  	ds := decodeState{
 113  		ctx:      ctx,
 114  		r:        wire.Reader{Reader: r},
 115  		types:    makeTypeDecodeDatabase(),
 116  		deferred: make(map[objectID]wire.Object),
 117  	}
 118  
 119  	// Attempt our decode.
 120  	err := safely(func() {
 121  		ds.Load(reflect.ValueOf(rootPtr).Elem())
 122  	})
 123  	return ds.stats, err
 124  }
 125  
 126  // Sink is used for Type.StateSave.
 127  type Sink struct {
 128  	internal objectEncoder
 129  }
 130  
 131  // Save adds the given object to the map.
 132  //
 133  // You should pass always pointers to the object you are saving. For example:
 134  //
 135  //	type X struct {
 136  //		A int
 137  //		B *int
 138  //	}
 139  //
 140  //	func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
 141  //		return state.TypeInfo{
 142  //			Name:   "pkg.X",
 143  //			Fields: []string{
 144  //				"A",
 145  //				"B",
 146  //			},
 147  //		}
 148  //	}
 149  //
 150  //	func (x *X) StateSave(m Sink) {
 151  //		m.Save(0, &x.A) // Field is A.
 152  //		m.Save(1, &x.B) // Field is B.
 153  //	}
 154  //
 155  //	func (x *X) StateLoad(m Source) {
 156  //		m.Load(0, &x.A) // Field is A.
 157  //		m.Load(1, &x.B) // Field is B.
 158  //	}
 159  func (s Sink) Save(slot int, objPtr any) {
 160  	s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
 161  }
 162  
 163  // SaveValue adds the given object value to the map.
 164  //
 165  // This should be used for values where pointers are not available, or casts
 166  // are required during Save/Load.
 167  //
 168  // For example, if we want to cast external package type P.Foo to int64:
 169  //
 170  //	func (x *X) StateSave(m Sink) {
 171  //		m.SaveValue(0, "A", int64(x.A))
 172  //	}
 173  //
 174  //	func (x *X) StateLoad(m Source) {
 175  //		m.LoadValue(0, new(int64), func(x any) {
 176  //			x.A = P.Foo(x.(int64))
 177  //		})
 178  //	}
 179  func (s Sink) SaveValue(slot int, obj any) {
 180  	s.internal.save(slot, reflect.ValueOf(obj))
 181  }
 182  
 183  // Context returns the context object provided at save time.
 184  func (s Sink) Context() context.Context {
 185  	return s.internal.es.ctx
 186  }
 187  
 188  // Type is an interface that must be implemented by Struct objects. This allows
 189  // these objects to be serialized while minimizing runtime reflection required.
 190  //
 191  // All these methods can be automatically generated by the go_statify tool.
 192  type Type interface {
 193  	// StateTypeName returns the type's name.
 194  	//
 195  	// This is used for matching type information during encoding and
 196  	// decoding, as well as dynamic interface dispatch. This should be
 197  	// globally unique.
 198  	StateTypeName() string
 199  
 200  	// StateFields returns information about the type.
 201  	//
 202  	// Fields is the set of fields for the object. Calls to Sink.Save and
 203  	// Source.Load must be made in-order with respect to these fields.
 204  	//
 205  	// This will be called at most once per serialization.
 206  	StateFields() []string
 207  }
 208  
 209  // SaverLoader must be implemented by struct types.
 210  type SaverLoader interface {
 211  	// StateSave saves the state of the object to the given Map.
 212  	StateSave(Sink)
 213  
 214  	// StateLoad loads the state of the object.
 215  	StateLoad(context.Context, Source)
 216  }
 217  
 218  // Source is used for Type.StateLoad.
 219  type Source struct {
 220  	internal objectDecoder
 221  }
 222  
 223  // Load loads the given object passed as a pointer..
 224  //
 225  // See Sink.Save for an example.
 226  func (s Source) Load(slot int, objPtr any) {
 227  	s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
 228  }
 229  
 230  // LoadWait loads the given objects from the map, and marks it as requiring all
 231  // AfterLoad executions to complete prior to running this object's AfterLoad.
 232  //
 233  // See Sink.Save for an example.
 234  func (s Source) LoadWait(slot int, objPtr any) {
 235  	s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
 236  }
 237  
 238  // LoadValue loads the given object value from the map.
 239  //
 240  // See Sink.SaveValue for an example.
 241  func (s Source) LoadValue(slot int, objPtr any, fn func(any)) {
 242  	o := reflect.ValueOf(objPtr)
 243  	s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
 244  }
 245  
 246  // AfterLoad schedules a function execution when all objects have been
 247  // allocated and their automated loading and customized load logic have been
 248  // executed. fn will not be executed until all of current object's
 249  // dependencies' AfterLoad() logic, if exist, have been executed.
 250  func (s Source) AfterLoad(fn func()) {
 251  	s.internal.afterLoad(fn)
 252  }
 253  
 254  // Context returns the context object provided at load time.
 255  func (s Source) Context() context.Context {
 256  	return s.internal.ds.ctx
 257  }
 258  
 259  // IsZeroValue checks if the given value is the zero value.
 260  //
 261  // This function is used by the stateify tool.
 262  func IsZeroValue(val any) bool {
 263  	return val == nil || reflect.ValueOf(val).Elem().IsZero()
 264  }
 265  
 266  // Failf is a wrapper around panic that should be used to generate errors that
 267  // can be caught during saving and loading.
 268  func Failf(fmtStr string, v ...any) {
 269  	panic(fmt.Errorf(fmtStr, v...))
 270  }
 271  
 272  // safely executes the given function, catching a panic and unpacking as an
 273  // error.
 274  //
 275  // The error flow through the state package uses panic and recover. There are
 276  // two important reasons for this:
 277  //
 278  // 1) Many of the reflection methods will already panic with invalid data or
 279  // violated assumptions. We would want to recover anyways here.
 280  //
 281  // 2) It allows us to eliminate boilerplate within Save() and Load() functions.
 282  // In nearly all cases, when the low-level serialization functions fail, you
 283  // will want the checkpoint to fail anyways. Plumbing errors through every
 284  // method doesn't add a lot of value. If there are specific error conditions
 285  // that you'd like to handle, you should add appropriate functionality to
 286  // objects themselves prior to calling Save() and Load().
 287  func safely(fn func()) (err error) {
 288  	defer func() {
 289  		if r := recover(); r != nil {
 290  			if es, ok := r.(*ErrState); ok {
 291  				err = es // Propagate.
 292  				return
 293  			}
 294  
 295  			// Build a new state error.
 296  			es := new(ErrState)
 297  			if e, ok := r.(error); ok {
 298  				es.err = e
 299  			} else {
 300  				es.err = fmt.Errorf("%v", r)
 301  			}
 302  
 303  			// Make a stack. We don't know how big it will be ahead
 304  			// of time, but want to make sure we get the whole
 305  			// thing. So we just do a stupid brute force approach.
 306  			var stack []byte
 307  			for sz := 1024; ; sz *= 2 {
 308  				stack = make([]byte, sz)
 309  				n := runtime.Stack(stack, false)
 310  				if n < sz {
 311  					es.trace = string(stack[:n])
 312  					break
 313  				}
 314  			}
 315  
 316  			// Set the error.
 317  			err = es
 318  		}
 319  	}()
 320  
 321  	// Execute the function.
 322  	fn()
 323  	return nil
 324  }
 325