nip44.mx raw

   1  package nip44
   2  
   3  import (
   4  	"smesh.lol/web/common/crypto/chacha20"
   5  	"smesh.lol/web/common/crypto/hkdf"
   6  	"smesh.lol/web/common/crypto/hmac"
   7  	"smesh.lol/web/common/helpers"
   8  	"smesh.lol/web/common/jsbridge/schnorr"
   9  )
  10  
  11  // ConversationKey derives the NIP-44 conversation key from ECDH shared secret.
  12  func ConversationKey(seckey, pubkey [32]byte) ([32]byte, bool) {
  13  	shared, ok := schnorr.ECDH(seckey[:], pubkey[:])
  14  	if !ok {
  15  		return [32]byte{}, false
  16  	}
  17  	return hkdf.Extract([]byte("nip44-v2"), shared), true
  18  }
  19  
  20  // Encrypt encrypts plaintext using NIP-44 v2.
  21  // nonce must be 32 random bytes.
  22  // Returns base64-encoded payload.
  23  func Encrypt(plaintext string, conversationKey [32]byte, nonce [32]byte) string {
  24  	// Derive message keys.
  25  	keys := hkdf.Expand(conversationKey[:], nonce[:], 76)
  26  	var chachaKey [32]byte
  27  	var chaChaNonce [12]byte
  28  	copy(chachaKey[:], keys[:32])
  29  	copy(chaChaNonce[:], keys[32:44])
  30  	hmacKey := keys[44:76]
  31  
  32  	// Pad plaintext.
  33  	padded := pad([]byte(plaintext))
  34  
  35  	// Encrypt.
  36  	ciphertext := chacha20.XOR(chachaKey, chaChaNonce, padded)
  37  
  38  	// MAC = HMAC-SHA256(hmacKey, nonce || ciphertext).
  39  	macInput := []byte{:32+len(ciphertext)}
  40  	copy(macInput, nonce[:])
  41  	copy(macInput[32:], ciphertext)
  42  	mac := hmac.Sum(hmacKey, macInput)
  43  
  44  	// Payload = version(1) || nonce(32) || ciphertext || mac(32).
  45  	payload := []byte{:0:1+32+len(ciphertext)+32}
  46  	payload = append(payload, 2) // version
  47  	payload = append(payload, nonce[:]...)
  48  	payload = append(payload, ciphertext...)
  49  	payload = append(payload, mac[:]...)
  50  	return helpers.Base64Encode(payload)
  51  }
  52  
  53  // Decrypt decrypts a NIP-44 v2 base64 payload.
  54  // Returns plaintext and true on success.
  55  func Decrypt(b64payload string, conversationKey [32]byte) (string, bool) {
  56  	raw := helpers.Base64Decode(b64payload)
  57  	if len(raw) < 99 { // 1 + 32 + 32 + 2 + 32 minimum
  58  		return "", false
  59  	}
  60  	if raw[0] != 2 {
  61  		return "", false
  62  	}
  63  
  64  	nonce := raw[1:33]
  65  	ciphertext := raw[33 : len(raw)-32]
  66  	mac := raw[len(raw)-32:]
  67  
  68  	// Derive message keys.
  69  	keys := hkdf.Expand(conversationKey[:], nonce, 76)
  70  	var chachaKey [32]byte
  71  	var chaChaNonce [12]byte
  72  	copy(chachaKey[:], keys[:32])
  73  	copy(chaChaNonce[:], keys[32:44])
  74  	hmacKey := keys[44:76]
  75  
  76  	// Verify MAC.
  77  	macInput := []byte{:32+len(ciphertext)}
  78  	copy(macInput, nonce)
  79  	copy(macInput[32:], ciphertext)
  80  	expectedMac := hmac.Sum(hmacKey, macInput)
  81  	ok := true
  82  	for i := range 32 {
  83  		if mac[i] != expectedMac[i] {
  84  			ok = false
  85  		}
  86  	}
  87  	if !ok {
  88  		return "", false
  89  	}
  90  
  91  	// Decrypt.
  92  	padded := chacha20.XOR(chachaKey, chaChaNonce, ciphertext)
  93  	return unpad(padded)
  94  }
  95  
  96  func calcPadding(length int) int {
  97  	if length <= 32 {
  98  		return 32
  99  	}
 100  	// Next power of 2.
 101  	bits := 0
 102  	v := length - 1
 103  	for v > 0 {
 104  		v >>= 1
 105  		bits++
 106  	}
 107  	nextPow := 1 << bits
 108  	chunk := max(nextPow/8, 32)
 109  	return chunk * ((length-1)/chunk + 1)
 110  }
 111  
 112  func pad(plaintext []byte) []byte {
 113  	n := len(plaintext)
 114  	paddedLen := calcPadding(n)
 115  	out := []byte{:2+paddedLen}
 116  	out[0] = byte(n >> 8)
 117  	out[1] = byte(n)
 118  	copy(out[2:], plaintext)
 119  	return out
 120  }
 121  
 122  func unpad(padded []byte) (string, bool) {
 123  	if len(padded) < 2 {
 124  		return "", false
 125  	}
 126  	n := int(padded[0])<<8 | int(padded[1])
 127  	if n < 1 || n > len(padded)-2 {
 128  		return "", false
 129  	}
 130  	// Verify zero padding.
 131  	for i := 2 + n; i < len(padded); i++ {
 132  		if padded[i] != 0 {
 133  			return "", false
 134  		}
 135  	}
 136  	return string(padded[2 : 2+n]), true
 137  }
 138