nip44.go raw
1 package encryption
2
3 import (
4 "bytes"
5 "crypto/hmac"
6 "crypto/rand"
7 "crypto/sha256"
8 "encoding/base64"
9 "encoding/binary"
10 "errors"
11 "io"
12 "math"
13
14 "next.orly.dev/pkg/nostr/crypto/ec/secp256k1"
15 "golang.org/x/crypto/chacha20"
16 "golang.org/x/crypto/hkdf"
17 "next.orly.dev/pkg/lol/errorf"
18 )
19
20 var (
21 MinPlaintextSize = 0x0001 // 1b msg => padded to 32b
22 MaxPlaintextSize = 0xffff // 65535 (64kb-1) => padded to 64kb
23 )
24
25 type EncryptOptions struct {
26 Salt []byte
27 Version int
28 }
29
30 func Encrypt(
31 conversationKey []byte, plaintext []byte, options *EncryptOptions,
32 ) (ciphertext string, err error) {
33 var (
34 version int = 2
35 salt []byte
36 enc []byte
37 nonce []byte
38 auth []byte
39 padded []byte
40 encrypted []byte
41 hmac_ []byte
42 concat []byte
43 )
44 if options != nil && options.Version != 0 {
45 version = options.Version
46 }
47 if options != nil && options.Salt != nil {
48 salt = options.Salt
49 } else {
50 if salt, err = randomBytes(32); err != nil {
51 return
52 }
53 }
54 if version != 2 {
55 err = errorf.E("unknown version %d", version)
56 return
57 }
58 if len(salt) != 32 {
59 err = errorf.E("salt must be 32 bytes")
60 return
61 }
62 if enc, nonce, auth, err = MessageKeys(conversationKey, salt); err != nil {
63 return
64 }
65 if padded, err = pad(plaintext); err != nil {
66 return
67 }
68 if encrypted, err = chacha20_(enc, nonce, padded); err != nil {
69 return
70 }
71 if hmac_, err = sha256Hmac(auth, encrypted, salt); err != nil {
72 return
73 }
74 concat = append(concat, []byte{byte(version)}...)
75 concat = append(concat, salt...)
76 concat = append(concat, encrypted...)
77 concat = append(concat, hmac_...)
78 ciphertext = base64.StdEncoding.EncodeToString(concat)
79 return
80 }
81
82 func Decrypt(conversationKey []byte, ciphertext string) (
83 plaintext string, err error,
84 ) {
85 var (
86 version int = 2
87 decoded []byte
88 cLen int
89 dLen int
90 salt []byte
91 ciphertext_ []byte
92 hmac []byte
93 hmac_ []byte
94 enc []byte
95 nonce []byte
96 auth []byte
97 padded []byte
98 unpaddedLen uint16
99 unpadded []byte
100 )
101 cLen = len(ciphertext)
102 if cLen < 132 || cLen > 87472 {
103 err = errorf.E("invalid payload length: %d", cLen)
104 return
105 }
106 if ciphertext[0:1] == "#" {
107 err = errorf.E("unknown version")
108 return
109 }
110 if decoded, err = base64.StdEncoding.DecodeString(ciphertext); err != nil {
111 err = errorf.E("invalid base64")
112 return
113 }
114 if version = int(decoded[0]); version != 2 {
115 err = errorf.E("unknown version %d", version)
116 return
117 }
118 dLen = len(decoded)
119 if dLen < 99 || dLen > 65603 {
120 err = errorf.E("invalid data length: %d", dLen)
121 return
122 }
123 salt, ciphertext_, hmac_ = decoded[1:33], decoded[33:dLen-32], decoded[dLen-32:]
124 if enc, nonce, auth, err = MessageKeys(conversationKey, salt); err != nil {
125 return
126 }
127 if hmac, err = sha256Hmac(auth, ciphertext_, salt); err != nil {
128 return
129 }
130 if !bytes.Equal(hmac_, hmac) {
131 err = errorf.E("invalid hmac")
132 return
133 }
134 if padded, err = chacha20_(enc, nonce, ciphertext_); err != nil {
135 return
136 }
137 unpaddedLen = binary.BigEndian.Uint16(padded[0:2])
138 if unpaddedLen < uint16(MinPlaintextSize) || unpaddedLen > uint16(MaxPlaintextSize) || len(padded) != 2+calcPadding(int(unpaddedLen)) {
139 err = errorf.E("invalid padding")
140 return
141 }
142 unpadded = padded[2 : unpaddedLen+2]
143 if len(unpadded) == 0 || len(unpadded) != int(unpaddedLen) {
144 err = errorf.E("invalid padding")
145 return
146 }
147 plaintext = string(unpadded)
148 return
149 }
150
151 func GenerateConversationKey(
152 sendPrivkey []byte, recvPubkey []byte,
153 ) (conversationKey []byte, err error) {
154 // Parse the private key
155 var privKey secp256k1.SecretKey
156 if overflow := privKey.Key.SetByteSlice(sendPrivkey); overflow {
157 err = errorf.E(
158 "invalid private key: x coordinate %x is not on the secp256k1 curve",
159 sendPrivkey,
160 )
161 return
162 }
163
164 // Check if private key is zero
165 if privKey.Key.IsZero() {
166 err = errorf.E(
167 "invalid private key: x coordinate %x is not on the secp256k1 curve",
168 sendPrivkey,
169 )
170 return
171 }
172
173 // Parse the public key
174 // If it's 32 bytes, prepend format byte for compressed format (0x02 for even y)
175 // If it's already 33 bytes, use as-is
176 var pubKeyBytes []byte
177 if len(recvPubkey) == 32 {
178 // Nostr-style 32-byte public key - prepend compressed format byte
179 pubKeyBytes = make([]byte, 33)
180 pubKeyBytes[0] = secp256k1.PubKeyFormatCompressedEven
181 copy(pubKeyBytes[1:], recvPubkey)
182 } else if len(recvPubkey) == 33 {
183 // Already in compressed format
184 pubKeyBytes = recvPubkey
185 } else {
186 err = errorf.E(
187 "invalid public key length: %d (expected 32 or 33 bytes)",
188 len(recvPubkey),
189 )
190 return
191 }
192
193 pubKey, err := secp256k1.ParsePubKey(pubKeyBytes)
194 if err != nil {
195 return
196 }
197
198 // Compute ECDH shared secret (returns only x-coordinate, 32 bytes)
199 shared := secp256k1.GenerateSharedSecret(&privKey, pubKey)
200
201 // Apply HKDF-Extract with salt "nip44-v2"
202 conversationKey = hkdf.Extract(sha256.New, shared, []byte("nip44-v2"))
203 return
204 }
205
206 func chacha20_(key []byte, nonce []byte, message []byte) ([]byte, error) {
207 var (
208 cipher *chacha20.Cipher
209 dst = make([]byte, len(message))
210 err error
211 )
212 if cipher, err = chacha20.NewUnauthenticatedCipher(key, nonce); err != nil {
213 return nil, err
214 }
215 cipher.XORKeyStream(dst, message)
216 return dst, nil
217 }
218
219 func randomBytes(n int) ([]byte, error) {
220 buf := make([]byte, n)
221 if _, err := rand.Read(buf); err != nil {
222 return nil, err
223 }
224 return buf, nil
225 }
226
227 func sha256Hmac(key []byte, ciphertext []byte, aad []byte) ([]byte, error) {
228 if len(aad) != 32 {
229 return nil, errors.New("aad data must be 32 bytes")
230 }
231 h := hmac.New(sha256.New, key)
232 h.Write(aad)
233 h.Write(ciphertext)
234 return h.Sum(nil), nil
235 }
236
237 func MessageKeys(conversationKey []byte, salt []byte) (
238 []byte, []byte, []byte, error,
239 ) {
240 var (
241 r io.Reader
242 enc []byte = make([]byte, 32)
243 nonce []byte = make([]byte, 12)
244 auth []byte = make([]byte, 32)
245 err error
246 )
247 if len(conversationKey) != 32 {
248 return nil, nil, nil, errors.New("conversation key must be 32 bytes")
249 }
250 if len(salt) != 32 {
251 return nil, nil, nil, errors.New("salt must be 32 bytes")
252 }
253 r = hkdf.Expand(sha256.New, conversationKey, salt)
254 if _, err = io.ReadFull(r, enc); err != nil {
255 return nil, nil, nil, err
256 }
257 if _, err = io.ReadFull(r, nonce); err != nil {
258 return nil, nil, nil, err
259 }
260 if _, err = io.ReadFull(r, auth); err != nil {
261 return nil, nil, nil, err
262 }
263 return enc, nonce, auth, nil
264 }
265
266 func pad(s []byte) ([]byte, error) {
267 var (
268 sb []byte
269 sbLen int
270 padding int
271 result []byte
272 )
273 sb = s
274 sbLen = len(sb)
275 if sbLen < 1 || sbLen > MaxPlaintextSize {
276 return nil, errors.New("plaintext should be between 1b and 64kB")
277 }
278 padding = calcPadding(sbLen)
279 result = make([]byte, 2)
280 binary.BigEndian.PutUint16(result, uint16(sbLen))
281 result = append(result, sb...)
282 result = append(result, make([]byte, padding-sbLen)...)
283 return result, nil
284 }
285
286 func calcPadding(sLen int) int {
287 var (
288 nextPower int
289 chunk int
290 )
291 if sLen <= 32 {
292 return 32
293 }
294 nextPower = 1 << int(math.Floor(math.Log2(float64(sLen-1)))+1)
295 chunk = int(math.Max(32, float64(nextPower/8)))
296 return chunk * int(math.Floor(float64((sLen-1)/chunk))+1)
297 }
298