Compare commits

12 Commits

13 changed files with 333 additions and 37 deletions

View File

@@ -10,7 +10,7 @@ import (
"syscall"
"time"
"github.com/aykhans/bsky-feedgen/pkg/generator"
feedgenAz "github.com/aykhans/bsky-feedgen/pkg/generator/az"
"github.com/aykhans/bsky-feedgen/pkg/types"
"github.com/aykhans/bsky-feedgen/pkg/config"
@@ -87,7 +87,7 @@ Flags:
os.Exit(1)
}
feedGeneratorAz := generator.NewFeedGeneratorAz(postCollection, feedAzCollection)
feedGeneratorAz := feedgenAz.NewGenerator(postCollection, feedAzCollection)
startCrons(ctx, feedGenAzConfig, feedGeneratorAz, feedAzCollection, cursorOption)
logger.Log.Info("Cron jobs started")
@@ -98,7 +98,7 @@ Flags:
func startCrons(
ctx context.Context,
feedGenAzConfig *config.FeedGenAzConfig,
feedGeneratorAz *generator.FeedGeneratorAz,
feedGeneratorAz *feedgenAz.Generator,
feedAzCollection *collections.FeedAzCollection,
cursorOption types.GeneratorCursor,
) {

1
go.mod
View File

@@ -21,6 +21,7 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/gocql/gocql v1.7.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/golang/snappy v0.0.4 // indirect
github.com/google/uuid v1.4.0 // indirect
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect

2
go.sum
View File

@@ -46,6 +46,8 @@ github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/aykhans/bsky-feedgen/pkg/api/handler"
"github.com/aykhans/bsky-feedgen/pkg/api/middleware"
"github.com/aykhans/bsky-feedgen/pkg/config"
"github.com/aykhans/bsky-feedgen/pkg/feed"
"github.com/aykhans/bsky-feedgen/pkg/logger"
@@ -23,13 +24,15 @@ func Run(
}
feedHandler := handler.NewFeedHandler(feeds, apiConfig.FeedgenPublisherDID)
authMiddleware := middleware.NewAuth(apiConfig.ServiceDID)
mux := http.NewServeMux()
mux.HandleFunc("GET /.well-known/did.json", baseHandler.GetWellKnownDIDDoc)
mux.HandleFunc("GET /xrpc/app.bsky.feed.describeFeedGenerator", feedHandler.DescribeFeeds)
mux.HandleFunc(
mux.Handle(
"GET /xrpc/app.bsky.feed.getFeedSkeleton",
feedHandler.GetFeedSkeleton,
authMiddleware.JWTAuthMiddleware(http.HandlerFunc(feedHandler.GetFeedSkeleton)),
)
httpServer := &http.Server{

View File

@@ -50,7 +50,7 @@ func (handler *FeedHandler) DescribeFeeds(w http.ResponseWriter, r *http.Request
}
func (handler *FeedHandler) GetFeedSkeleton(w http.ResponseWriter, r *http.Request) {
userDID, _ := r.Context().Value(middleware.UserDIDKey).(string)
userDID, _ := middleware.GetValue[string](r, middleware.UserDIDKey)
feedQuery := r.URL.Query().Get("feed")
if feedQuery == "" {

View File

@@ -2,12 +2,73 @@ package middleware
import (
"context"
"crypto"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"time"
"github.com/bluesky-social/indigo/atproto/identity"
"github.com/bluesky-social/indigo/atproto/syntax"
"github.com/golang-jwt/jwt/v5"
"github.com/whyrusleeping/go-did"
)
const UserDIDKey ContextKey = "user_did"
func JWTAuthMiddleware(next http.Handler) http.Handler {
const (
authorizationHeaderName = "Authorization"
authorizationHeaderValuePrefix = "Bearer "
)
// Global (or dependency-injected) DID resolver with caching.
var didResolver *identity.CacheDirectory
func init() {
baseDir := identity.BaseDirectory{}
// Configure cache with appropriate TTLs.
// Capacity 0 means unlimited cache size.
// hitTTL: 24 hours for successful resolutions.
// errTTL: 5 minutes for failed resolutions.
// invalidHandleTTL: also 5 minutes for invalid handles.
resolver := identity.NewCacheDirectory(
&baseDir,
0, // Unlimited capacity
24*time.Hour, // hitTTL
5*time.Minute, // errTTL
5*time.Minute, // invalidHandleTTL
)
didResolver = &resolver
}
type AuthorizationError struct {
Message string
Err error
}
func (e *AuthorizationError) Error() string {
if e.Err != nil {
return fmt.Sprintf("%s: %v", e.Message, e.Err)
}
return e.Message
}
func (e *AuthorizationError) Unwrap() error {
return e.Err
}
type Auth struct {
serviceDID *did.DID
}
func NewAuth(serviceDID *did.DID) *Auth {
return &Auth{serviceDID}
}
func (auth *Auth) JWTAuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
@@ -16,8 +77,100 @@ func JWTAuthMiddleware(next http.Handler) http.Handler {
return
}
// TODO: Add auth verification
ctx := context.WithValue(r.Context(), UserDIDKey, "")
userDID, _ := auth.validateAuth(r.Context(), r)
ctx := context.WithValue(r.Context(), UserDIDKey, userDID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// getDIDSigningKey resolves a DID and extracts its public signing key.
// It leverages indigo's identity package which handles multibase decoding and key parsing.
func (auth *Auth) getDIDSigningKey(ctx context.Context, did string) (crypto.PublicKey, error) {
atID, err := syntax.ParseAtIdentifier(did)
if err != nil {
return nil, fmt.Errorf("invalid DID syntax: %w", err)
}
// Use Lookup for bi-directional verification (handle -> DID -> handle).
// The `Lookup` method returns an `Identity` struct which contains `PublicKey()` method
// to get the signing key.
identity, err := didResolver.Lookup(ctx, *atID)
if err != nil {
return nil, fmt.Errorf("DID resolution failed for %s: %w", did, err)
}
if identity == nil || identity.DID.String() == "" {
return nil, fmt.Errorf("DID resolution returned empty identity for %s", did)
}
publicKey, err := identity.PublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get signing key for DID %s: %w", did, err)
}
return publicKey, nil
}
// ValidateAuth validates the authorization header and returns the requester's DID.
func (auth *Auth) validateAuth(ctx context.Context, r *http.Request) (string, error) {
authHeader := r.Header.Get(authorizationHeaderName)
if authHeader == "" {
return "", &AuthorizationError{Message: "Authorization header is missing"}
}
if !strings.HasPrefix(authHeader, authorizationHeaderValuePrefix) {
return "", &AuthorizationError{Message: "Invalid authorization header format"}
}
jwtString := strings.TrimPrefix(authHeader, authorizationHeaderValuePrefix)
jwtString = strings.TrimSpace(jwtString)
claims := jwt.RegisteredClaims{}
keyFunc := func(token *jwt.Token) (any, error) {
regClaims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok {
return nil, fmt.Errorf("invalid JWT claims type")
}
issuerDID := regClaims.Issuer
if issuerDID == "" {
return nil, fmt.Errorf("JWT 'iss' claim is missing")
}
publicKey, err := auth.getDIDSigningKey(ctx, issuerDID)
if err != nil {
return nil, fmt.Errorf("failed to get signing key for DID %s: %w", issuerDID, err)
}
return publicKey, nil
}
token, err := jwt.ParseWithClaims(jwtString, &claims, keyFunc)
if err != nil {
if errors.Is(err, jwt.ErrTokenSignatureInvalid) {
return "", &AuthorizationError{Message: "Invalid signature", Err: err}
}
if errors.Is(err, jwt.ErrTokenExpired) {
return "", &AuthorizationError{Message: "Token expired", Err: err}
}
if errors.Is(err, jwt.ErrTokenNotValidYet) {
return "", &AuthorizationError{Message: "Token not valid yet", Err: err}
}
if errors.Is(err, jwt.ErrTokenMalformed) {
return "", &AuthorizationError{Message: "Malformed token", Err: err}
}
return "", &AuthorizationError{Message: "Failed to parse or validate JWT", Err: err}
}
if !token.Valid {
return "", &AuthorizationError{Message: "Token is invalid"}
}
if slices.Contains(claims.Audience, auth.serviceDID.String()) {
return "", &AuthorizationError{Message: fmt.Sprintf("Invalid audience (expected %s)", auth.serviceDID)}
}
// Return the issuer's DID.
return claims.Issuer, nil
}

View File

@@ -1,3 +1,19 @@
package middleware
import (
"net/http"
"github.com/aykhans/bsky-feedgen/pkg/types"
)
type ContextKey string
func GetValue[T any](r *http.Request, key ContextKey) (T, error) {
value, ok := r.Context().Value(key).(T)
if ok == false {
var zero T
return zero, types.ErrNotfound
}
return value, nil
}

View File

@@ -0,0 +1,91 @@
package middleware
// copied from https://gist.github.com/bnewbold/bc9b97c9b281295da1fa47c03b0b3c69
import (
"crypto"
"errors"
"fmt"
atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
"github.com/golang-jwt/jwt/v5"
)
var (
SigningMethodES256K *SigningMethodAtproto
SigningMethodES256 *SigningMethodAtproto
)
type SigningMethodAtproto struct {
alg string
hash crypto.Hash
toOutSig toOutSig
sigLen int
}
type toOutSig func(sig []byte) []byte
func init() {
SigningMethodES256K = &SigningMethodAtproto{
alg: "ES256K",
hash: crypto.SHA256,
toOutSig: toES256K,
sigLen: 64,
}
jwt.RegisterSigningMethod(SigningMethodES256K.Alg(), func() jwt.SigningMethod {
return SigningMethodES256K
})
SigningMethodES256 = &SigningMethodAtproto{
alg: "ES256",
hash: crypto.SHA256,
toOutSig: toES256,
sigLen: 64,
}
jwt.RegisterSigningMethod(SigningMethodES256.Alg(), func() jwt.SigningMethod {
return SigningMethodES256
})
fmt.Println("init Completed")
}
// Errors returned on different problems.
var (
ErrWrongKeyFormat = errors.New("wrong key type")
ErrBadSignature = errors.New("bad signature")
ErrVerification = errors.New("signature verification failed")
ErrFailedSigning = errors.New("failed generating signature")
ErrHashUnavailable = errors.New("hasher unavailable")
)
func (sm *SigningMethodAtproto) Verify(signingString string, sig []byte, key any) error {
pub, ok := key.(atcrypto.PublicKey)
if !ok {
return ErrWrongKeyFormat
}
if !sm.hash.Available() {
return ErrHashUnavailable
}
if len(sig) != sm.sigLen {
return ErrBadSignature
}
return pub.HashAndVerifyLenient([]byte(signingString), sig)
}
func (sm *SigningMethodAtproto) Sign(signingString string, key any) ([]byte, error) {
// TODO: implement signatures
return nil, ErrFailedSigning
}
func (sm *SigningMethodAtproto) Alg() string {
return sm.alg
}
func toES256K(sig []byte) []byte {
return sig[:64]
}
func toES256(sig []byte) []byte {
return sig[:64]
}

View File

@@ -1,5 +1,9 @@
package consumer
// This file contains code for consuming and processing the Bluesky firehose event stream.
// Most of this implementation is copied and inspired from the original source at:
// https://github.com/bluesky-social/indigo/blob/main/cmd/beemo/firehose_consumer.go
import (
"bytes"
"context"
@@ -264,7 +268,7 @@ func ConsumeAndSaveToMongoDB(
case <-ticker.C:
if len(postBatch) > 0 {
consumerLastFlushingTime = time.Now()
logger.Log.Info("flushing post batch", "count", len(postBatch))
// logger.Log.Info("flushing post batch", "count", len(postBatch))
err := postCollection.Insert(ctx, true, postBatch...)
if err != nil {
return fmt.Errorf("mongodb post insert error: %v", err)
@@ -272,7 +276,7 @@ func ConsumeAndSaveToMongoDB(
postBatch = []*collections.Post{} // Clear batch after insert
} else {
// If we haven't seen any data for 25 seconds, cancel the consumer connection
if consumerLastFlushingTime.Add(time.Second*25).Before(time.Now()) {
if consumerLastFlushingTime.Add(time.Second * 25).Before(time.Now()) {
cancel()
}
}

View File

@@ -1,4 +1,4 @@
package generator
package az
import (
"context"
@@ -13,39 +13,24 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
)
var azInvalidUser []string = []string{
"did:plc:5zww7zorx2ajw7hqrhuix3ba",
"did:plc:c4vhz47h566t2ntgd7gtawen",
}
var azValidUsers []string = []string{
"did:plc:jbt4qi6psd7rutwzedtecsq7",
"did:plc:yzgdpxsklrmfgqmjghdvw3ti",
"did:plc:g7ebgiai577ln3avsi2pt3sn",
"did:plc:phtq2rhgbwipyx5ie3apw44j",
"did:plc:jfdvklrs5n5qv7f25v6swc5h",
"did:plc:u5ez5w6qslh6advti4wyddba",
"did:plc:cs2cbzojm6hmx5lfxiuft3mq",
}
type FeedGeneratorAz struct {
type Generator struct {
postCollection *collections.PostCollection
feedAzCollection *collections.FeedAzCollection
textRegex *regexp.Regexp
}
func NewFeedGeneratorAz(
func NewGenerator(
postCollection *collections.PostCollection,
feedAzCollection *collections.FeedAzCollection,
) *FeedGeneratorAz {
return &FeedGeneratorAz{
) *Generator {
return &Generator{
postCollection: postCollection,
feedAzCollection: feedAzCollection,
textRegex: regexp.MustCompile("(?i)(azerbaijan|azərbaycan|aзербайджан|azerbaycan)"),
}
}
func (generator *FeedGeneratorAz) Start(ctx context.Context, cursorOption types.GeneratorCursor, batchSize int) error {
func (generator *Generator) Start(ctx context.Context, cursorOption types.GeneratorCursor, batchSize int) error {
var mongoCursor *mongo.Cursor
switch cursorOption {
case types.GeneratorCursorLastGenerated:
@@ -124,17 +109,16 @@ func (generator *FeedGeneratorAz) Start(ctx context.Context, cursorOption types.
return nil
}
func (generator *FeedGeneratorAz) IsValid(post *collections.Post) bool {
func (generator *Generator) IsValid(post *collections.Post) bool {
if post.Reply != nil && post.Reply.RootURI != post.Reply.ParentURI {
return false
}
if slices.Contains(azInvalidUser, post.DID) {
return false
if isValidUser := users.IsValid(post.DID); isValidUser != nil {
return *isValidUser
}
if slices.Contains(azValidUsers, post.DID) || // Posts from always-valid users
(slices.Contains(post.Langs, "az") && len(post.Langs) < 3) || // Posts in Azerbaijani language with fewer than 3 languages
if (slices.Contains(post.Langs, "az") && len(post.Langs) < 3) || // Posts in Azerbaijani language with fewer than 3 languages
generator.textRegex.MatchString(post.Text) { // Posts containing Azerbaijan-related keywords
return true
}

27
pkg/generator/az/lists.go Normal file
View File

@@ -0,0 +1,27 @@
package az
import "github.com/aykhans/bsky-feedgen/pkg/generator"
var users = generator.Users{
// Invalid
"did:plc:5zww7zorx2ajw7hqrhuix3ba": false,
"did:plc:c4vhz47h566t2ntgd7gtawen": false,
"did:plc:lc7j7xdq67gn7vc6vzmydfqk": false,
"did:plc:msian4dqa2rqalf3biilnf3m": false,
"did:plc:gtosalycg7snvodjhsze35jm": false,
"did:plc:i53e6y3liw2oaw4s6e6odw5m": false,
"did:plc:pvdqvmpkeermkhy7fezam473": false,
"did:plc:5vwjnzaibnwscbbcvkzhy57v": false,
"did:plc:6mfp3coadoobuvlg6w2avw6x": false,
"did:plc:lm2uhaoqoe6yo76oeihndfyi": false,
// Valid
"did:plc:jbt4qi6psd7rutwzedtecsq7": true,
"did:plc:yzgdpxsklrmfgqmjghdvw3ti": true,
"did:plc:g7ebgiai577ln3avsi2pt3sn": true,
"did:plc:phtq2rhgbwipyx5ie3apw44j": true,
"did:plc:jfdvklrs5n5qv7f25v6swc5h": true,
"did:plc:u5ez5w6qslh6advti4wyddba": true,
"did:plc:cs2cbzojm6hmx5lfxiuft3mq": true,
"did:plc:x7alwnnjygt2aqcwblhazko7": true,
}

14
pkg/generator/base.go Normal file
View File

@@ -0,0 +1,14 @@
package generator
import "github.com/aykhans/bsky-feedgen/pkg/utils"
type Users map[string]bool
func (u Users) IsValid(did string) *bool {
isValid, ok := u[did]
if ok == false {
return nil
}
return utils.ToPtr(isValid)
}

View File

@@ -4,4 +4,5 @@ import "errors"
var (
ErrInternal = errors.New("internal error")
ErrNotfound = errors.New("not found")
)