extrapolation.go raw

   1  // SPDX-License-Identifier: Unlicense OR MIT
   2  
   3  package fling
   4  
   5  import (
   6  	"math"
   7  	"strconv"
   8  	"strings"
   9  	"time"
  10  )
  11  
  12  // Extrapolation computes a 1-dimensional velocity estimate
  13  // for a set of timestamped points using the least squares
  14  // fit of a 2nd order polynomial. The same method is used
  15  // by Android.
  16  type Extrapolation struct {
  17  	// Index into points.
  18  	idx int
  19  	// Circular buffer of samples.
  20  	samples   []sample
  21  	lastValue float32
  22  	// Pre-allocated cache for samples.
  23  	cache [historySize]sample
  24  
  25  	// Filtered values and times
  26  	values [historySize]float32
  27  	times  [historySize]float32
  28  }
  29  
  30  type sample struct {
  31  	t time.Duration
  32  	v float32
  33  }
  34  
  35  type matrix struct {
  36  	rows, cols int
  37  	data       []float32
  38  }
  39  
  40  type Estimate struct {
  41  	Velocity float32
  42  	Distance float32
  43  }
  44  
  45  type coefficients [degree + 1]float32
  46  
  47  const (
  48  	degree       = 2
  49  	historySize  = 20
  50  	maxAge       = 100 * time.Millisecond
  51  	maxSampleGap = 40 * time.Millisecond
  52  )
  53  
  54  // SampleDelta adds a relative sample to the estimation.
  55  func (e *Extrapolation) SampleDelta(t time.Duration, delta float32) {
  56  	val := delta + e.lastValue
  57  	e.Sample(t, val)
  58  }
  59  
  60  // Sample adds an absolute sample to the estimation.
  61  func (e *Extrapolation) Sample(t time.Duration, val float32) {
  62  	e.lastValue = val
  63  	if e.samples == nil {
  64  		e.samples = e.cache[:0]
  65  	}
  66  	s := sample{
  67  		t: t,
  68  		v: val,
  69  	}
  70  	if e.idx == len(e.samples) && e.idx < cap(e.samples) {
  71  		e.samples = append(e.samples, s)
  72  	} else {
  73  		e.samples[e.idx] = s
  74  	}
  75  	e.idx++
  76  	if e.idx == cap(e.samples) {
  77  		e.idx = 0
  78  	}
  79  }
  80  
  81  // Velocity returns an estimate of the implied velocity and
  82  // distance for the points sampled, or zero if the estimation method
  83  // failed.
  84  func (e *Extrapolation) Estimate() Estimate {
  85  	if len(e.samples) == 0 {
  86  		return Estimate{}
  87  	}
  88  	values := e.values[:0]
  89  	times := e.times[:0]
  90  	first := e.get(0)
  91  	t := first.t
  92  	// Walk backwards collecting samples.
  93  	for i := 0; i < len(e.samples); i++ {
  94  		p := e.get(-i)
  95  		age := first.t - p.t
  96  		if age >= maxAge || t-p.t >= maxSampleGap {
  97  			// If the samples are too old or
  98  			// too much time passed between samples
  99  			// assume they're not part of the fling.
 100  			break
 101  		}
 102  		t = p.t
 103  		values = append(values, first.v-p.v)
 104  		times = append(times, float32((-age).Seconds()))
 105  	}
 106  	coef, ok := polyFit(times, values)
 107  	if !ok {
 108  		return Estimate{}
 109  	}
 110  	dist := values[len(values)-1] - values[0]
 111  	return Estimate{
 112  		Velocity: coef[1],
 113  		Distance: dist,
 114  	}
 115  }
 116  
 117  func (e *Extrapolation) get(i int) sample {
 118  	idx := (e.idx + i - 1 + len(e.samples)) % len(e.samples)
 119  	return e.samples[idx]
 120  }
 121  
 122  // fit computes the least squares polynomial fit for
 123  // the set of points in X, Y. If the fitting fails
 124  // because of contradicting or insufficient data,
 125  // fit returns false.
 126  func polyFit(X, Y []float32) (coefficients, bool) {
 127  	if len(X) != len(Y) {
 128  		panic("X and Y lengths differ")
 129  	}
 130  	if len(X) <= degree {
 131  		// Not enough points to fit a curve.
 132  		return coefficients{}, false
 133  	}
 134  
 135  	// Use a method similar to Android's VelocityTracker.cpp:
 136  	// https://android.googlesource.com/platform/frameworks/base/+/56a2301/libs/androidfw/VelocityTracker.cpp
 137  	// where all weights are 1.
 138  
 139  	// First, expand the X vector to the matrix A in column-major order.
 140  	A := newMatrix(degree+1, len(X))
 141  	for i, x := range X {
 142  		A.set(0, i, 1)
 143  		for j := 1; j < A.rows; j++ {
 144  			A.set(j, i, A.get(j-1, i)*x)
 145  		}
 146  	}
 147  
 148  	Q, Rt, ok := decomposeQR(A)
 149  	if !ok {
 150  		return coefficients{}, false
 151  	}
 152  	// Solve R*B = Qt*Y for B, which is then the polynomial coefficients.
 153  	// Since R is upper triangular, we can proceed from bottom right to
 154  	// upper left.
 155  	// https://en.wikipedia.org/wiki/Non-linear_least_squares
 156  	var B coefficients
 157  	for i := Q.rows - 1; i >= 0; i-- {
 158  		B[i] = dot(Q.col(i), Y)
 159  		for j := Q.rows - 1; j > i; j-- {
 160  			B[i] -= Rt.get(i, j) * B[j]
 161  		}
 162  		B[i] /= Rt.get(i, i)
 163  	}
 164  	return B, true
 165  }
 166  
 167  // decomposeQR computes and returns Q, Rt where Q*transpose(Rt) = A, if
 168  // possible. R is guaranteed to be upper triangular and only the square
 169  // part of Rt is returned.
 170  func decomposeQR(A *matrix) (*matrix, *matrix, bool) {
 171  	// Gram-Schmidt QR decompose A where Q*R = A.
 172  	// https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
 173  	Q := newMatrix(A.rows, A.cols)  // Column-major.
 174  	Rt := newMatrix(A.rows, A.rows) // R transposed, row-major.
 175  	for i := 0; i < Q.rows; i++ {
 176  		// Copy A column.
 177  		for j := 0; j < Q.cols; j++ {
 178  			Q.set(i, j, A.get(i, j))
 179  		}
 180  		// Subtract projections. Note that int the projection
 181  		//
 182  		// proju a = <u, a>/<u, u> u
 183  		//
 184  		// the normalized column e replaces u, where <e, e> = 1:
 185  		//
 186  		// proje a = <e, a>/<e, e> e = <e, a> e
 187  		for j := 0; j < i; j++ {
 188  			d := dot(Q.col(j), Q.col(i))
 189  			for k := 0; k < Q.cols; k++ {
 190  				Q.set(i, k, Q.get(i, k)-d*Q.get(j, k))
 191  			}
 192  		}
 193  		// Normalize Q columns.
 194  		n := norm(Q.col(i))
 195  		if n < 0.000001 {
 196  			// Degenerate data, no solution.
 197  			return nil, nil, false
 198  		}
 199  		invNorm := 1 / n
 200  		for j := 0; j < Q.cols; j++ {
 201  			Q.set(i, j, Q.get(i, j)*invNorm)
 202  		}
 203  		// Update Rt.
 204  		for j := i; j < Rt.cols; j++ {
 205  			Rt.set(i, j, dot(Q.col(i), A.col(j)))
 206  		}
 207  	}
 208  	return Q, Rt, true
 209  }
 210  
 211  func norm(V []float32) float32 {
 212  	var n float32
 213  	for _, v := range V {
 214  		n += v * v
 215  	}
 216  	return float32(math.Sqrt(float64(n)))
 217  }
 218  
 219  func dot(V1, V2 []float32) float32 {
 220  	var d float32
 221  	for i, v1 := range V1 {
 222  		d += v1 * V2[i]
 223  	}
 224  	return d
 225  }
 226  
 227  func newMatrix(rows, cols int) *matrix {
 228  	return &matrix{
 229  		rows: rows,
 230  		cols: cols,
 231  		data: make([]float32, rows*cols),
 232  	}
 233  }
 234  
 235  func (m *matrix) set(row, col int, v float32) {
 236  	if row < 0 || row >= m.rows {
 237  		panic("row out of range")
 238  	}
 239  	if col < 0 || col >= m.cols {
 240  		panic("col out of range")
 241  	}
 242  	m.data[row*m.cols+col] = v
 243  }
 244  
 245  func (m *matrix) get(row, col int) float32 {
 246  	if row < 0 || row >= m.rows {
 247  		panic("row out of range")
 248  	}
 249  	if col < 0 || col >= m.cols {
 250  		panic("col out of range")
 251  	}
 252  	return m.data[row*m.cols+col]
 253  }
 254  
 255  func (m *matrix) col(c int) []float32 {
 256  	return m.data[c*m.cols : (c+1)*m.cols]
 257  }
 258  
 259  func (m *matrix) approxEqual(m2 *matrix) bool {
 260  	if m.rows != m2.rows || m.cols != m2.cols {
 261  		return false
 262  	}
 263  	const epsilon = 0.00001
 264  	for row := 0; row < m.rows; row++ {
 265  		for col := 0; col < m.cols; col++ {
 266  			d := m2.get(row, col) - m.get(row, col)
 267  			if d < -epsilon || d > epsilon {
 268  				return false
 269  			}
 270  		}
 271  	}
 272  	return true
 273  }
 274  
 275  func (m *matrix) transpose() *matrix {
 276  	t := &matrix{
 277  		rows: m.cols,
 278  		cols: m.rows,
 279  		data: make([]float32, len(m.data)),
 280  	}
 281  	for i := 0; i < m.rows; i++ {
 282  		for j := 0; j < m.cols; j++ {
 283  			t.set(j, i, m.get(i, j))
 284  		}
 285  	}
 286  	return t
 287  }
 288  
 289  func (m *matrix) mul(m2 *matrix) *matrix {
 290  	if m.rows != m2.cols {
 291  		panic("mismatched matrices")
 292  	}
 293  	mm := &matrix{
 294  		rows: m.rows,
 295  		cols: m2.cols,
 296  		data: make([]float32, m.rows*m2.cols),
 297  	}
 298  	for i := 0; i < mm.rows; i++ {
 299  		for j := 0; j < mm.cols; j++ {
 300  			var v float32
 301  			for k := 0; k < m.rows; k++ {
 302  				v += m.get(k, j) * m2.get(i, k)
 303  			}
 304  			mm.set(i, j, v)
 305  		}
 306  	}
 307  	return mm
 308  }
 309  
 310  func (m *matrix) String() string {
 311  	var b strings.Builder
 312  	for i := 0; i < m.rows; i++ {
 313  		for j := 0; j < m.cols; j++ {
 314  			v := m.get(i, j)
 315  			b.WriteString(strconv.FormatFloat(float64(v), 'g', -1, 32))
 316  			b.WriteString(", ")
 317  		}
 318  		b.WriteString("\n")
 319  	}
 320  	return b.String()
 321  }
 322  
 323  func (c coefficients) approxEqual(c2 coefficients) bool {
 324  	const epsilon = 0.00001
 325  	for i, v := range c {
 326  		d := v - c2[i]
 327  		if d < -epsilon || d > epsilon {
 328  			return false
 329  		}
 330  	}
 331  	return true
 332  }
 333