diff options
| author | Levi Durfee <levi.durfee@gmail.com> | 2026-01-06 18:08:41 -0500 |
|---|---|---|
| committer | Levi Durfee <levi.durfee@gmail.com> | 2026-01-06 18:08:50 -0500 |
| commit | 6630c8cb513941b4bb2f0a24f23143665f4b9476 (patch) | |
| tree | e97c64c61a218b2866f5d3668a3869c978f6d722 | |
| parent | ca629087012f6651131ea99805286423aa21c5f8 (diff) | |
Update
| -rw-r--r-- | main.go | 132 |
1 files changed, 103 insertions, 29 deletions
@@ -14,81 +14,155 @@ import ( "github.com/joho/godotenv" ) +type ( + KEK []byte + DEK []byte + WrappedDEK []byte + Ciphertext []byte +) + +var ( + aadWrapDEK = []byte("wrap:dek:v1") + aadDataMsg = []byte("data:msg:v1") + errBadKeyLn = errors.New("invalid key length: must be 16, 24, or 32 bytes") +) + func main() { - err := godotenv.Load() - if err != nil { + if err := godotenv.Load(); err != nil { log.Fatal("Error loading .env file") } - base64kek := os.Getenv("SECRET_KEY") + kek, err := NewKEKFromEnvB64("SECRET_KEY") + if err != nil { + log.Fatal(err) + } - kek, err := base64.StdEncoding.DecodeString(base64kek) + dek, err := NewDEK() if err != nil { - panic(err) + log.Fatal(err) } - dek := GenCipherKey() + edek, err := WrapDEK(dek, kek) + if err != nil { + log.Fatal(err) + } - encryptedDek, err := Encrypt(dek, kek) + ct, err := EncryptData([]byte("hello"), dek) if err != nil { - panic(err) + log.Fatal(err) } - fmt.Println("edek", encryptedDek) + // Print as base64 for readability/transport. + fmt.Println("edek_b64:", base64.StdEncoding.EncodeToString(edek)) + fmt.Println("ct_b64: ", base64.StdEncoding.EncodeToString(ct)) - cipherText, err := Encrypt([]byte("hello"), dek) + // Round-trip demo. + dek2, err := UnwrapDEK(edek, kek) if err != nil { - panic(err) + log.Fatal(err) } - fmt.Println("ciphertext", cipherText) + pt, err := DecryptData(ct, dek2) + if err != nil { + log.Fatal(err) + } + fmt.Println("pt:", string(pt)) } -type CipherKey []byte +func NewKEKFromEnvB64(envVar string) (KEK, error) { + b64 := os.Getenv(envVar) + if b64 == "" { + return nil, fmt.Errorf("%s is not set", envVar) + } + + raw, err := base64.StdEncoding.DecodeString(b64) + if err != nil { + return nil, fmt.Errorf("decode %s base64: %w", envVar, err) + } + + if !validAESKeyLen(len(raw)) { + return nil, errBadKeyLn + } -func GenCipherKey() CipherKey { - key := make([]byte, 32) + return KEK(raw), nil +} + +func NewDEK() (DEK, error) { + key := make([]byte, 32) // AES-256 if _, err := io.ReadFull(rand.Reader, key); err != nil { - log.Fatalf("random key gen: %v", err) + return nil, fmt.Errorf("random DEK gen: %w", err) } - return CipherKey(key) + return DEK(key), nil +} + +func WrapDEK(dek DEK, kek KEK) (WrappedDEK, error) { + edek, err := encryptAEAD([]byte(dek), []byte(kek), aadWrapDEK) + return WrappedDEK(edek), err +} + +func UnwrapDEK(edek WrappedDEK, kek KEK) (DEK, error) { + dek, err := decryptAEAD([]byte(edek), []byte(kek), aadWrapDEK) + return DEK(dek), err +} + +func EncryptData(plaintext []byte, dek DEK) (Ciphertext, error) { + ct, err := encryptAEAD(plaintext, []byte(dek), aadDataMsg) + return Ciphertext(ct), err } -func Encrypt(plaintext []byte, key CipherKey) ([]byte, error) { - c, err := aes.NewCipher(key) +func DecryptData(ct Ciphertext, dek DEK) ([]byte, error) { + return decryptAEAD([]byte(ct), []byte(dek), aadDataMsg) +} + +// encryptAEAD returns: nonce || ciphertext +func encryptAEAD(plaintext, key, aad []byte) ([]byte, error) { + if !validAESKeyLen(len(key)) { + return nil, errBadKeyLn + } + + block, err := aes.NewCipher(key) if err != nil { return nil, err } - gcm, err := cipher.NewGCM(c) + gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonce := make([]byte, gcm.NonceSize()) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, err } - return gcm.Seal(nonce, nonce, plaintext, nil), nil + return gcm.Seal(nonce, nonce, plaintext, aad), nil } -func Decrypt(ciphertext []byte, key CipherKey) ([]byte, error) { - c, err := aes.NewCipher(key) +func decryptAEAD(ciphertext, key, aad []byte) ([]byte, error) { + if !validAESKeyLen(len(key)) { + return nil, errBadKeyLn + } + + block, err := aes.NewCipher(key) if err != nil { return nil, err } - gcm, err := cipher.NewGCM(c) + gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } - nonceSize := gcm.NonceSize() - if len(ciphertext) < nonceSize { + ns := gcm.NonceSize() + if len(ciphertext) < ns { return nil, errors.New("ciphertext too short") } - nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] - return gcm.Open(nil, nonce, ciphertext, nil) + nonce := ciphertext[:ns] + body := ciphertext[ns:] + return gcm.Open(nil, nonce, body, aad) +} + +func validAESKeyLen(n int) bool { + return n == 16 || n == 24 || n == 32 } |
