mirror of
https://github.com/aykhans/bsky-feedgen.git
synced 2025-05-30 02:30:03 +00:00
Add auth middleware
This commit is contained in:
parent
b6eaaf7331
commit
1eecbafd07
1
go.mod
1
go.mod
@ -21,6 +21,7 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/gocql/gocql v1.7.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/google/uuid v1.4.0 // indirect
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
|
||||
|
2
go.sum
2
go.sum
@ -46,6 +46,8 @@ github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
||||
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/api/handler"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/api/middleware"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/config"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/feed"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/logger"
|
||||
@ -23,13 +24,15 @@ func Run(
|
||||
}
|
||||
feedHandler := handler.NewFeedHandler(feeds, apiConfig.FeedgenPublisherDID)
|
||||
|
||||
authMiddleware := middleware.NewAuth(apiConfig.ServiceDID)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("GET /.well-known/did.json", baseHandler.GetWellKnownDIDDoc)
|
||||
mux.HandleFunc("GET /xrpc/app.bsky.feed.describeFeedGenerator", feedHandler.DescribeFeeds)
|
||||
mux.HandleFunc(
|
||||
mux.Handle(
|
||||
"GET /xrpc/app.bsky.feed.getFeedSkeleton",
|
||||
feedHandler.GetFeedSkeleton,
|
||||
authMiddleware.JWTAuthMiddleware(http.HandlerFunc(feedHandler.GetFeedSkeleton)),
|
||||
)
|
||||
|
||||
httpServer := &http.Server{
|
||||
|
@ -50,7 +50,7 @@ func (handler *FeedHandler) DescribeFeeds(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
func (handler *FeedHandler) GetFeedSkeleton(w http.ResponseWriter, r *http.Request) {
|
||||
userDID, _ := r.Context().Value(middleware.UserDIDKey).(string)
|
||||
userDID, _ := middleware.GetValue[string](r, middleware.UserDIDKey)
|
||||
|
||||
feedQuery := r.URL.Query().Get("feed")
|
||||
if feedQuery == "" {
|
||||
|
@ -6,68 +6,28 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/logger"
|
||||
"github.com/bluesky-social/indigo/atproto/identity"
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/whyrusleeping/go-did"
|
||||
)
|
||||
|
||||
const UserDIDKey ContextKey = "user_did"
|
||||
|
||||
func JWTAuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
// No auth header, continue without authentication
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
did, err := ValidateAuth(r.Context(), r)
|
||||
if err != nil {
|
||||
logger.Log.Error(err.Error())
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), UserDIDKey, did)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////
|
||||
|
||||
const (
|
||||
authorizationHeaderName = "Authorization"
|
||||
authorizationHeaderValuePrefix = "Bearer "
|
||||
)
|
||||
|
||||
// AuthorizationError is a custom error type for authorization failures.
|
||||
type AuthorizationError struct {
|
||||
Message string
|
||||
Err error // Wrapped error
|
||||
}
|
||||
|
||||
// Error returns the formatted error message.
|
||||
func (e *AuthorizationError) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("%s: %v", e.Message, e.Err)
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Unwrap returns the wrapped error, allowing for errors.Is and errors.As.
|
||||
func (e *AuthorizationError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// Global (or dependency-injected) DID resolver with caching.
|
||||
var didResolver *identity.CacheDirectory
|
||||
|
||||
// init function is called once when the package is initialized.
|
||||
func init() {
|
||||
// Initialize the base directory for actual DID resolution.
|
||||
baseDir := identity.BaseDirectory{} // Zero value is usable.
|
||||
baseDir := identity.BaseDirectory{}
|
||||
|
||||
// Configure cache with appropriate TTLs.
|
||||
// Capacity 0 means unlimited cache size.
|
||||
@ -84,9 +44,49 @@ func init() {
|
||||
didResolver = &resolver
|
||||
}
|
||||
|
||||
type AuthorizationError struct {
|
||||
Message string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *AuthorizationError) Error() string {
|
||||
if e.Err != nil {
|
||||
return fmt.Sprintf("%s: %v", e.Message, e.Err)
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func (e *AuthorizationError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
serviceDID *did.DID
|
||||
}
|
||||
|
||||
func NewAuth(serviceDID *did.DID) *Auth {
|
||||
return &Auth{serviceDID}
|
||||
}
|
||||
|
||||
func (auth *Auth) JWTAuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
// No auth header, continue without authentication
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userDID, _ := auth.validateAuth(r.Context(), r)
|
||||
ctx := context.WithValue(r.Context(), UserDIDKey, userDID)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// getDIDSigningKey resolves a DID and extracts its public signing key.
|
||||
// It leverages indigo's identity package which handles multibase decoding and key parsing.
|
||||
func getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) {
|
||||
func (auth *Auth) getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) {
|
||||
atID, err := syntax.ParseAtIdentifier(did)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid DID syntax: %w", err)
|
||||
@ -103,7 +103,6 @@ func getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error)
|
||||
return nil, fmt.Errorf("DID resolution returned empty identity for %s", did)
|
||||
}
|
||||
|
||||
// Get the public key using the PublicKey() method from the Identity struct.
|
||||
publicKey, err := identity.PublicKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get signing key for DID %s: %w", did, err)
|
||||
@ -113,8 +112,7 @@ func getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error)
|
||||
}
|
||||
|
||||
// ValidateAuth validates the authorization header and returns the requester's DID.
|
||||
// It requires a context.Context for DID resolution to allow for timeouts and cancellation.
|
||||
func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
|
||||
func (auth *Auth) validateAuth(ctx context.Context, r *http.Request) (string, error) {
|
||||
authHeader := r.Header.Get(authorizationHeaderName)
|
||||
if authHeader == "" {
|
||||
return "", &AuthorizationError{Message: "Authorization header is missing"}
|
||||
@ -127,12 +125,9 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
|
||||
jwtString := strings.TrimPrefix(authHeader, authorizationHeaderValuePrefix)
|
||||
jwtString = strings.TrimSpace(jwtString)
|
||||
|
||||
// Define a custom claims struct if needed, otherwise use jwt.RegisteredClaims.
|
||||
claims := jwt.RegisteredClaims{}
|
||||
|
||||
// Keyfunc callback to dynamically fetch the public key based on the JWT's issuer (iss).
|
||||
keyFunc := func(token *jwt.Token) (interface{}, error) {
|
||||
// Assert claims to RegisteredClaims to get the issuer (iss).
|
||||
keyFunc := func(token *jwt.Token) (any, error) {
|
||||
regClaims, ok := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid JWT claims type")
|
||||
@ -143,8 +138,7 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
|
||||
return nil, fmt.Errorf("JWT 'iss' claim is missing")
|
||||
}
|
||||
|
||||
// Resolve the DID and get the public signing key.
|
||||
publicKey, err := getDIDSigningKey(ctx, issuerDID)
|
||||
publicKey, err := auth.getDIDSigningKey(ctx, issuerDID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get signing key for DID %s: %w", issuerDID, err)
|
||||
}
|
||||
@ -166,7 +160,6 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
|
||||
if errors.Is(err, jwt.ErrTokenMalformed) {
|
||||
return "", &AuthorizationError{Message: "Malformed token", Err: err}
|
||||
}
|
||||
// Catch other generic parsing or validation errors.
|
||||
return "", &AuthorizationError{Message: "Failed to parse or validate JWT", Err: err}
|
||||
}
|
||||
|
||||
@ -174,6 +167,10 @@ func ValidateAuth(ctx context.Context, r *http.Request) (string, error) {
|
||||
return "", &AuthorizationError{Message: "Token is invalid"}
|
||||
}
|
||||
|
||||
if slices.Contains(claims.Audience, auth.serviceDID.String()) {
|
||||
return "", &AuthorizationError{Message: fmt.Sprintf("Invalid audience (expected %s)", auth.serviceDID)}
|
||||
}
|
||||
|
||||
// Return the issuer's DID.
|
||||
return claims.Issuer, nil
|
||||
}
|
||||
|
@ -1,3 +1,19 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/types"
|
||||
)
|
||||
|
||||
type ContextKey string
|
||||
|
||||
func GetValue[T any](r *http.Request, key ContextKey) (T, error) {
|
||||
value, ok := r.Context().Value(key).(T)
|
||||
if ok == false {
|
||||
var zero T
|
||||
return zero, types.ErrNotfound
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
91
pkg/api/middleware/es256k.go
Normal file
91
pkg/api/middleware/es256k.go
Normal file
@ -0,0 +1,91 @@
|
||||
package middleware
|
||||
|
||||
// copied from https://gist.github.com/bnewbold/bc9b97c9b281295da1fa47c03b0b3c69
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
SigningMethodES256K *SigningMethodAtproto
|
||||
SigningMethodES256 *SigningMethodAtproto
|
||||
)
|
||||
|
||||
type SigningMethodAtproto struct {
|
||||
alg string
|
||||
hash crypto.Hash
|
||||
toOutSig toOutSig
|
||||
sigLen int
|
||||
}
|
||||
|
||||
type toOutSig func(sig []byte) []byte
|
||||
|
||||
func init() {
|
||||
SigningMethodES256K = &SigningMethodAtproto{
|
||||
alg: "ES256K",
|
||||
hash: crypto.SHA256,
|
||||
toOutSig: toES256K,
|
||||
sigLen: 64,
|
||||
}
|
||||
jwt.RegisterSigningMethod(SigningMethodES256K.Alg(), func() jwt.SigningMethod {
|
||||
return SigningMethodES256K
|
||||
})
|
||||
SigningMethodES256 = &SigningMethodAtproto{
|
||||
alg: "ES256",
|
||||
hash: crypto.SHA256,
|
||||
toOutSig: toES256,
|
||||
sigLen: 64,
|
||||
}
|
||||
jwt.RegisterSigningMethod(SigningMethodES256.Alg(), func() jwt.SigningMethod {
|
||||
return SigningMethodES256
|
||||
})
|
||||
fmt.Println("init Completed")
|
||||
}
|
||||
|
||||
// Errors returned on different problems.
|
||||
var (
|
||||
ErrWrongKeyFormat = errors.New("wrong key type")
|
||||
ErrBadSignature = errors.New("bad signature")
|
||||
ErrVerification = errors.New("signature verification failed")
|
||||
ErrFailedSigning = errors.New("failed generating signature")
|
||||
ErrHashUnavailable = errors.New("hasher unavailable")
|
||||
)
|
||||
|
||||
func (sm *SigningMethodAtproto) Verify(signingString string, sig []byte, key any) error {
|
||||
pub, ok := key.(atcrypto.PublicKey)
|
||||
if !ok {
|
||||
return ErrWrongKeyFormat
|
||||
}
|
||||
|
||||
if !sm.hash.Available() {
|
||||
return ErrHashUnavailable
|
||||
}
|
||||
|
||||
if len(sig) != sm.sigLen {
|
||||
return ErrBadSignature
|
||||
}
|
||||
|
||||
return pub.HashAndVerifyLenient([]byte(signingString), sig)
|
||||
}
|
||||
|
||||
func (sm *SigningMethodAtproto) Sign(signingString string, key any) ([]byte, error) {
|
||||
// TODO: implement signatures
|
||||
return nil, ErrFailedSigning
|
||||
}
|
||||
|
||||
func (sm *SigningMethodAtproto) Alg() string {
|
||||
return sm.alg
|
||||
}
|
||||
|
||||
func toES256K(sig []byte) []byte {
|
||||
return sig[:64]
|
||||
}
|
||||
|
||||
func toES256(sig []byte) []byte {
|
||||
return sig[:64]
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
package consumer
|
||||
|
||||
// This file contains code for consuming and processing the Bluesky firehose event stream.
|
||||
// Most of this implementation is copied and inspired from the original source at:
|
||||
// https://github.com/bluesky-social/indigo/blob/main/cmd/beemo/firehose_consumer.go
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
@ -4,4 +4,5 @@ import "errors"
|
||||
|
||||
var (
|
||||
ErrInternal = errors.New("internal error")
|
||||
ErrNotfound = errors.New("not found")
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user