nip4.go raw

   1  package encryption
   2  
   3  import (
   4  	"bytes"
   5  	"crypto/aes"
   6  	"crypto/cipher"
   7  	"encoding/base64"
   8  
   9  	"next.orly.dev/pkg/lol/chk"
  10  	"next.orly.dev/pkg/lol/errorf"
  11  	"lukechampine.com/frand"
  12  )
  13  
  14  // EncryptNip4 encrypts message with key using aes-256-cbc. key should be the shared secret generated by
  15  // ComputeSharedSecret.
  16  //
  17  // Returns: base64(encrypted_bytes) + "?iv=" + base64(initialization_vector).
  18  func EncryptNip4(msg, key []byte) (ct []byte, err error) {
  19  	// block size is 16 bytes
  20  	iv := make([]byte, 16)
  21  	if _, err = frand.Read(iv); chk.E(err) {
  22  		err = errorf.E("error creating initialization vector: %w", err)
  23  		return
  24  	}
  25  	// automatically picks aes-256 based on key length (32 bytes)
  26  	var block cipher.Block
  27  	if block, err = aes.NewCipher(key); chk.E(err) {
  28  		err = errorf.E("error creating block cipher: %w", err)
  29  		return
  30  	}
  31  	mode := cipher.NewCBCEncrypter(block, iv)
  32  	plaintext := []byte(msg)
  33  	// add padding
  34  	base := len(plaintext)
  35  	// this will be a number between 1 and 16 (inclusive), never 0
  36  	bs := block.BlockSize()
  37  	padding := bs - base%bs
  38  	// encode the padding in all the padding bytes themselves
  39  	padText := bytes.Repeat([]byte{byte(padding)}, padding)
  40  	paddedMsgBytes := append(plaintext, padText...)
  41  	ciphertext := make([]byte, len(paddedMsgBytes))
  42  	mode.CryptBlocks(ciphertext, paddedMsgBytes)
  43  	return []byte(base64.StdEncoding.EncodeToString(ciphertext) + "?iv=" +
  44  		base64.StdEncoding.EncodeToString(iv)), nil
  45  }
  46  
  47  // DecryptNip4 decrypts a content string using the shared secret key. The inverse operation to message ->
  48  // EncryptNip4(message, key).
  49  func DecryptNip4(content, key []byte) (msg []byte, err error) {
  50  	parts := bytes.Split(content, []byte("?iv="))
  51  	if len(parts) < 2 {
  52  		return nil, errorf.E(
  53  			"error parsing encrypted message: no initialization vector",
  54  		)
  55  	}
  56  	ciphertextBuf := make([]byte, base64.StdEncoding.EncodedLen(len(parts[0])))
  57  	var ciphertextLen int
  58  	if ciphertextLen, err = base64.StdEncoding.Decode(ciphertextBuf, parts[0]); chk.E(err) {
  59  		err = errorf.E("error decoding ciphertext from base64: %w", err)
  60  		return
  61  	}
  62  	ciphertext := ciphertextBuf[:ciphertextLen]
  63  
  64  	ivBuf := make([]byte, base64.StdEncoding.EncodedLen(len(parts[1])))
  65  	var ivLen int
  66  	if ivLen, err = base64.StdEncoding.Decode(ivBuf, parts[1]); chk.E(err) {
  67  		err = errorf.E("error decoding iv from base64: %w", err)
  68  		return
  69  	}
  70  	iv := ivBuf[:ivLen]
  71  	if len(iv) != 16 {
  72  		err = errorf.E("invalid IV length: %d, expected 16", len(iv))
  73  		return
  74  	}
  75  	var block cipher.Block
  76  	if block, err = aes.NewCipher(key); chk.E(err) {
  77  		err = errorf.E("error creating block cipher: %w", err)
  78  		return
  79  	}
  80  	mode := cipher.NewCBCDecrypter(block, iv)
  81  	msg = make([]byte, len(ciphertext))
  82  	mode.CryptBlocks(msg, ciphertext)
  83  	// remove padding
  84  	var (
  85  		plaintextLen = len(msg)
  86  	)
  87  	if plaintextLen > 0 {
  88  		// PKCS#7 padding: padding byte value is the number of padding bytes,
  89  		// valid range is [1, 16] for AES-256-CBC with 16-byte blocks.
  90  		padding := int(msg[plaintextLen-1])
  91  		if padding < 1 || padding > 16 || padding > plaintextLen {
  92  			err = errorf.E("invalid padding amount: %d", padding)
  93  			return
  94  		}
  95  		msg = msg[0 : plaintextLen-padding]
  96  	}
  97  	return msg, nil
  98  }
  99