nip4.go raw
1 package encryption
2
3 import (
4 "bytes"
5 "crypto/aes"
6 "crypto/cipher"
7 "encoding/base64"
8
9 "next.orly.dev/pkg/lol/chk"
10 "next.orly.dev/pkg/lol/errorf"
11 "lukechampine.com/frand"
12 )
13
14 // EncryptNip4 encrypts message with key using aes-256-cbc. key should be the shared secret generated by
15 // ComputeSharedSecret.
16 //
17 // Returns: base64(encrypted_bytes) + "?iv=" + base64(initialization_vector).
18 func EncryptNip4(msg, key []byte) (ct []byte, err error) {
19 // block size is 16 bytes
20 iv := make([]byte, 16)
21 if _, err = frand.Read(iv); chk.E(err) {
22 err = errorf.E("error creating initialization vector: %w", err)
23 return
24 }
25 // automatically picks aes-256 based on key length (32 bytes)
26 var block cipher.Block
27 if block, err = aes.NewCipher(key); chk.E(err) {
28 err = errorf.E("error creating block cipher: %w", err)
29 return
30 }
31 mode := cipher.NewCBCEncrypter(block, iv)
32 plaintext := []byte(msg)
33 // add padding
34 base := len(plaintext)
35 // this will be a number between 1 and 16 (inclusive), never 0
36 bs := block.BlockSize()
37 padding := bs - base%bs
38 // encode the padding in all the padding bytes themselves
39 padText := bytes.Repeat([]byte{byte(padding)}, padding)
40 paddedMsgBytes := append(plaintext, padText...)
41 ciphertext := make([]byte, len(paddedMsgBytes))
42 mode.CryptBlocks(ciphertext, paddedMsgBytes)
43 return []byte(base64.StdEncoding.EncodeToString(ciphertext) + "?iv=" +
44 base64.StdEncoding.EncodeToString(iv)), nil
45 }
46
47 // DecryptNip4 decrypts a content string using the shared secret key. The inverse operation to message ->
48 // EncryptNip4(message, key).
49 func DecryptNip4(content, key []byte) (msg []byte, err error) {
50 parts := bytes.Split(content, []byte("?iv="))
51 if len(parts) < 2 {
52 return nil, errorf.E(
53 "error parsing encrypted message: no initialization vector",
54 )
55 }
56 ciphertextBuf := make([]byte, base64.StdEncoding.EncodedLen(len(parts[0])))
57 var ciphertextLen int
58 if ciphertextLen, err = base64.StdEncoding.Decode(ciphertextBuf, parts[0]); chk.E(err) {
59 err = errorf.E("error decoding ciphertext from base64: %w", err)
60 return
61 }
62 ciphertext := ciphertextBuf[:ciphertextLen]
63
64 ivBuf := make([]byte, base64.StdEncoding.EncodedLen(len(parts[1])))
65 var ivLen int
66 if ivLen, err = base64.StdEncoding.Decode(ivBuf, parts[1]); chk.E(err) {
67 err = errorf.E("error decoding iv from base64: %w", err)
68 return
69 }
70 iv := ivBuf[:ivLen]
71 if len(iv) != 16 {
72 err = errorf.E("invalid IV length: %d, expected 16", len(iv))
73 return
74 }
75 var block cipher.Block
76 if block, err = aes.NewCipher(key); chk.E(err) {
77 err = errorf.E("error creating block cipher: %w", err)
78 return
79 }
80 mode := cipher.NewCBCDecrypter(block, iv)
81 msg = make([]byte, len(ciphertext))
82 mode.CryptBlocks(msg, ciphertext)
83 // remove padding
84 var (
85 plaintextLen = len(msg)
86 )
87 if plaintextLen > 0 {
88 // PKCS#7 padding: padding byte value is the number of padding bytes,
89 // valid range is [1, 16] for AES-256-CBC with 16-byte blocks.
90 padding := int(msg[plaintextLen-1])
91 if padding < 1 || padding > 16 || padding > plaintextLen {
92 err = errorf.E("invalid padding amount: %d", padding)
93 return
94 }
95 msg = msg[0 : plaintextLen-padding]
96 }
97 return msg, nil
98 }
99