nip44.go raw

   1  package encryption
   2  
   3  import (
   4  	"bytes"
   5  	"crypto/hmac"
   6  	"crypto/rand"
   7  	"crypto/sha256"
   8  	"encoding/base64"
   9  	"encoding/binary"
  10  	"errors"
  11  	"io"
  12  	"math"
  13  
  14  	"next.orly.dev/pkg/nostr/crypto/ec/secp256k1"
  15  	"golang.org/x/crypto/chacha20"
  16  	"golang.org/x/crypto/hkdf"
  17  	"next.orly.dev/pkg/lol/errorf"
  18  )
  19  
  20  var (
  21  	MinPlaintextSize = 0x0001 // 1b msg => padded to 32b
  22  	MaxPlaintextSize = 0xffff // 65535 (64kb-1) => padded to 64kb
  23  )
  24  
  25  type EncryptOptions struct {
  26  	Salt    []byte
  27  	Version int
  28  }
  29  
  30  func Encrypt(
  31  	conversationKey []byte, plaintext []byte, options *EncryptOptions,
  32  ) (ciphertext string, err error) {
  33  	var (
  34  		version   int = 2
  35  		salt      []byte
  36  		enc       []byte
  37  		nonce     []byte
  38  		auth      []byte
  39  		padded    []byte
  40  		encrypted []byte
  41  		hmac_     []byte
  42  		concat    []byte
  43  	)
  44  	if options != nil && options.Version != 0 {
  45  		version = options.Version
  46  	}
  47  	if options != nil && options.Salt != nil {
  48  		salt = options.Salt
  49  	} else {
  50  		if salt, err = randomBytes(32); err != nil {
  51  			return
  52  		}
  53  	}
  54  	if version != 2 {
  55  		err = errorf.E("unknown version %d", version)
  56  		return
  57  	}
  58  	if len(salt) != 32 {
  59  		err = errorf.E("salt must be 32 bytes")
  60  		return
  61  	}
  62  	if enc, nonce, auth, err = MessageKeys(conversationKey, salt); err != nil {
  63  		return
  64  	}
  65  	if padded, err = pad(plaintext); err != nil {
  66  		return
  67  	}
  68  	if encrypted, err = chacha20_(enc, nonce, padded); err != nil {
  69  		return
  70  	}
  71  	if hmac_, err = sha256Hmac(auth, encrypted, salt); err != nil {
  72  		return
  73  	}
  74  	concat = append(concat, []byte{byte(version)}...)
  75  	concat = append(concat, salt...)
  76  	concat = append(concat, encrypted...)
  77  	concat = append(concat, hmac_...)
  78  	ciphertext = base64.StdEncoding.EncodeToString(concat)
  79  	return
  80  }
  81  
  82  func Decrypt(conversationKey []byte, ciphertext string) (
  83  	plaintext string, err error,
  84  ) {
  85  	var (
  86  		version     int = 2
  87  		decoded     []byte
  88  		cLen        int
  89  		dLen        int
  90  		salt        []byte
  91  		ciphertext_ []byte
  92  		hmac        []byte
  93  		hmac_       []byte
  94  		enc         []byte
  95  		nonce       []byte
  96  		auth        []byte
  97  		padded      []byte
  98  		unpaddedLen uint16
  99  		unpadded    []byte
 100  	)
 101  	cLen = len(ciphertext)
 102  	if cLen < 132 || cLen > 87472 {
 103  		err = errorf.E("invalid payload length: %d", cLen)
 104  		return
 105  	}
 106  	if ciphertext[0:1] == "#" {
 107  		err = errorf.E("unknown version")
 108  		return
 109  	}
 110  	if decoded, err = base64.StdEncoding.DecodeString(ciphertext); err != nil {
 111  		err = errorf.E("invalid base64")
 112  		return
 113  	}
 114  	if version = int(decoded[0]); version != 2 {
 115  		err = errorf.E("unknown version %d", version)
 116  		return
 117  	}
 118  	dLen = len(decoded)
 119  	if dLen < 99 || dLen > 65603 {
 120  		err = errorf.E("invalid data length: %d", dLen)
 121  		return
 122  	}
 123  	salt, ciphertext_, hmac_ = decoded[1:33], decoded[33:dLen-32], decoded[dLen-32:]
 124  	if enc, nonce, auth, err = MessageKeys(conversationKey, salt); err != nil {
 125  		return
 126  	}
 127  	if hmac, err = sha256Hmac(auth, ciphertext_, salt); err != nil {
 128  		return
 129  	}
 130  	if !bytes.Equal(hmac_, hmac) {
 131  		err = errorf.E("invalid hmac")
 132  		return
 133  	}
 134  	if padded, err = chacha20_(enc, nonce, ciphertext_); err != nil {
 135  		return
 136  	}
 137  	unpaddedLen = binary.BigEndian.Uint16(padded[0:2])
 138  	if unpaddedLen < uint16(MinPlaintextSize) || unpaddedLen > uint16(MaxPlaintextSize) || len(padded) != 2+calcPadding(int(unpaddedLen)) {
 139  		err = errorf.E("invalid padding")
 140  		return
 141  	}
 142  	unpadded = padded[2 : unpaddedLen+2]
 143  	if len(unpadded) == 0 || len(unpadded) != int(unpaddedLen) {
 144  		err = errorf.E("invalid padding")
 145  		return
 146  	}
 147  	plaintext = string(unpadded)
 148  	return
 149  }
 150  
 151  func GenerateConversationKey(
 152  	sendPrivkey []byte, recvPubkey []byte,
 153  ) (conversationKey []byte, err error) {
 154  	// Parse the private key
 155  	var privKey secp256k1.SecretKey
 156  	if overflow := privKey.Key.SetByteSlice(sendPrivkey); overflow {
 157  		err = errorf.E(
 158  			"invalid private key: x coordinate %x is not on the secp256k1 curve",
 159  			sendPrivkey,
 160  		)
 161  		return
 162  	}
 163  
 164  	// Check if private key is zero
 165  	if privKey.Key.IsZero() {
 166  		err = errorf.E(
 167  			"invalid private key: x coordinate %x is not on the secp256k1 curve",
 168  			sendPrivkey,
 169  		)
 170  		return
 171  	}
 172  
 173  	// Parse the public key
 174  	// If it's 32 bytes, prepend format byte for compressed format (0x02 for even y)
 175  	// If it's already 33 bytes, use as-is
 176  	var pubKeyBytes []byte
 177  	if len(recvPubkey) == 32 {
 178  		// Nostr-style 32-byte public key - prepend compressed format byte
 179  		pubKeyBytes = make([]byte, 33)
 180  		pubKeyBytes[0] = secp256k1.PubKeyFormatCompressedEven
 181  		copy(pubKeyBytes[1:], recvPubkey)
 182  	} else if len(recvPubkey) == 33 {
 183  		// Already in compressed format
 184  		pubKeyBytes = recvPubkey
 185  	} else {
 186  		err = errorf.E(
 187  			"invalid public key length: %d (expected 32 or 33 bytes)",
 188  			len(recvPubkey),
 189  		)
 190  		return
 191  	}
 192  
 193  	pubKey, err := secp256k1.ParsePubKey(pubKeyBytes)
 194  	if err != nil {
 195  		return
 196  	}
 197  
 198  	// Compute ECDH shared secret (returns only x-coordinate, 32 bytes)
 199  	shared := secp256k1.GenerateSharedSecret(&privKey, pubKey)
 200  
 201  	// Apply HKDF-Extract with salt "nip44-v2"
 202  	conversationKey = hkdf.Extract(sha256.New, shared, []byte("nip44-v2"))
 203  	return
 204  }
 205  
 206  func chacha20_(key []byte, nonce []byte, message []byte) ([]byte, error) {
 207  	var (
 208  		cipher *chacha20.Cipher
 209  		dst    = make([]byte, len(message))
 210  		err    error
 211  	)
 212  	if cipher, err = chacha20.NewUnauthenticatedCipher(key, nonce); err != nil {
 213  		return nil, err
 214  	}
 215  	cipher.XORKeyStream(dst, message)
 216  	return dst, nil
 217  }
 218  
 219  func randomBytes(n int) ([]byte, error) {
 220  	buf := make([]byte, n)
 221  	if _, err := rand.Read(buf); err != nil {
 222  		return nil, err
 223  	}
 224  	return buf, nil
 225  }
 226  
 227  func sha256Hmac(key []byte, ciphertext []byte, aad []byte) ([]byte, error) {
 228  	if len(aad) != 32 {
 229  		return nil, errors.New("aad data must be 32 bytes")
 230  	}
 231  	h := hmac.New(sha256.New, key)
 232  	h.Write(aad)
 233  	h.Write(ciphertext)
 234  	return h.Sum(nil), nil
 235  }
 236  
 237  func MessageKeys(conversationKey []byte, salt []byte) (
 238  	[]byte, []byte, []byte, error,
 239  ) {
 240  	var (
 241  		r     io.Reader
 242  		enc   []byte = make([]byte, 32)
 243  		nonce []byte = make([]byte, 12)
 244  		auth  []byte = make([]byte, 32)
 245  		err   error
 246  	)
 247  	if len(conversationKey) != 32 {
 248  		return nil, nil, nil, errors.New("conversation key must be 32 bytes")
 249  	}
 250  	if len(salt) != 32 {
 251  		return nil, nil, nil, errors.New("salt must be 32 bytes")
 252  	}
 253  	r = hkdf.Expand(sha256.New, conversationKey, salt)
 254  	if _, err = io.ReadFull(r, enc); err != nil {
 255  		return nil, nil, nil, err
 256  	}
 257  	if _, err = io.ReadFull(r, nonce); err != nil {
 258  		return nil, nil, nil, err
 259  	}
 260  	if _, err = io.ReadFull(r, auth); err != nil {
 261  		return nil, nil, nil, err
 262  	}
 263  	return enc, nonce, auth, nil
 264  }
 265  
 266  func pad(s []byte) ([]byte, error) {
 267  	var (
 268  		sb      []byte
 269  		sbLen   int
 270  		padding int
 271  		result  []byte
 272  	)
 273  	sb = s
 274  	sbLen = len(sb)
 275  	if sbLen < 1 || sbLen > MaxPlaintextSize {
 276  		return nil, errors.New("plaintext should be between 1b and 64kB")
 277  	}
 278  	padding = calcPadding(sbLen)
 279  	result = make([]byte, 2)
 280  	binary.BigEndian.PutUint16(result, uint16(sbLen))
 281  	result = append(result, sb...)
 282  	result = append(result, make([]byte, padding-sbLen)...)
 283  	return result, nil
 284  }
 285  
 286  func calcPadding(sLen int) int {
 287  	var (
 288  		nextPower int
 289  		chunk     int
 290  	)
 291  	if sLen <= 32 {
 292  		return 32
 293  	}
 294  	nextPower = 1 << int(math.Floor(math.Log2(float64(sLen-1)))+1)
 295  	chunk = int(math.Max(32, float64(nextPower/8)))
 296  	return chunk * int(math.Floor(float64((sLen-1)/chunk))+1)
 297  }
 298