taproot.go raw

   1  // Package taproot provides a collection of tools for encoding bitcoin taproot
   2  // addresses.
   3  package taproot
   4  
   5  import (
   6  	"bytes"
   7  	"errors"
   8  	"fmt"
   9  
  10  	"next.orly.dev/pkg/nostr/crypto/ec/bech32"
  11  	"next.orly.dev/pkg/nostr/crypto/ec/chaincfg"
  12  	"next.orly.dev/pkg/nostr/utils"
  13  	"next.orly.dev/pkg/lol/chk"
  14  )
  15  
  16  // AddressSegWit is the base address type for all SegWit addresses.
  17  type AddressSegWit struct {
  18  	hrp            []byte
  19  	witnessVersion byte
  20  	witnessProgram []byte
  21  }
  22  
  23  // AddressTaproot is an Address for a pay-to-taproot (P2TR) output. See BIP 341
  24  // for further details.
  25  type AddressTaproot struct {
  26  	AddressSegWit
  27  }
  28  
  29  // NewAddressTaproot returns a new AddressTaproot.
  30  func NewAddressTaproot(
  31  	witnessProg []byte,
  32  	net *chaincfg.Params,
  33  ) (*AddressTaproot, error) {
  34  
  35  	return newAddressTaproot(net.Bech32HRPSegwit, witnessProg)
  36  }
  37  
  38  // newAddressWitnessScriptHash is an internal helper function to create an
  39  // AddressWitnessScriptHash with a known human-readable part, rather than
  40  // looking it up through its parameters.
  41  func newAddressTaproot(hrp []byte, witnessProg []byte) (
  42  	*AddressTaproot, error,
  43  ) {
  44  	// Check for valid program length for witness version 1, which is 32
  45  	// for P2TR.
  46  	if len(witnessProg) != 32 {
  47  		return nil, errors.New(
  48  			"witness program must be 32 bytes for " +
  49  				"p2tr",
  50  		)
  51  	}
  52  	addr := &AddressTaproot{
  53  		AddressSegWit{
  54  			hrp:            bytes.ToLower(hrp),
  55  			witnessVersion: 0x01,
  56  			witnessProgram: witnessProg,
  57  		},
  58  	}
  59  	return addr, nil
  60  }
  61  
  62  // decodeSegWitAddress parses a bech32 encoded segwit address string and
  63  // returns the witness version and witness program byte representation.
  64  func decodeSegWitAddress(address []byte) (byte, []byte, error) {
  65  	// Decode the bech32 encoded address.
  66  	_, data, bech32version, err := bech32.DecodeGeneric(address)
  67  	if chk.E(err) {
  68  		return 0, nil, err
  69  	}
  70  	// The first byte of the decoded address is the witness version, it must
  71  	// exist.
  72  	if len(data) < 1 {
  73  		return 0, nil, fmt.Errorf("no witness version")
  74  	}
  75  	// ...and be <= 16.
  76  	version := data[0]
  77  	if version > 16 {
  78  		return 0, nil, fmt.Errorf("invalid witness version: %v", version)
  79  	}
  80  	// The remaining characters of the address returned are grouped into
  81  	// words of 5 bits. In order to restore the original witness program
  82  	// bytes, we'll need to regroup into 8 bit words.
  83  	regrouped, err := bech32.ConvertBits(data[1:], 5, 8, false)
  84  	if chk.E(err) {
  85  		return 0, nil, err
  86  	}
  87  	// The regrouped data must be between 2 and 40 bytes.
  88  	if len(regrouped) < 2 || len(regrouped) > 40 {
  89  		return 0, nil, fmt.Errorf("invalid data length")
  90  	}
  91  	// For witness version 0, address MUST be exactly 20 or 32 bytes.
  92  	if version == 0 && len(regrouped) != 20 && len(regrouped) != 32 {
  93  		return 0, nil, fmt.Errorf(
  94  			"invalid data length for witness "+
  95  				"version 0: %v", len(regrouped),
  96  		)
  97  	}
  98  	// For witness version 0, the bech32 encoding must be used.
  99  	if version == 0 && bech32version != bech32.Version0 {
 100  		return 0, nil, fmt.Errorf(
 101  			"invalid checksum expected bech32 " +
 102  				"encoding for address with witness version 0",
 103  		)
 104  	}
 105  	// For witness version 1, the bech32m encoding must be used.
 106  	if version == 1 && bech32version != bech32.VersionM {
 107  		return 0, nil, fmt.Errorf(
 108  			"invalid checksum expected bech32m " +
 109  				"encoding for address with witness version 1",
 110  		)
 111  	}
 112  	return version, regrouped, nil
 113  }
 114  
 115  // encodeSegWitAddress creates a bech32 (or bech32m for SegWit v1) encoded
 116  // address string representation from witness version and witness program.
 117  func encodeSegWitAddress(
 118  	hrp []byte, witnessVersion byte,
 119  	witnessProgram []byte,
 120  ) ([]byte, error) {
 121  	// Group the address bytes into 5 bit groups, as this is what is used to
 122  	// encode each character in the address string.
 123  	converted, err := bech32.ConvertBits(witnessProgram, 8, 5, true)
 124  	if chk.E(err) {
 125  		return nil, err
 126  	}
 127  	// Concatenate the witness version and program, and encode the resulting
 128  	// bytes using bech32 encoding.
 129  	combined := make([]byte, len(converted)+1)
 130  	combined[0] = witnessVersion
 131  	copy(combined[1:], converted)
 132  	var bech []byte
 133  	switch witnessVersion {
 134  	case 0:
 135  		bech, err = bech32.Encode(hrp, combined)
 136  
 137  	case 1:
 138  		bech, err = bech32.EncodeM(hrp, combined)
 139  
 140  	default:
 141  		return nil, fmt.Errorf(
 142  			"unsupported witness version %d",
 143  			witnessVersion,
 144  		)
 145  	}
 146  	if chk.E(err) {
 147  		return nil, err
 148  	}
 149  	// Check validity by decoding the created address.
 150  	version, program, err := decodeSegWitAddress(bech)
 151  	if chk.E(err) {
 152  		return nil, fmt.Errorf("invalid segwit address: %v", err)
 153  	}
 154  	if version != witnessVersion || !utils.FastEqual(program, witnessProgram) {
 155  		return nil, fmt.Errorf("invalid segwit address")
 156  	}
 157  	return bech, nil
 158  }
 159  
 160  // EncodeAddress returns the bech32 (or bech32m for SegWit v1) string encoding
 161  // of an AddressSegWit.
 162  //
 163  // NOTE: This method is part of the Address interface.
 164  func (a *AddressSegWit) EncodeAddress() []byte {
 165  	str, err := encodeSegWitAddress(
 166  		a.hrp, a.witnessVersion, a.witnessProgram[:],
 167  	)
 168  	if chk.E(err) {
 169  		return nil
 170  	}
 171  	return str
 172  }
 173