Add auth middleware

This commit is contained in:
2025-05-24 02:51:25 +04:00
parent b6eaaf7331
commit 1eecbafd07
9 changed files with 172 additions and 57 deletions

View File

@@ -6,68 +6,28 @@ import (
"errors"
"fmt"
"net/http"
"slices"
"strings"
"time"
"github.com/aykhans/bsky-feedgen/pkg/logger"
"github.com/bluesky-social/indigo/atproto/identity"
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/golang-jwt/jwt/v5"
"github.com/whyrusleeping/go-did"
)
const UserDIDKey ContextKey = "user_did"
func JWTAuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
// No auth header, continue without authentication
next.ServeHTTP(w, r)
return
}
did, err := ValidateAuth(r.Context(), r)
if err != nil {
logger.Log.Error(err.Error())
}
ctx := context.WithValue(r.Context(), UserDIDKey, did)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
///////////////////////////////////////////////////
const (
authorizationHeaderName = "Authorization"
authorizationHeaderValuePrefix = "Bearer "
)
// AuthorizationError is a custom error type for authorization failures.
type AuthorizationError struct {
Message string
Err error // Wrapped error
}
// Error returns the formatted error message.
func (e *AuthorizationError) Error() string {
if e.Err != nil {
return fmt.Sprintf("%s: %v", e.Message, e.Err)
}
return e.Message
}
// Unwrap returns the wrapped error, allowing for errors.Is and errors.As.
func (e *AuthorizationError) Unwrap() error {
return e.Err
}
// Global (or dependency-injected) DID resolver with caching.
var didResolver *identity.CacheDirectory
// init function is called once when the package is initialized.
func init() {
// Initialize the base directory for actual DID resolution.
baseDir := identity.BaseDirectory{} // Zero value is usable.
baseDir := identity.BaseDirectory{}
// Configure cache with appropriate TTLs.
// Capacity 0 means unlimited cache size.
@@ -84,9 +44,49 @@ func init() {
didResolver = &resolver
}
type AuthorizationError struct {
Message string
Err error
}
func (e *AuthorizationError) Error() string {
if e.Err != nil {
return fmt.Sprintf("%s: %v", e.Message, e.Err)
}
return e.Message
}
func (e *AuthorizationError) Unwrap() error {
return e.Err
}
type Auth struct {
serviceDID *did.DID
}
func NewAuth(serviceDID *did.DID) *Auth {
return &Auth{serviceDID}
}
func (auth *Auth) JWTAuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
// No auth header, continue without authentication
next.ServeHTTP(w, r)
return
}
userDID, _ := auth.validateAuth(r.Context(), r)
ctx := context.WithValue(r.Context(), UserDIDKey, userDID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// getDIDSigningKey resolves a DID and extracts its public signing key.
// It leverages indigo's identity package which handles multibase decoding and key parsing.
func getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) {
func (auth *Auth) getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) {
atID, err := syntax.ParseAtIdentifier(did)
if err != nil {
return nil, fmt.Errorf("invalid DID syntax: %w", err)
@@ -103,7 +103,6 @@ func getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error)
return nil, fmt.Errorf("DID resolution returned empty identity for %s", did)
}
// Get the public key using the PublicKey() method from the Identity struct.
publicKey, err := identity.PublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get signing key for DID %s: %w", did, err)
@@ -113,8 +112,7 @@ func getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error)
}
// ValidateAuth validates the authorization header and returns the requester's DID.
// It requires a context.Context for DID resolution to allow for timeouts and cancellation.
func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
func (auth *Auth) validateAuth(ctx context.Context, r *http.Request) (string, error) {
authHeader := r.Header.Get(authorizationHeaderName)
if authHeader == "" {
return "", &AuthorizationError{Message: "Authorization header is missing"}
@@ -127,12 +125,9 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
jwtString := strings.TrimPrefix(authHeader, authorizationHeaderValuePrefix)
jwtString = strings.TrimSpace(jwtString)
// Define a custom claims struct if needed, otherwise use jwt.RegisteredClaims.
claims := jwt.RegisteredClaims{}
// Keyfunc callback to dynamically fetch the public key based on the JWT's issuer (iss).
keyFunc := func(token *jwt.Token) (interface{}, error) {
// Assert claims to RegisteredClaims to get the issuer (iss).
keyFunc := func(token *jwt.Token) (any, error) {
regClaims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok {
return nil, fmt.Errorf("invalid JWT claims type")
@@ -143,8 +138,7 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
return nil, fmt.Errorf("JWT 'iss' claim is missing")
}
// Resolve the DID and get the public signing key.
publicKey, err := getDIDSigningKey(ctx, issuerDID)
publicKey, err := auth.getDIDSigningKey(ctx, issuerDID)
if err != nil {
return nil, fmt.Errorf("failed to get signing key for DID %s: %w", issuerDID, err)
}
@@ -166,7 +160,6 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
if errors.Is(err, jwt.ErrTokenMalformed) {
return "", &AuthorizationError{Message: "Malformed token", Err: err}
}
// Catch other generic parsing or validation errors.
return "", &AuthorizationError{Message: "Failed to parse or validate JWT", Err: err}
}
@@ -174,6 +167,10 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
return "", &AuthorizationError{Message: "Token is invalid"}
}
if slices.Contains(claims.Audience, auth.serviceDID.String()) {
return "", &AuthorizationError{Message: fmt.Sprintf("Invalid audience (expected %s)", auth.serviceDID)}
}
// Return the issuer's DID.
return claims.Issuer, nil
}