diff --git a/go.mod b/go.mod index 24a96c3..cf48eaf 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/gocql/gocql v1.7.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/golang/snappy v0.0.4 // indirect github.com/google/uuid v1.4.0 // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect diff --git a/go.sum b/go.sum index 423a4b9..a472229 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus= github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= diff --git a/pkg/api/base.go b/pkg/api/base.go index 287a675..9f44d82 100644 --- a/pkg/api/base.go +++ b/pkg/api/base.go @@ -7,6 +7,7 @@ import ( "time" "github.com/aykhans/bsky-feedgen/pkg/api/handler" + "github.com/aykhans/bsky-feedgen/pkg/api/middleware" "github.com/aykhans/bsky-feedgen/pkg/config" "github.com/aykhans/bsky-feedgen/pkg/feed" "github.com/aykhans/bsky-feedgen/pkg/logger" @@ -23,13 +24,15 @@ func Run( } feedHandler := handler.NewFeedHandler(feeds, apiConfig.FeedgenPublisherDID) + authMiddleware := middleware.NewAuth(apiConfig.ServiceDID) + mux := http.NewServeMux() mux.HandleFunc("GET /.well-known/did.json", baseHandler.GetWellKnownDIDDoc) mux.HandleFunc("GET /xrpc/app.bsky.feed.describeFeedGenerator", feedHandler.DescribeFeeds) - mux.HandleFunc( + mux.Handle( "GET /xrpc/app.bsky.feed.getFeedSkeleton", - feedHandler.GetFeedSkeleton, + authMiddleware.JWTAuthMiddleware(http.HandlerFunc(feedHandler.GetFeedSkeleton)), ) httpServer := &http.Server{ diff --git a/pkg/api/handler/feed.go b/pkg/api/handler/feed.go index bf123ae..d4b7471 100644 --- a/pkg/api/handler/feed.go +++ b/pkg/api/handler/feed.go @@ -50,7 +50,7 @@ func (handler *FeedHandler) DescribeFeeds(w http.ResponseWriter, r *http.Request } func (handler *FeedHandler) GetFeedSkeleton(w http.ResponseWriter, r *http.Request) { - userDID, _ := r.Context().Value(middleware.UserDIDKey).(string) + userDID, _ := middleware.GetValue[string](r, middleware.UserDIDKey) feedQuery := r.URL.Query().Get("feed") if feedQuery == "" { diff --git a/pkg/api/middleware/auth.go b/pkg/api/middleware/auth.go index 920674f..ee5e012 100644 --- a/pkg/api/middleware/auth.go +++ b/pkg/api/middleware/auth.go @@ -2,12 +2,73 @@ package middleware import ( "context" + "crypto" + "errors" + "fmt" "net/http" + "slices" + "strings" + "time" + + "github.com/bluesky-social/indigo/atproto/identity" + "github.com/bluesky-social/indigo/atproto/syntax" + "github.com/golang-jwt/jwt/v5" + "github.com/whyrusleeping/go-did" ) const UserDIDKey ContextKey = "user_did" -func JWTAuthMiddleware(next http.Handler) http.Handler { +const ( + authorizationHeaderName = "Authorization" + authorizationHeaderValuePrefix = "Bearer " +) + +// Global (or dependency-injected) DID resolver with caching. +var didResolver *identity.CacheDirectory + +func init() { + baseDir := identity.BaseDirectory{} + + // Configure cache with appropriate TTLs. + // Capacity 0 means unlimited cache size. + // hitTTL: 24 hours for successful resolutions. + // errTTL: 5 minutes for failed resolutions. + // invalidHandleTTL: also 5 minutes for invalid handles. + resolver := identity.NewCacheDirectory( + &baseDir, + 0, // Unlimited capacity + 24*time.Hour, // hitTTL + 5*time.Minute, // errTTL + 5*time.Minute, // invalidHandleTTL + ) + didResolver = &resolver +} + +type AuthorizationError struct { + Message string + Err error +} + +func (e *AuthorizationError) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s: %v", e.Message, e.Err) + } + return e.Message +} + +func (e *AuthorizationError) Unwrap() error { + return e.Err +} + +type Auth struct { + serviceDID *did.DID +} + +func NewAuth(serviceDID *did.DID) *Auth { + return &Auth{serviceDID} +} + +func (auth *Auth) JWTAuthMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { @@ -16,8 +77,100 @@ func JWTAuthMiddleware(next http.Handler) http.Handler { return } - // TODO: Add auth verification - ctx := context.WithValue(r.Context(), UserDIDKey, "") + userDID, _ := auth.validateAuth(r.Context(), r) + ctx := context.WithValue(r.Context(), UserDIDKey, userDID) + next.ServeHTTP(w, r.WithContext(ctx)) }) } + +// getDIDSigningKey resolves a DID and extracts its public signing key. +// It leverages indigo's identity package which handles multibase decoding and key parsing. +func (auth *Auth) getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) { + atID, err := syntax.ParseAtIdentifier(did) + if err != nil { + return nil, fmt.Errorf("invalid DID syntax: %w", err) + } + + // Use Lookup for bi-directional verification (handle -> DID -> handle). + // The `Lookup` method returns an `Identity` struct which contains `PublicKey()` method + // to get the signing key. + identity, err := didResolver.Lookup(ctx, *atID) + if err != nil { + return nil, fmt.Errorf("DID resolution failed for %s: %w", did, err) + } + if identity == nil || identity.DID.String() == "" { + return nil, fmt.Errorf("DID resolution returned empty identity for %s", did) + } + + publicKey, err := identity.PublicKey() + if err != nil { + return nil, fmt.Errorf("failed to get signing key for DID %s: %w", did, err) + } + + return publicKey, nil +} + +// ValidateAuth validates the authorization header and returns the requester's DID. +func (auth *Auth) validateAuth(ctx context.Context, r *http.Request) (string, error) { + authHeader := r.Header.Get(authorizationHeaderName) + if authHeader == "" { + return "", &AuthorizationError{Message: "Authorization header is missing"} + } + + if !strings.HasPrefix(authHeader, authorizationHeaderValuePrefix) { + return "", &AuthorizationError{Message: "Invalid authorization header format"} + } + + jwtString := strings.TrimPrefix(authHeader, authorizationHeaderValuePrefix) + jwtString = strings.TrimSpace(jwtString) + + claims := jwt.RegisteredClaims{} + + keyFunc := func(token *jwt.Token) (any, error) { + regClaims, ok := token.Claims.(*jwt.RegisteredClaims) + if !ok { + return nil, fmt.Errorf("invalid JWT claims type") + } + + issuerDID := regClaims.Issuer + if issuerDID == "" { + return nil, fmt.Errorf("JWT 'iss' claim is missing") + } + + publicKey, err := auth.getDIDSigningKey(ctx, issuerDID) + if err != nil { + return nil, fmt.Errorf("failed to get signing key for DID %s: %w", issuerDID, err) + } + + return publicKey, nil + } + + token, err := jwt.ParseWithClaims(jwtString, &claims, keyFunc) + if err != nil { + if errors.Is(err, jwt.ErrTokenSignatureInvalid) { + return "", &AuthorizationError{Message: "Invalid signature", Err: err} + } + if errors.Is(err, jwt.ErrTokenExpired) { + return "", &AuthorizationError{Message: "Token expired", Err: err} + } + if errors.Is(err, jwt.ErrTokenNotValidYet) { + return "", &AuthorizationError{Message: "Token not valid yet", Err: err} + } + if errors.Is(err, jwt.ErrTokenMalformed) { + return "", &AuthorizationError{Message: "Malformed token", Err: err} + } + return "", &AuthorizationError{Message: "Failed to parse or validate JWT", Err: err} + } + + if !token.Valid { + return "", &AuthorizationError{Message: "Token is invalid"} + } + + if slices.Contains(claims.Audience, auth.serviceDID.String()) { + return "", &AuthorizationError{Message: fmt.Sprintf("Invalid audience (expected %s)", auth.serviceDID)} + } + + // Return the issuer's DID. + return claims.Issuer, nil +} diff --git a/pkg/api/middleware/base.go b/pkg/api/middleware/base.go index 58e4e76..491d21b 100644 --- a/pkg/api/middleware/base.go +++ b/pkg/api/middleware/base.go @@ -1,3 +1,19 @@ package middleware +import ( + "net/http" + + "github.com/aykhans/bsky-feedgen/pkg/types" +) + type ContextKey string + +func GetValue[T any](r *http.Request, key ContextKey) (T, error) { + value, ok := r.Context().Value(key).(T) + if ok == false { + var zero T + return zero, types.ErrNotfound + } + + return value, nil +} diff --git a/pkg/api/middleware/es256k.go b/pkg/api/middleware/es256k.go new file mode 100644 index 0000000..994dc06 --- /dev/null +++ b/pkg/api/middleware/es256k.go @@ -0,0 +1,91 @@ +package middleware + +// copied from https://gist.github.com/bnewbold/bc9b97c9b281295da1fa47c03b0b3c69 + +import ( + "crypto" + "errors" + "fmt" + + atcrypto "github.com/bluesky-social/indigo/atproto/crypto" + "github.com/golang-jwt/jwt/v5" +) + +var ( + SigningMethodES256K *SigningMethodAtproto + SigningMethodES256 *SigningMethodAtproto +) + +type SigningMethodAtproto struct { + alg string + hash crypto.Hash + toOutSig toOutSig + sigLen int +} + +type toOutSig func(sig []byte) []byte + +func init() { + SigningMethodES256K = &SigningMethodAtproto{ + alg: "ES256K", + hash: crypto.SHA256, + toOutSig: toES256K, + sigLen: 64, + } + jwt.RegisterSigningMethod(SigningMethodES256K.Alg(), func() jwt.SigningMethod { + return SigningMethodES256K + }) + SigningMethodES256 = &SigningMethodAtproto{ + alg: "ES256", + hash: crypto.SHA256, + toOutSig: toES256, + sigLen: 64, + } + jwt.RegisterSigningMethod(SigningMethodES256.Alg(), func() jwt.SigningMethod { + return SigningMethodES256 + }) + fmt.Println("init Completed") +} + +// Errors returned on different problems. +var ( + ErrWrongKeyFormat = errors.New("wrong key type") + ErrBadSignature = errors.New("bad signature") + ErrVerification = errors.New("signature verification failed") + ErrFailedSigning = errors.New("failed generating signature") + ErrHashUnavailable = errors.New("hasher unavailable") +) + +func (sm *SigningMethodAtproto) Verify(signingString string, sig []byte, key any) error { + pub, ok := key.(atcrypto.PublicKey) + if !ok { + return ErrWrongKeyFormat + } + + if !sm.hash.Available() { + return ErrHashUnavailable + } + + if len(sig) != sm.sigLen { + return ErrBadSignature + } + + return pub.HashAndVerifyLenient([]byte(signingString), sig) +} + +func (sm *SigningMethodAtproto) Sign(signingString string, key any) ([]byte, error) { + // TODO: implement signatures + return nil, ErrFailedSigning +} + +func (sm *SigningMethodAtproto) Alg() string { + return sm.alg +} + +func toES256K(sig []byte) []byte { + return sig[:64] +} + +func toES256(sig []byte) []byte { + return sig[:64] +} diff --git a/pkg/consumer/base.go b/pkg/consumer/base.go index b503f94..e16b18e 100644 --- a/pkg/consumer/base.go +++ b/pkg/consumer/base.go @@ -1,5 +1,9 @@ package consumer +// This file contains code for consuming and processing the Bluesky firehose event stream. +// Most of this implementation is copied and inspired from the original source at: +// https://github.com/bluesky-social/indigo/blob/main/cmd/beemo/firehose_consumer.go + import ( "bytes" "context" diff --git a/pkg/types/errors.go b/pkg/types/errors.go index d06da35..341c79b 100644 --- a/pkg/types/errors.go +++ b/pkg/types/errors.go @@ -4,4 +4,5 @@ import "errors" var ( ErrInternal = errors.New("internal error") + ErrNotfound = errors.New("not found") )