summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLevi Durfee <levi.durfee@gmail.com>2026-01-06 18:08:41 -0500
committerLevi Durfee <levi.durfee@gmail.com>2026-01-06 18:08:50 -0500
commit6630c8cb513941b4bb2f0a24f23143665f4b9476 (patch)
treee97c64c61a218b2866f5d3668a3869c978f6d722
parentca629087012f6651131ea99805286423aa21c5f8 (diff)
Update
-rw-r--r--main.go132
1 files changed, 103 insertions, 29 deletions
diff --git a/main.go b/main.go
index 385c042..e93d2a3 100644
--- a/main.go
+++ b/main.go
@@ -14,81 +14,155 @@ import (
"github.com/joho/godotenv"
)
+type (
+ KEK []byte
+ DEK []byte
+ WrappedDEK []byte
+ Ciphertext []byte
+)
+
+var (
+ aadWrapDEK = []byte("wrap:dek:v1")
+ aadDataMsg = []byte("data:msg:v1")
+ errBadKeyLn = errors.New("invalid key length: must be 16, 24, or 32 bytes")
+)
+
func main() {
- err := godotenv.Load()
- if err != nil {
+ if err := godotenv.Load(); err != nil {
log.Fatal("Error loading .env file")
}
- base64kek := os.Getenv("SECRET_KEY")
+ kek, err := NewKEKFromEnvB64("SECRET_KEY")
+ if err != nil {
+ log.Fatal(err)
+ }
- kek, err := base64.StdEncoding.DecodeString(base64kek)
+ dek, err := NewDEK()
if err != nil {
- panic(err)
+ log.Fatal(err)
}
- dek := GenCipherKey()
+ edek, err := WrapDEK(dek, kek)
+ if err != nil {
+ log.Fatal(err)
+ }
- encryptedDek, err := Encrypt(dek, kek)
+ ct, err := EncryptData([]byte("hello"), dek)
if err != nil {
- panic(err)
+ log.Fatal(err)
}
- fmt.Println("edek", encryptedDek)
+ // Print as base64 for readability/transport.
+ fmt.Println("edek_b64:", base64.StdEncoding.EncodeToString(edek))
+ fmt.Println("ct_b64: ", base64.StdEncoding.EncodeToString(ct))
- cipherText, err := Encrypt([]byte("hello"), dek)
+ // Round-trip demo.
+ dek2, err := UnwrapDEK(edek, kek)
if err != nil {
- panic(err)
+ log.Fatal(err)
}
- fmt.Println("ciphertext", cipherText)
+ pt, err := DecryptData(ct, dek2)
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println("pt:", string(pt))
}
-type CipherKey []byte
+func NewKEKFromEnvB64(envVar string) (KEK, error) {
+ b64 := os.Getenv(envVar)
+ if b64 == "" {
+ return nil, fmt.Errorf("%s is not set", envVar)
+ }
+
+ raw, err := base64.StdEncoding.DecodeString(b64)
+ if err != nil {
+ return nil, fmt.Errorf("decode %s base64: %w", envVar, err)
+ }
+
+ if !validAESKeyLen(len(raw)) {
+ return nil, errBadKeyLn
+ }
-func GenCipherKey() CipherKey {
- key := make([]byte, 32)
+ return KEK(raw), nil
+}
+
+func NewDEK() (DEK, error) {
+ key := make([]byte, 32) // AES-256
if _, err := io.ReadFull(rand.Reader, key); err != nil {
- log.Fatalf("random key gen: %v", err)
+ return nil, fmt.Errorf("random DEK gen: %w", err)
}
- return CipherKey(key)
+ return DEK(key), nil
+}
+
+func WrapDEK(dek DEK, kek KEK) (WrappedDEK, error) {
+ edek, err := encryptAEAD([]byte(dek), []byte(kek), aadWrapDEK)
+ return WrappedDEK(edek), err
+}
+
+func UnwrapDEK(edek WrappedDEK, kek KEK) (DEK, error) {
+ dek, err := decryptAEAD([]byte(edek), []byte(kek), aadWrapDEK)
+ return DEK(dek), err
+}
+
+func EncryptData(plaintext []byte, dek DEK) (Ciphertext, error) {
+ ct, err := encryptAEAD(plaintext, []byte(dek), aadDataMsg)
+ return Ciphertext(ct), err
}
-func Encrypt(plaintext []byte, key CipherKey) ([]byte, error) {
- c, err := aes.NewCipher(key)
+func DecryptData(ct Ciphertext, dek DEK) ([]byte, error) {
+ return decryptAEAD([]byte(ct), []byte(dek), aadDataMsg)
+}
+
+// encryptAEAD returns: nonce || ciphertext
+func encryptAEAD(plaintext, key, aad []byte) ([]byte, error) {
+ if !validAESKeyLen(len(key)) {
+ return nil, errBadKeyLn
+ }
+
+ block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
- gcm, err := cipher.NewGCM(c)
+ gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
- if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
+ if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
- return gcm.Seal(nonce, nonce, plaintext, nil), nil
+ return gcm.Seal(nonce, nonce, plaintext, aad), nil
}
-func Decrypt(ciphertext []byte, key CipherKey) ([]byte, error) {
- c, err := aes.NewCipher(key)
+func decryptAEAD(ciphertext, key, aad []byte) ([]byte, error) {
+ if !validAESKeyLen(len(key)) {
+ return nil, errBadKeyLn
+ }
+
+ block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
- gcm, err := cipher.NewGCM(c)
+ gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
- nonceSize := gcm.NonceSize()
- if len(ciphertext) < nonceSize {
+ ns := gcm.NonceSize()
+ if len(ciphertext) < ns {
return nil, errors.New("ciphertext too short")
}
- nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
- return gcm.Open(nil, nonce, ciphertext, nil)
+ nonce := ciphertext[:ns]
+ body := ciphertext[ns:]
+ return gcm.Open(nil, nonce, body, aad)
+}
+
+func validAESKeyLen(n int) bool {
+ return n == 16 || n == 24 || n == 32
}