fma.mx raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package math
   6  
   7  import "math/bits"
   8  
   9  func zero(x uint64) uint64 {
  10  	if x == 0 {
  11  		return 1
  12  	}
  13  	return 0
  14  	// branchless:
  15  	// return ((x>>1 | x&1) - 1) >> 63
  16  }
  17  
  18  func nonzero(x uint64) uint64 {
  19  	if x != 0 {
  20  		return 1
  21  	}
  22  	return 0
  23  	// branchless:
  24  	// return 1 - ((x>>1|x&1)-1)>>63
  25  }
  26  
  27  func shl(u1, u2 uint64, n uint) (r1, r2 uint64) {
  28  	r1 = u1<<n | u2>>(64-n) | u2<<(n-64)
  29  	r2 = u2 << n
  30  	return
  31  }
  32  
  33  func shr(u1, u2 uint64, n uint) (r1, r2 uint64) {
  34  	r2 = u2>>n | u1<<(64-n) | u1>>(n-64)
  35  	r1 = u1 >> n
  36  	return
  37  }
  38  
  39  // shrcompress compresses the bottom n+1 bits of the two-word
  40  // value into a single bit. the result is equal to the value
  41  // shifted to the right by n, except the result's 0th bit is
  42  // set to the bitwise OR of the bottom n+1 bits.
  43  func shrcompress(u1, u2 uint64, n uint) (r1, r2 uint64) {
  44  	// TODO: Performance here is really sensitive to the
  45  	// order/placement of these branches. n == 0 is common
  46  	// enough to be in the fast path. Perhaps more measurement
  47  	// needs to be done to find the optimal order/placement?
  48  	switch {
  49  	case n == 0:
  50  		return u1, u2
  51  	case n == 64:
  52  		return 0, u1 | nonzero(u2)
  53  	case n >= 128:
  54  		return 0, nonzero(u1 | u2)
  55  	case n < 64:
  56  		r1, r2 = shr(u1, u2, n)
  57  		r2 |= nonzero(u2 & (1<<n - 1))
  58  	case n < 128:
  59  		r1, r2 = shr(u1, u2, n)
  60  		r2 |= nonzero(u1&(1<<(n-64)-1) | u2)
  61  	}
  62  	return
  63  }
  64  
  65  func lz(u1, u2 uint64) (l int32) {
  66  	l = int32(bits.LeadingZeros64(u1))
  67  	if l == 64 {
  68  		l += int32(bits.LeadingZeros64(u2))
  69  	}
  70  	return l
  71  }
  72  
  73  // split splits b into sign, biased exponent, and mantissa.
  74  // It adds the implicit 1 bit to the mantissa for normal values,
  75  // and normalizes subnormal values.
  76  func split(b uint64) (sign uint32, exp int32, mantissa uint64) {
  77  	sign = uint32(b >> 63)
  78  	exp = int32(b>>52) & mask
  79  	mantissa = b & fracMask
  80  
  81  	if exp == 0 {
  82  		// Normalize value if subnormal.
  83  		shift := uint(bits.LeadingZeros64(mantissa) - 11)
  84  		mantissa <<= shift
  85  		exp = 1 - int32(shift)
  86  	} else {
  87  		// Add implicit 1 bit
  88  		mantissa |= 1 << 52
  89  	}
  90  	return
  91  }
  92  
  93  // FMA returns x * y + z, computed with only one rounding.
  94  // (That is, FMA returns the fused multiply-add of x, y, and z.)
  95  func FMA(x, y, z float64) float64 {
  96  	bx, by, bz := Float64bits(x), Float64bits(y), Float64bits(z)
  97  
  98  	// Inf or NaN or zero involved. At most one rounding will occur.
  99  	if x == 0.0 || y == 0.0 || bx&uvinf == uvinf || by&uvinf == uvinf {
 100  		return x*y + z
 101  	}
 102  	// Handle z == 0.0 separately.
 103  	// Adding zero usually does not change the original value.
 104  	// However, there is an exception with negative zero. (e.g. (-0) + (+0) = (+0))
 105  	// This applies when x * y is negative and underflows.
 106  	if z == 0.0 {
 107  		return x * y
 108  	}
 109  	// Handle non-finite z separately. Evaluating x*y+z where
 110  	// x and y are finite, but z is infinite, should always result in z.
 111  	if bz&uvinf == uvinf {
 112  		return z
 113  	}
 114  
 115  	// Inputs are (sub)normal.
 116  	// Split x, y, z into sign, exponent, mantissa.
 117  	xs, xe, xm := split(bx)
 118  	ys, ye, ym := split(by)
 119  	zs, ze, zm := split(bz)
 120  
 121  	// Compute product p = x*y as sign, exponent, two-word mantissa.
 122  	// Start with exponent. "is normal" bit isn't subtracted yet.
 123  	pe := xe + ye - bias + 1
 124  
 125  	// pm1:pm2 is the double-word mantissa for the product p.
 126  	// Shift left to leave top bit in product. Effectively
 127  	// shifts the 106-bit product to the left by 21.
 128  	pm1, pm2 := bits.Mul64(xm<<10, ym<<11)
 129  	zm1, zm2 := zm<<10, uint64(0)
 130  	ps := xs ^ ys // product sign
 131  
 132  	// normalize to 62nd bit
 133  	is62zero := uint((^pm1 >> 62) & 1)
 134  	pm1, pm2 = shl(pm1, pm2, is62zero)
 135  	pe -= int32(is62zero)
 136  
 137  	// Swap addition operands so |p| >= |z|
 138  	if pe < ze || pe == ze && pm1 < zm1 {
 139  		ps, pe, pm1, pm2, zs, ze, zm1, zm2 = zs, ze, zm1, zm2, ps, pe, pm1, pm2
 140  	}
 141  
 142  	// Special case: if p == -z the result is always +0 since neither operand is zero.
 143  	if ps != zs && pe == ze && pm1 == zm1 && pm2 == zm2 {
 144  		return 0
 145  	}
 146  
 147  	// Align significands
 148  	zm1, zm2 = shrcompress(zm1, zm2, uint(pe-ze))
 149  
 150  	// Compute resulting significands, normalizing if necessary.
 151  	var m, c uint64
 152  	if ps == zs {
 153  		// Adding (pm1:pm2) + (zm1:zm2)
 154  		pm2, c = bits.Add64(pm2, zm2, 0)
 155  		pm1, _ = bits.Add64(pm1, zm1, c)
 156  		pe -= int32(^pm1 >> 63)
 157  		pm1, m = shrcompress(pm1, pm2, uint(64+pm1>>63))
 158  	} else {
 159  		// Subtracting (pm1:pm2) - (zm1:zm2)
 160  		// TODO: should we special-case cancellation?
 161  		pm2, c = bits.Sub64(pm2, zm2, 0)
 162  		pm1, _ = bits.Sub64(pm1, zm1, c)
 163  		nz := lz(pm1, pm2)
 164  		pe -= nz
 165  		m, pm2 = shl(pm1, pm2, uint(nz-1))
 166  		m |= nonzero(pm2)
 167  	}
 168  
 169  	// Round and break ties to even
 170  	if pe > 1022+bias || pe == 1022+bias && (m+1<<9)>>63 == 1 {
 171  		// rounded value overflows exponent range
 172  		return Float64frombits(uint64(ps)<<63 | uvinf)
 173  	}
 174  	if pe < 0 {
 175  		n := uint(-pe)
 176  		m = m>>n | nonzero(m&(1<<n-1))
 177  		pe = 0
 178  	}
 179  	m = ((m + 1<<9) >> 10) & ^zero((m&(1<<10-1))^1<<9)
 180  	pe &= -int32(nonzero(m))
 181  	return Float64frombits(uint64(ps)<<63 + uint64(pe)<<52 + m)
 182  }
 183