taproot.go raw
1 // Package taproot provides a collection of tools for encoding bitcoin taproot
2 // addresses.
3 package taproot
4
5 import (
6 "bytes"
7 "errors"
8 "fmt"
9
10 "next.orly.dev/pkg/nostr/crypto/ec/bech32"
11 "next.orly.dev/pkg/nostr/crypto/ec/chaincfg"
12 "next.orly.dev/pkg/nostr/utils"
13 "next.orly.dev/pkg/lol/chk"
14 )
15
16 // AddressSegWit is the base address type for all SegWit addresses.
17 type AddressSegWit struct {
18 hrp []byte
19 witnessVersion byte
20 witnessProgram []byte
21 }
22
23 // AddressTaproot is an Address for a pay-to-taproot (P2TR) output. See BIP 341
24 // for further details.
25 type AddressTaproot struct {
26 AddressSegWit
27 }
28
29 // NewAddressTaproot returns a new AddressTaproot.
30 func NewAddressTaproot(
31 witnessProg []byte,
32 net *chaincfg.Params,
33 ) (*AddressTaproot, error) {
34
35 return newAddressTaproot(net.Bech32HRPSegwit, witnessProg)
36 }
37
38 // newAddressWitnessScriptHash is an internal helper function to create an
39 // AddressWitnessScriptHash with a known human-readable part, rather than
40 // looking it up through its parameters.
41 func newAddressTaproot(hrp []byte, witnessProg []byte) (
42 *AddressTaproot, error,
43 ) {
44 // Check for valid program length for witness version 1, which is 32
45 // for P2TR.
46 if len(witnessProg) != 32 {
47 return nil, errors.New(
48 "witness program must be 32 bytes for " +
49 "p2tr",
50 )
51 }
52 addr := &AddressTaproot{
53 AddressSegWit{
54 hrp: bytes.ToLower(hrp),
55 witnessVersion: 0x01,
56 witnessProgram: witnessProg,
57 },
58 }
59 return addr, nil
60 }
61
62 // decodeSegWitAddress parses a bech32 encoded segwit address string and
63 // returns the witness version and witness program byte representation.
64 func decodeSegWitAddress(address []byte) (byte, []byte, error) {
65 // Decode the bech32 encoded address.
66 _, data, bech32version, err := bech32.DecodeGeneric(address)
67 if chk.E(err) {
68 return 0, nil, err
69 }
70 // The first byte of the decoded address is the witness version, it must
71 // exist.
72 if len(data) < 1 {
73 return 0, nil, fmt.Errorf("no witness version")
74 }
75 // ...and be <= 16.
76 version := data[0]
77 if version > 16 {
78 return 0, nil, fmt.Errorf("invalid witness version: %v", version)
79 }
80 // The remaining characters of the address returned are grouped into
81 // words of 5 bits. In order to restore the original witness program
82 // bytes, we'll need to regroup into 8 bit words.
83 regrouped, err := bech32.ConvertBits(data[1:], 5, 8, false)
84 if chk.E(err) {
85 return 0, nil, err
86 }
87 // The regrouped data must be between 2 and 40 bytes.
88 if len(regrouped) < 2 || len(regrouped) > 40 {
89 return 0, nil, fmt.Errorf("invalid data length")
90 }
91 // For witness version 0, address MUST be exactly 20 or 32 bytes.
92 if version == 0 && len(regrouped) != 20 && len(regrouped) != 32 {
93 return 0, nil, fmt.Errorf(
94 "invalid data length for witness "+
95 "version 0: %v", len(regrouped),
96 )
97 }
98 // For witness version 0, the bech32 encoding must be used.
99 if version == 0 && bech32version != bech32.Version0 {
100 return 0, nil, fmt.Errorf(
101 "invalid checksum expected bech32 " +
102 "encoding for address with witness version 0",
103 )
104 }
105 // For witness version 1, the bech32m encoding must be used.
106 if version == 1 && bech32version != bech32.VersionM {
107 return 0, nil, fmt.Errorf(
108 "invalid checksum expected bech32m " +
109 "encoding for address with witness version 1",
110 )
111 }
112 return version, regrouped, nil
113 }
114
115 // encodeSegWitAddress creates a bech32 (or bech32m for SegWit v1) encoded
116 // address string representation from witness version and witness program.
117 func encodeSegWitAddress(
118 hrp []byte, witnessVersion byte,
119 witnessProgram []byte,
120 ) ([]byte, error) {
121 // Group the address bytes into 5 bit groups, as this is what is used to
122 // encode each character in the address string.
123 converted, err := bech32.ConvertBits(witnessProgram, 8, 5, true)
124 if chk.E(err) {
125 return nil, err
126 }
127 // Concatenate the witness version and program, and encode the resulting
128 // bytes using bech32 encoding.
129 combined := make([]byte, len(converted)+1)
130 combined[0] = witnessVersion
131 copy(combined[1:], converted)
132 var bech []byte
133 switch witnessVersion {
134 case 0:
135 bech, err = bech32.Encode(hrp, combined)
136
137 case 1:
138 bech, err = bech32.EncodeM(hrp, combined)
139
140 default:
141 return nil, fmt.Errorf(
142 "unsupported witness version %d",
143 witnessVersion,
144 )
145 }
146 if chk.E(err) {
147 return nil, err
148 }
149 // Check validity by decoding the created address.
150 version, program, err := decodeSegWitAddress(bech)
151 if chk.E(err) {
152 return nil, fmt.Errorf("invalid segwit address: %v", err)
153 }
154 if version != witnessVersion || !utils.FastEqual(program, witnessProgram) {
155 return nil, fmt.Errorf("invalid segwit address")
156 }
157 return bech, nil
158 }
159
160 // EncodeAddress returns the bech32 (or bech32m for SegWit v1) string encoding
161 // of an AddressSegWit.
162 //
163 // NOTE: This method is part of the Address interface.
164 func (a *AddressSegWit) EncodeAddress() []byte {
165 str, err := encodeSegWitAddress(
166 a.hrp, a.witnessVersion, a.witnessProgram[:],
167 )
168 if chk.E(err) {
169 return nil
170 }
171 return str
172 }
173