diff --git a/go.mod b/go.mod index 24a96c3..cf48eaf 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/gocql/gocql v1.7.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/golang/snappy v0.0.4 // indirect github.com/google/uuid v1.4.0 // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect diff --git a/go.sum b/go.sum index 423a4b9..a472229 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus= github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= diff --git a/pkg/api/base.go b/pkg/api/base.go index 287a675..9f44d82 100644 --- a/pkg/api/base.go +++ b/pkg/api/base.go @@ -7,6 +7,7 @@ import ( "time" "github.com/aykhans/bsky-feedgen/pkg/api/handler" + "github.com/aykhans/bsky-feedgen/pkg/api/middleware" "github.com/aykhans/bsky-feedgen/pkg/config" "github.com/aykhans/bsky-feedgen/pkg/feed" "github.com/aykhans/bsky-feedgen/pkg/logger" @@ -23,13 +24,15 @@ func Run( } feedHandler := handler.NewFeedHandler(feeds, apiConfig.FeedgenPublisherDID) + authMiddleware := middleware.NewAuth(apiConfig.ServiceDID) + mux := http.NewServeMux() mux.HandleFunc("GET /.well-known/did.json", baseHandler.GetWellKnownDIDDoc) mux.HandleFunc("GET /xrpc/app.bsky.feed.describeFeedGenerator", feedHandler.DescribeFeeds) - mux.HandleFunc( + mux.Handle( "GET /xrpc/app.bsky.feed.getFeedSkeleton", - feedHandler.GetFeedSkeleton, + authMiddleware.JWTAuthMiddleware(http.HandlerFunc(feedHandler.GetFeedSkeleton)), ) httpServer := &http.Server{ diff --git a/pkg/api/handler/feed.go b/pkg/api/handler/feed.go index bf123ae..d4b7471 100644 --- a/pkg/api/handler/feed.go +++ b/pkg/api/handler/feed.go @@ -50,7 +50,7 @@ func (handler *FeedHandler) DescribeFeeds(w http.ResponseWriter, r *http.Request } func (handler *FeedHandler) GetFeedSkeleton(w http.ResponseWriter, r *http.Request) { - userDID, _ := r.Context().Value(middleware.UserDIDKey).(string) + userDID, _ := middleware.GetValue[string](r, middleware.UserDIDKey) feedQuery := r.URL.Query().Get("feed") if feedQuery == "" { diff --git a/pkg/api/middleware/auth.go b/pkg/api/middleware/auth.go index 2ed8585..ee5e012 100644 --- a/pkg/api/middleware/auth.go +++ b/pkg/api/middleware/auth.go @@ -6,68 +6,28 @@ import ( "errors" "fmt" "net/http" + "slices" "strings" "time" - "github.com/aykhans/bsky-feedgen/pkg/logger" "github.com/bluesky-social/indigo/atproto/identity" "github.com/bluesky-social/indigo/atproto/syntax" "github.com/golang-jwt/jwt/v5" + "github.com/whyrusleeping/go-did" ) const UserDIDKey ContextKey = "user_did" -func JWTAuthMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - // No auth header, continue without authentication - next.ServeHTTP(w, r) - return - } - - did, err := ValidateAuth(r.Context(), r) - if err != nil { - logger.Log.Error(err.Error()) - } - ctx := context.WithValue(r.Context(), UserDIDKey, did) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -/////////////////////////////////////////////////// - const ( authorizationHeaderName = "Authorization" authorizationHeaderValuePrefix = "Bearer " ) -// AuthorizationError is a custom error type for authorization failures. -type AuthorizationError struct { - Message string - Err error // Wrapped error -} - -// Error returns the formatted error message. -func (e *AuthorizationError) Error() string { - if e.Err != nil { - return fmt.Sprintf("%s: %v", e.Message, e.Err) - } - return e.Message -} - -// Unwrap returns the wrapped error, allowing for errors.Is and errors.As. -func (e *AuthorizationError) Unwrap() error { - return e.Err -} - // Global (or dependency-injected) DID resolver with caching. var didResolver *identity.CacheDirectory -// init function is called once when the package is initialized. func init() { - // Initialize the base directory for actual DID resolution. - baseDir := identity.BaseDirectory{} // Zero value is usable. + baseDir := identity.BaseDirectory{} // Configure cache with appropriate TTLs. // Capacity 0 means unlimited cache size. @@ -84,9 +44,49 @@ func init() { didResolver = &resolver } +type AuthorizationError struct { + Message string + Err error +} + +func (e *AuthorizationError) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s: %v", e.Message, e.Err) + } + return e.Message +} + +func (e *AuthorizationError) Unwrap() error { + return e.Err +} + +type Auth struct { + serviceDID *did.DID +} + +func NewAuth(serviceDID *did.DID) *Auth { + return &Auth{serviceDID} +} + +func (auth *Auth) JWTAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + // No auth header, continue without authentication + next.ServeHTTP(w, r) + return + } + + userDID, _ := auth.validateAuth(r.Context(), r) + ctx := context.WithValue(r.Context(), UserDIDKey, userDID) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + // getDIDSigningKey resolves a DID and extracts its public signing key. // It leverages indigo's identity package which handles multibase decoding and key parsing. -func getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) { +func (auth *Auth) getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) { atID, err := syntax.ParseAtIdentifier(did) if err != nil { return nil, fmt.Errorf("invalid DID syntax: %w", err) @@ -103,7 +103,6 @@ func getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) return nil, fmt.Errorf("DID resolution returned empty identity for %s", did) } - // Get the public key using the PublicKey() method from the Identity struct. publicKey, err := identity.PublicKey() if err != nil { return nil, fmt.Errorf("failed to get signing key for DID %s: %w", did, err) @@ -113,8 +112,7 @@ func getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) } // ValidateAuth validates the authorization header and returns the requester's DID. -// It requires a context.Context for DID resolution to allow for timeouts and cancellation. -func ValidateAuth(ctx context.Context, r *http.Request) (string, error) { +func (auth *Auth) validateAuth(ctx context.Context, r *http.Request) (string, error) { authHeader := r.Header.Get(authorizationHeaderName) if authHeader == "" { return "", &AuthorizationError{Message: "Authorization header is missing"} @@ -127,12 +125,9 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) { jwtString := strings.TrimPrefix(authHeader, authorizationHeaderValuePrefix) jwtString = strings.TrimSpace(jwtString) - // Define a custom claims struct if needed, otherwise use jwt.RegisteredClaims. claims := jwt.RegisteredClaims{} - // Keyfunc callback to dynamically fetch the public key based on the JWT's issuer (iss). - keyFunc := func(token *jwt.Token) (interface{}, error) { - // Assert claims to RegisteredClaims to get the issuer (iss). + keyFunc := func(token *jwt.Token) (any, error) { regClaims, ok := token.Claims.(*jwt.RegisteredClaims) if !ok { return nil, fmt.Errorf("invalid JWT claims type") @@ -143,8 +138,7 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) { return nil, fmt.Errorf("JWT 'iss' claim is missing") } - // Resolve the DID and get the public signing key. - publicKey, err := getDIDSigningKey(ctx, issuerDID) + publicKey, err := auth.getDIDSigningKey(ctx, issuerDID) if err != nil { return nil, fmt.Errorf("failed to get signing key for DID %s: %w", issuerDID, err) } @@ -166,7 +160,6 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) { if errors.Is(err, jwt.ErrTokenMalformed) { return "", &AuthorizationError{Message: "Malformed token", Err: err} } - // Catch other generic parsing or validation errors. return "", &AuthorizationError{Message: "Failed to parse or validate JWT", Err: err} } @@ -174,6 +167,10 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) { return "", &AuthorizationError{Message: "Token is invalid"} } + if slices.Contains(claims.Audience, auth.serviceDID.String()) { + return "", &AuthorizationError{Message: fmt.Sprintf("Invalid audience (expected %s)", auth.serviceDID)} + } + // Return the issuer's DID. return claims.Issuer, nil } diff --git a/pkg/api/middleware/base.go b/pkg/api/middleware/base.go index 58e4e76..491d21b 100644 --- a/pkg/api/middleware/base.go +++ b/pkg/api/middleware/base.go @@ -1,3 +1,19 @@ package middleware +import ( + "net/http" + + "github.com/aykhans/bsky-feedgen/pkg/types" +) + type ContextKey string + +func GetValue[T any](r *http.Request, key ContextKey) (T, error) { + value, ok := r.Context().Value(key).(T) + if ok == false { + var zero T + return zero, types.ErrNotfound + } + + return value, nil +} diff --git a/pkg/api/middleware/es256k.go b/pkg/api/middleware/es256k.go new file mode 100644 index 0000000..994dc06 --- /dev/null +++ b/pkg/api/middleware/es256k.go @@ -0,0 +1,91 @@ +package middleware + +// copied from https://gist.github.com/bnewbold/bc9b97c9b281295da1fa47c03b0b3c69 + +import ( + "crypto" + "errors" + "fmt" + + atcrypto "github.com/bluesky-social/indigo/atproto/crypto" + "github.com/golang-jwt/jwt/v5" +) + +var ( + SigningMethodES256K *SigningMethodAtproto + SigningMethodES256 *SigningMethodAtproto +) + +type SigningMethodAtproto struct { + alg string + hash crypto.Hash + toOutSig toOutSig + sigLen int +} + +type toOutSig func(sig []byte) []byte + +func init() { + SigningMethodES256K = &SigningMethodAtproto{ + alg: "ES256K", + hash: crypto.SHA256, + toOutSig: toES256K, + sigLen: 64, + } + jwt.RegisterSigningMethod(SigningMethodES256K.Alg(), func() jwt.SigningMethod { + return SigningMethodES256K + }) + SigningMethodES256 = &SigningMethodAtproto{ + alg: "ES256", + hash: crypto.SHA256, + toOutSig: toES256, + sigLen: 64, + } + jwt.RegisterSigningMethod(SigningMethodES256.Alg(), func() jwt.SigningMethod { + return SigningMethodES256 + }) + fmt.Println("init Completed") +} + +// Errors returned on different problems. +var ( + ErrWrongKeyFormat = errors.New("wrong key type") + ErrBadSignature = errors.New("bad signature") + ErrVerification = errors.New("signature verification failed") + ErrFailedSigning = errors.New("failed generating signature") + ErrHashUnavailable = errors.New("hasher unavailable") +) + +func (sm *SigningMethodAtproto) Verify(signingString string, sig []byte, key any) error { + pub, ok := key.(atcrypto.PublicKey) + if !ok { + return ErrWrongKeyFormat + } + + if !sm.hash.Available() { + return ErrHashUnavailable + } + + if len(sig) != sm.sigLen { + return ErrBadSignature + } + + return pub.HashAndVerifyLenient([]byte(signingString), sig) +} + +func (sm *SigningMethodAtproto) Sign(signingString string, key any) ([]byte, error) { + // TODO: implement signatures + return nil, ErrFailedSigning +} + +func (sm *SigningMethodAtproto) Alg() string { + return sm.alg +} + +func toES256K(sig []byte) []byte { + return sig[:64] +} + +func toES256(sig []byte) []byte { + return sig[:64] +} diff --git a/pkg/consumer/base.go b/pkg/consumer/base.go index b503f94..e16b18e 100644 --- a/pkg/consumer/base.go +++ b/pkg/consumer/base.go @@ -1,5 +1,9 @@ package consumer +// This file contains code for consuming and processing the Bluesky firehose event stream. +// Most of this implementation is copied and inspired from the original source at: +// https://github.com/bluesky-social/indigo/blob/main/cmd/beemo/firehose_consumer.go + import ( "bytes" "context" diff --git a/pkg/types/errors.go b/pkg/types/errors.go index d06da35..341c79b 100644 --- a/pkg/types/errors.go +++ b/pkg/types/errors.go @@ -4,4 +4,5 @@ import "errors" var ( ErrInternal = errors.New("internal error") + ErrNotfound = errors.New("not found") )