Rotate rotates the database encryption keys by re-encrypting all user tokens with the first cipher and revoking all other ciphers.
(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Cipher)
| 14 | // Rotate rotates the database encryption keys by re-encrypting all user tokens |
| 15 | // with the first cipher and revoking all other ciphers. |
| 16 | func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Cipher) error { |
| 17 | db := database.New(sqlDB) |
| 18 | cryptDB, err := New(ctx, db, ciphers...) |
| 19 | if err != nil { |
| 20 | return xerrors.Errorf("create cryptdb: %w", err) |
| 21 | } |
| 22 | |
| 23 | userIDs, err := db.AllUserIDs(ctx, false) |
| 24 | if err != nil { |
| 25 | return xerrors.Errorf("get users: %w", err) |
| 26 | } |
| 27 | log.Info(ctx, "encrypting user tokens", slog.F("user_count", len(userIDs))) |
| 28 | for idx, uid := range userIDs { |
| 29 | err := cryptDB.InTx(func(cryptTx database.Store) error { |
| 30 | userLinks, err := cryptTx.GetUserLinksByUserID(ctx, uid) |
| 31 | if err != nil { |
| 32 | return xerrors.Errorf("get user links for user: %w", err) |
| 33 | } |
| 34 | for _, userLink := range userLinks { |
| 35 | if userLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && userLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() { |
| 36 | log.Debug(ctx, "skipping user link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) |
| 37 | continue |
| 38 | } |
| 39 | if _, err := cryptTx.UpdateUserLink(ctx, database.UpdateUserLinkParams{ |
| 40 | OAuthAccessToken: userLink.OAuthAccessToken, |
| 41 | OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required |
| 42 | OAuthRefreshToken: userLink.OAuthRefreshToken, |
| 43 | OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required |
| 44 | OAuthExpiry: userLink.OAuthExpiry, |
| 45 | UserID: uid, |
| 46 | LoginType: userLink.LoginType, |
| 47 | Claims: userLink.Claims, |
| 48 | }); err != nil { |
| 49 | return xerrors.Errorf("update user link user_id=%s linked_id=%s: %w", userLink.UserID, userLink.LinkedID, err) |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | externalAuthLinks, err := cryptTx.GetExternalAuthLinksByUserID(ctx, uid) |
| 54 | if err != nil { |
| 55 | return xerrors.Errorf("get git auth links for user: %w", err) |
| 56 | } |
| 57 | for _, externalAuthLink := range externalAuthLinks { |
| 58 | if externalAuthLink.OAuthAccessTokenKeyID.String == ciphers[0].HexDigest() && externalAuthLink.OAuthRefreshTokenKeyID.String == ciphers[0].HexDigest() { |
| 59 | log.Debug(ctx, "skipping external auth link", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) |
| 60 | continue |
| 61 | } |
| 62 | if _, err := cryptTx.UpdateExternalAuthLink(ctx, database.UpdateExternalAuthLinkParams{ |
| 63 | ProviderID: externalAuthLink.ProviderID, |
| 64 | UserID: uid, |
| 65 | UpdatedAt: externalAuthLink.UpdatedAt, |
| 66 | OAuthAccessToken: externalAuthLink.OAuthAccessToken, |
| 67 | OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required |
| 68 | OAuthRefreshToken: externalAuthLink.OAuthRefreshToken, |
| 69 | OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required |
| 70 | OAuthExpiry: externalAuthLink.OAuthExpiry, |
| 71 | OAuthExtra: externalAuthLink.OAuthExtra, |
| 72 | }); err != nil { |
| 73 | return xerrors.Errorf("update external auth link user_id=%s provider_id=%s: %w", externalAuthLink.UserID, externalAuthLink.ProviderID, err) |
no test coverage detected