mirror of
https://github.com/aykhans/bsky-feedgen.git
synced 2025-07-17 05:14:01 +00:00
🦋
This commit is contained in:
62
pkg/api/base.go
Normal file
62
pkg/api/base.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/api/handler"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/config"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/feed"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/logger"
|
||||
)
|
||||
|
||||
func Run(
|
||||
ctx context.Context,
|
||||
apiConfig *config.APIConfig,
|
||||
feeds []feed.Feed,
|
||||
) error {
|
||||
baseHandler, err := handler.NewBaseHandler(apiConfig.FeedgenHostname, apiConfig.ServiceDID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
feedHandler := handler.NewFeedHandler(feeds, apiConfig.FeedgenPublisherDID)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("GET /.well-known/did.json", baseHandler.GetWellKnownDIDDoc)
|
||||
mux.HandleFunc("GET /xrpc/app.bsky.feed.describeFeedGenerator", feedHandler.DescribeFeeds)
|
||||
mux.HandleFunc(
|
||||
"GET /xrpc/app.bsky.feed.getFeedSkeleton",
|
||||
feedHandler.GetFeedSkeleton,
|
||||
)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", apiConfig.APIPort),
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
listenerErrChan := make(chan error)
|
||||
|
||||
logger.Log.Info(fmt.Sprintf("Starting server on port %d", apiConfig.APIPort))
|
||||
go func() {
|
||||
listenerErrChan <- httpServer.ListenAndServe()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-listenerErrChan:
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
return fmt.Errorf("error while serving http: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer shutdownCancel()
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("error while shutting down http server: %v", err)
|
||||
}
|
||||
}
|
||||
logger.Log.Info(fmt.Sprintf("Server on port %d stopped", apiConfig.APIPort))
|
||||
|
||||
return nil
|
||||
}
|
49
pkg/api/handler/base.go
Normal file
49
pkg/api/handler/base.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/api/response"
|
||||
"github.com/whyrusleeping/go-did"
|
||||
)
|
||||
|
||||
type BaseHandler struct {
|
||||
WellKnownDIDDoc did.Document
|
||||
}
|
||||
|
||||
func NewBaseHandler(serviceEndpoint *url.URL, serviceDID *did.DID) (*BaseHandler, error) {
|
||||
serviceID, err := did.ParseDID("#bsky_fg")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("service ID parse error: %v", err)
|
||||
}
|
||||
|
||||
return &BaseHandler{
|
||||
WellKnownDIDDoc: did.Document{
|
||||
Context: []string{did.CtxDIDv1},
|
||||
ID: *serviceDID,
|
||||
Service: []did.Service{
|
||||
{
|
||||
ID: serviceID,
|
||||
Type: "BskyFeedGenerator",
|
||||
ServiceEndpoint: serviceEndpoint.String(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type WellKnownDidResponse struct {
|
||||
Context []string `json:"@context"`
|
||||
ID string `json:"id"`
|
||||
Service []did.Service `json:"service"`
|
||||
}
|
||||
|
||||
func (handler *BaseHandler) GetWellKnownDIDDoc(w http.ResponseWriter, r *http.Request) {
|
||||
response.JSON(w, 200, WellKnownDidResponse{
|
||||
Context: handler.WellKnownDIDDoc.Context,
|
||||
ID: handler.WellKnownDIDDoc.ID.String(),
|
||||
Service: handler.WellKnownDIDDoc.Service,
|
||||
})
|
||||
}
|
101
pkg/api/handler/feed.go
Normal file
101
pkg/api/handler/feed.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/api/middleware"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/api/response"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/feed"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/types"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/utils"
|
||||
"github.com/bluesky-social/indigo/api/bsky"
|
||||
"github.com/whyrusleeping/go-did"
|
||||
)
|
||||
|
||||
type FeedHandler struct {
|
||||
feedsOutput []*bsky.FeedDescribeFeedGenerator_Feed
|
||||
feedsMap map[string]feed.Feed
|
||||
publisherDID *did.DID
|
||||
}
|
||||
|
||||
func NewFeedHandler(feeds []feed.Feed, publisherDID *did.DID) *FeedHandler {
|
||||
ctx := context.Background()
|
||||
|
||||
feedsMap := make(map[string]feed.Feed)
|
||||
for _, feed := range feeds {
|
||||
feedsMap[feed.GetName(ctx)] = feed
|
||||
}
|
||||
|
||||
feedsOutput := make([]*bsky.FeedDescribeFeedGenerator_Feed, len(feeds))
|
||||
for i, f := range feeds {
|
||||
feedsOutput[i] = utils.ToPtr(f.Describe(ctx))
|
||||
}
|
||||
|
||||
return &FeedHandler{
|
||||
feedsOutput: feedsOutput,
|
||||
feedsMap: feedsMap,
|
||||
publisherDID: publisherDID,
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *FeedHandler) DescribeFeeds(w http.ResponseWriter, r *http.Request) {
|
||||
response.JSON(w, 200, bsky.FeedDescribeFeedGenerator_Output{
|
||||
Did: handler.publisherDID.String(),
|
||||
Feeds: handler.feedsOutput,
|
||||
})
|
||||
}
|
||||
|
||||
func (handler *FeedHandler) GetFeedSkeleton(w http.ResponseWriter, r *http.Request) {
|
||||
userDID, _ := r.Context().Value(middleware.UserDIDKey).(string)
|
||||
|
||||
feedQuery := r.URL.Query().Get("feed")
|
||||
if feedQuery == "" {
|
||||
response.JSON(w, 400, response.M{"error": "feed query parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
feedNameStartingIndex := strings.LastIndex(feedQuery, "/")
|
||||
if feedNameStartingIndex == -1 {
|
||||
response.JSON(w, 400, response.M{"error": "feed query parameter is invalid"})
|
||||
}
|
||||
|
||||
feedName := feedQuery[feedNameStartingIndex+1:]
|
||||
feed := handler.feedsMap[feedName]
|
||||
if feed == nil {
|
||||
response.JSON(w, 400, response.M{"error": "feed not found"})
|
||||
return
|
||||
}
|
||||
|
||||
limitQuery := r.URL.Query().Get("limit")
|
||||
var limit int64 = 50
|
||||
if limitQuery != "" {
|
||||
parsedLimit, err := strconv.ParseInt(limitQuery, 10, 64)
|
||||
if err == nil && parsedLimit >= 1 && parsedLimit <= 100 {
|
||||
limit = parsedLimit
|
||||
}
|
||||
}
|
||||
|
||||
cursor := r.URL.Query().Get("cursor")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
feedItems, newCursor, err := feed.GetPage(ctx, userDID, limit, cursor)
|
||||
if err != nil {
|
||||
if err == types.ErrInternal {
|
||||
response.JSON500(w)
|
||||
return
|
||||
}
|
||||
response.JSON(w, 400, response.M{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
response.JSON(w, 200, bsky.FeedGetFeedSkeleton_Output{
|
||||
Feed: feedItems,
|
||||
Cursor: newCursor,
|
||||
})
|
||||
}
|
23
pkg/api/middleware/auth.go
Normal file
23
pkg/api/middleware/auth.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const UserDIDKey ContextKey = "user_did"
|
||||
|
||||
func JWTAuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
// No auth header, continue without authentication
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Add auth verification
|
||||
ctx := context.WithValue(r.Context(), UserDIDKey, "")
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
3
pkg/api/middleware/base.go
Normal file
3
pkg/api/middleware/base.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package middleware
|
||||
|
||||
type ContextKey string
|
27
pkg/api/response/json.go
Normal file
27
pkg/api/response/json.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/logger"
|
||||
)
|
||||
|
||||
type M map[string]any
|
||||
|
||||
func JSON(w http.ResponseWriter, statusCode int, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
logger.Log.Error("Failed to encode JSON response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func JSON500(w http.ResponseWriter) {
|
||||
JSON(w, 500, M{"error": "Internal server error"})
|
||||
}
|
||||
|
||||
func JSON404(w http.ResponseWriter) {
|
||||
JSON(w, 404, M{"error": "Not found"})
|
||||
}
|
24
pkg/api/response/text.go
Normal file
24
pkg/api/response/text.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/logger"
|
||||
)
|
||||
|
||||
func Text(w http.ResponseWriter, statusCode int, content []byte) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(statusCode)
|
||||
if _, err := w.Write(content); err != nil {
|
||||
logger.Log.Error("Failed to write text response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func Text404(w http.ResponseWriter) {
|
||||
Text(w, 404, []byte("Not found"))
|
||||
}
|
||||
|
||||
func Text500(w http.ResponseWriter) {
|
||||
Text(w, 500, []byte("Internal server error"))
|
||||
}
|
70
pkg/config/api.go
Normal file
70
pkg/config/api.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"slices"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/types"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/utils"
|
||||
"github.com/whyrusleeping/go-did"
|
||||
)
|
||||
|
||||
type APIConfig struct {
|
||||
FeedgenHostname *url.URL
|
||||
ServiceDID *did.DID
|
||||
FeedgenPublisherDID *did.DID
|
||||
APIPort uint16
|
||||
}
|
||||
|
||||
func NewAPIConfig() (*APIConfig, types.ErrMap) {
|
||||
errs := make(types.ErrMap)
|
||||
|
||||
defaultHostname, _ := url.Parse("http://localhost")
|
||||
feedgenHostname, err := utils.GetEnvOr("FEEDGEN_HOSTNAME", defaultHostname)
|
||||
if err != nil {
|
||||
errs["FEEDGEN_HOSTNAME"] = err
|
||||
} else {
|
||||
if !slices.Contains([]string{"", "http", "https"}, feedgenHostname.Scheme) {
|
||||
errs["FEEDGEN_HOSTNAME"] = fmt.Errorf(
|
||||
"invalid schema '%s' for FEEDGEN_HOSTNAME. Accepted schemas are: '', 'http', 'https'",
|
||||
feedgenHostname.Scheme,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
serviceDID, err := did.ParseDID("did:web:" + feedgenHostname.Hostname())
|
||||
if err != nil {
|
||||
errs["SERVICE_DID"] = fmt.Errorf("failed to parse service DID: %w", err)
|
||||
}
|
||||
|
||||
defaultDID, _ := did.ParseDID("did:plc:development")
|
||||
feedgenPublisherDID, err := utils.GetEnvOr("FEEDGEN_PUBLISHER_DID", &defaultDID)
|
||||
if err != nil {
|
||||
errs["FEEDGEN_PUBLISHER_DID"] = err
|
||||
}
|
||||
|
||||
apiPort, err := utils.GetEnv[uint16]("API_PORT")
|
||||
if err != nil {
|
||||
errs["API_PORT"] = err
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
if feedgenHostname.Scheme == "" {
|
||||
if feedgenHostname.Host == "" {
|
||||
feedgenHostname, _ = url.Parse("https://" + feedgenHostname.String())
|
||||
} else {
|
||||
feedgenHostname.Scheme = "https://"
|
||||
}
|
||||
}
|
||||
|
||||
return &APIConfig{
|
||||
FeedgenHostname: feedgenHostname,
|
||||
ServiceDID: &serviceDID,
|
||||
FeedgenPublisherDID: feedgenPublisherDID,
|
||||
APIPort: apiPort,
|
||||
}, nil
|
||||
}
|
40
pkg/config/consumer.go
Normal file
40
pkg/config/consumer.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/types"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/utils"
|
||||
)
|
||||
|
||||
type ConsumerConfig struct {
|
||||
PostMaxDate time.Duration
|
||||
PostCollectionCutoffCronDelay time.Duration
|
||||
PostCollectionCutoffCronMaxDocument int64
|
||||
}
|
||||
|
||||
func NewConsumerConfig() (*ConsumerConfig, types.ErrMap) {
|
||||
errs := make(types.ErrMap)
|
||||
maxDate, err := utils.GetEnv[time.Duration]("POST_MAX_DATE")
|
||||
if err != nil {
|
||||
errs["POST_MAX_DATE"] = err
|
||||
}
|
||||
cronDelay, err := utils.GetEnv[time.Duration]("POST_COLLECTION_CUTOFF_CRON_DELAY")
|
||||
if err != nil {
|
||||
errs["POST_COLLECTION_CUTOFF_CRON_DELAY"] = err
|
||||
}
|
||||
cronMaxDocument, err := utils.GetEnv[int64]("POST_COLLECTION_CUTOFF_CRON_MAX_DOCUMENT")
|
||||
if err != nil {
|
||||
errs["POST_COLLECTION_CUTOFF_CRON_MAX_DOCUMENT"] = err
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return &ConsumerConfig{
|
||||
PostMaxDate: maxDate,
|
||||
PostCollectionCutoffCronDelay: cronDelay,
|
||||
PostCollectionCutoffCronMaxDocument: cronMaxDocument,
|
||||
}, nil
|
||||
}
|
40
pkg/config/feedgen.go
Normal file
40
pkg/config/feedgen.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/types"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/utils"
|
||||
)
|
||||
|
||||
type FeedGenAzConfig struct {
|
||||
CollectionMaxDocument int64
|
||||
GeneratorCronDelay time.Duration
|
||||
CutoffCronDelay time.Duration
|
||||
}
|
||||
|
||||
func NewFeedGenAzConfig() (*FeedGenAzConfig, types.ErrMap) {
|
||||
errs := make(types.ErrMap)
|
||||
maxDocument, err := utils.GetEnv[int64]("FEED_AZ_COLLECTION_CUTOFF_CRON_MAX_DOCUMENT")
|
||||
if err != nil {
|
||||
errs["FEED_AZ_COLLECTION_CUTOFF_CRON_MAX_DOCUMENT"] = err
|
||||
}
|
||||
generatorCronDelay, err := utils.GetEnv[time.Duration]("FEED_AZ_GENERATER_CRON_DELAY")
|
||||
if err != nil {
|
||||
errs["FEED_AZ_GENERATER_CRON_DELAY"] = err
|
||||
}
|
||||
cutoffCronDelay, err := utils.GetEnv[time.Duration]("FEED_AZ_COLLECTION_CUTOFF_CRON_DELAY")
|
||||
if err != nil {
|
||||
errs["FEED_AZ_COLLECTION_CUTOFF_CRON_DELAY"] = err
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return &FeedGenAzConfig{
|
||||
CollectionMaxDocument: maxDocument,
|
||||
GeneratorCronDelay: generatorCronDelay,
|
||||
CutoffCronDelay: cutoffCronDelay,
|
||||
}, nil
|
||||
}
|
46
pkg/config/mongodb.go
Normal file
46
pkg/config/mongodb.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"github.com/aykhans/bsky-feedgen/pkg/types"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/utils"
|
||||
)
|
||||
|
||||
const MongoDBBaseDB = "main"
|
||||
|
||||
type MongoDBConfig struct {
|
||||
Host string
|
||||
Port uint16
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
func NewMongoDBConfig() (*MongoDBConfig, types.ErrMap) {
|
||||
errs := make(types.ErrMap)
|
||||
host, err := utils.GetEnv[string]("MONGODB_HOST")
|
||||
if err != nil {
|
||||
errs["host"] = err
|
||||
}
|
||||
port, err := utils.GetEnv[uint16]("MONGODB_PORT")
|
||||
if err != nil {
|
||||
errs["port"] = err
|
||||
}
|
||||
username, err := utils.GetEnvOr("MONGODB_USERNAME", "")
|
||||
if err != nil {
|
||||
errs["username"] = err
|
||||
}
|
||||
password, err := utils.GetEnvOr("MONGODB_PASSWORD", "")
|
||||
if err != nil {
|
||||
errs["password"] = err
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return &MongoDBConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}, nil
|
||||
}
|
273
pkg/consumer/base.go
Normal file
273
pkg/consumer/base.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package consumer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/logger"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/storage/mongodb/collections"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/types"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/utils"
|
||||
|
||||
comatproto "github.com/bluesky-social/indigo/api/atproto"
|
||||
"github.com/bluesky-social/indigo/api/bsky"
|
||||
"github.com/bluesky-social/indigo/atproto/syntax"
|
||||
"github.com/bluesky-social/indigo/events/schedulers/parallel"
|
||||
lexutil "github.com/bluesky-social/indigo/lex/util"
|
||||
|
||||
"github.com/bluesky-social/indigo/events"
|
||||
"github.com/bluesky-social/indigo/repo"
|
||||
"github.com/bluesky-social/indigo/repomgr"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type CallbackData struct {
|
||||
Sequence int64
|
||||
DID syntax.DID
|
||||
RecordKey syntax.RecordKey
|
||||
Post bsky.FeedPost
|
||||
}
|
||||
|
||||
type CallbackFunc func(int64, syntax.DID, syntax.RecordKey, bsky.FeedPost)
|
||||
|
||||
func RunFirehoseConsumer(
|
||||
ctx context.Context,
|
||||
relayHost string,
|
||||
callbackFunc CallbackFunc,
|
||||
cursor *int64,
|
||||
) error {
|
||||
dialer := websocket.DefaultDialer
|
||||
u, err := url.Parse(relayHost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid relayHost URI: %w", err)
|
||||
}
|
||||
|
||||
u.Path = "xrpc/com.atproto.sync.subscribeRepos"
|
||||
if cursor != nil {
|
||||
q := url.Values{}
|
||||
q.Set("cursor", strconv.FormatInt(*cursor, 10))
|
||||
u.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
logger.Log.Info("subscribing to repo event stream", "upstream", relayHost)
|
||||
con, _, err := dialer.Dial(u.String(), http.Header{
|
||||
"User-Agent": []string{"Firehose-Consumer"},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("subscribing to firehose failed (dialing): %w", err)
|
||||
}
|
||||
|
||||
rsc := &events.RepoStreamCallbacks{
|
||||
RepoCommit: func(evt *comatproto.SyncSubscribeRepos_Commit) error {
|
||||
return HandleRepoCommit(ctx, evt, callbackFunc)
|
||||
},
|
||||
}
|
||||
|
||||
var scheduler events.Scheduler
|
||||
parallelism := 8
|
||||
scheduler = parallel.NewScheduler(
|
||||
parallelism,
|
||||
100_000,
|
||||
relayHost,
|
||||
rsc.EventHandler,
|
||||
)
|
||||
logger.Log.Info("firehose scheduler configured", "workers", parallelism)
|
||||
|
||||
err = events.HandleRepoStream(ctx, con, scheduler, logger.Log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("repoStream error: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func HandleRepoCommit(
|
||||
ctx context.Context,
|
||||
evt *comatproto.SyncSubscribeRepos_Commit,
|
||||
postCallback CallbackFunc,
|
||||
) error {
|
||||
localLogger := logger.Log.With("event", "commit", "did", evt.Repo, "rev", evt.Rev, "seq", evt.Seq)
|
||||
|
||||
if evt.TooBig {
|
||||
localLogger.Warn("skipping tooBig events for now")
|
||||
return nil
|
||||
}
|
||||
|
||||
did, err := syntax.ParseDID(evt.Repo)
|
||||
if err != nil {
|
||||
localLogger.Error("bad DID syntax in event", "err", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
rr, err := repo.ReadRepoFromCar(ctx, bytes.NewReader(evt.Blocks))
|
||||
if err != nil {
|
||||
localLogger.Error("failed to read repo from car", "err", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, op := range evt.Ops {
|
||||
localLogger = localLogger.With("eventKind", op.Action, "path", op.Path)
|
||||
collection, rkey, err := syntax.ParseRepoPath(op.Path)
|
||||
if err != nil {
|
||||
localLogger.Error("invalid path in repo op")
|
||||
return nil
|
||||
}
|
||||
|
||||
ek := repomgr.EventKind(op.Action)
|
||||
switch ek {
|
||||
case repomgr.EvtKindCreateRecord, repomgr.EvtKindUpdateRecord:
|
||||
// read the record bytes from blocks, and verify CID
|
||||
rc, recordCBOR, err := rr.GetRecordBytes(ctx, op.Path)
|
||||
if err != nil {
|
||||
localLogger.Error("reading record from event blocks (CAR)", "err", err)
|
||||
continue
|
||||
}
|
||||
if op.Cid == nil || lexutil.LexLink(rc) != *op.Cid {
|
||||
localLogger.Error("mismatch between commit op CID and record block", "recordCID", rc, "opCID", op.Cid)
|
||||
continue
|
||||
}
|
||||
|
||||
switch collection {
|
||||
case "app.bsky.feed.post":
|
||||
var post bsky.FeedPost
|
||||
if err := post.UnmarshalCBOR(bytes.NewReader(*recordCBOR)); err != nil {
|
||||
localLogger.Error("failed to parse app.bsky.feed.post record", "err", err)
|
||||
continue
|
||||
}
|
||||
postCallback(evt.Seq, did, rkey, post)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ConsumeAndSaveToMongoDB(
|
||||
ctx context.Context,
|
||||
postCollection *collections.PostCollection,
|
||||
relayHost string,
|
||||
cursorOption types.ConsumerCursor,
|
||||
oldestPostDuration time.Duration,
|
||||
batchFlushTime time.Duration,
|
||||
) error {
|
||||
firehoseDataChan := make(chan CallbackData, 500)
|
||||
localCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var sequenceCursor *int64
|
||||
switch cursorOption {
|
||||
case types.ConsumerCursorLastConsumed:
|
||||
var err error
|
||||
sequenceCursor, err = postCollection.GetMaxSequence(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case types.ConsumerCursorFirstStream:
|
||||
sequenceCursor = utils.ToPtr[int64](0)
|
||||
case types.ConsumerCursorCurrentStream:
|
||||
sequenceCursor = nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
for {
|
||||
err := RunFirehoseConsumer(
|
||||
ctx,
|
||||
relayHost,
|
||||
func(sequence int64, did syntax.DID, recordKey syntax.RecordKey, post bsky.FeedPost) {
|
||||
firehoseDataChan <- CallbackData{sequence, did, recordKey, post}
|
||||
},
|
||||
sequenceCursor,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
logger.Log.Error(err.Error())
|
||||
if !strings.HasPrefix(err.Error(), "repoStream error") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
sequenceCursor, err = postCollection.GetMaxSequence(ctx)
|
||||
if err != nil {
|
||||
logger.Log.Error(err.Error())
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
postBatch := []*collections.Post{}
|
||||
ticker := time.NewTicker(batchFlushTime)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Log.Info("consumer shutting down")
|
||||
return nil
|
||||
|
||||
case <-localCtx.Done():
|
||||
return nil
|
||||
|
||||
case data := <-firehoseDataChan:
|
||||
facets := &collections.Facets{}
|
||||
for _, facet := range data.Post.Facets {
|
||||
for _, feature := range facet.Features {
|
||||
if feature.RichtextFacet_Mention != nil {
|
||||
facets.Mentions = append(facets.Mentions, feature.RichtextFacet_Mention.Did)
|
||||
}
|
||||
if feature.RichtextFacet_Link != nil {
|
||||
facets.Links = append(facets.Links, feature.RichtextFacet_Link.Uri)
|
||||
}
|
||||
if feature.RichtextFacet_Tag != nil {
|
||||
facets.Tags = append(facets.Tags, feature.RichtextFacet_Tag.Tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reply := &collections.Reply{}
|
||||
if data.Post.Reply != nil {
|
||||
if data.Post.Reply.Root != nil {
|
||||
reply.RootURI = data.Post.Reply.Root.Uri
|
||||
}
|
||||
if data.Post.Reply.Parent != nil {
|
||||
reply.ParentURI = data.Post.Reply.Parent.Uri
|
||||
}
|
||||
}
|
||||
|
||||
createdAt, _ := time.Parse(time.RFC3339, data.Post.CreatedAt)
|
||||
if createdAt.After(time.Now().UTC().Add(-oldestPostDuration)) {
|
||||
postItem := &collections.Post{
|
||||
ID: fmt.Sprintf("%s/%s", data.DID, data.RecordKey),
|
||||
Sequence: data.Sequence,
|
||||
DID: data.DID.String(),
|
||||
RecordKey: data.RecordKey.String(),
|
||||
CreatedAt: createdAt,
|
||||
Langs: data.Post.Langs,
|
||||
Tags: data.Post.Tags,
|
||||
Text: data.Post.Text,
|
||||
Facets: facets,
|
||||
Reply: reply,
|
||||
}
|
||||
postBatch = append(postBatch, postItem)
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
if len(postBatch) > 0 {
|
||||
// logger.Log.Info("flushing post batch", "count", len(postBatch))
|
||||
err := postCollection.Insert(ctx, true, postBatch...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mongodb post insert error: %v", err)
|
||||
}
|
||||
postBatch = []*collections.Post{} // Clear batch after insert
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
80
pkg/feed/az.go
Normal file
80
pkg/feed/az.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package feed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/logger"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/storage/mongodb/collections"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/types"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/utils"
|
||||
"github.com/bluesky-social/indigo/api/bsky"
|
||||
"github.com/whyrusleeping/go-did"
|
||||
)
|
||||
|
||||
type FeedAz struct {
|
||||
name string
|
||||
did *did.DID
|
||||
feedAzCollection *collections.FeedAzCollection
|
||||
}
|
||||
|
||||
func NewFeedAz(name string, publisherDID *did.DID, feedAzCollection *collections.FeedAzCollection) *FeedAz {
|
||||
return &FeedAz{
|
||||
name: name,
|
||||
did: publisherDID,
|
||||
feedAzCollection: feedAzCollection,
|
||||
}
|
||||
}
|
||||
|
||||
func (f FeedAz) GetName(_ context.Context) string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f *FeedAz) Describe(_ context.Context) bsky.FeedDescribeFeedGenerator_Feed {
|
||||
return bsky.FeedDescribeFeedGenerator_Feed{
|
||||
Uri: "at://" + f.did.String() + "/app.bsky.feed.generator/" + f.name,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FeedAz) GetPage(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
limit int64,
|
||||
cursor string,
|
||||
) ([]*bsky.FeedDefs_SkeletonFeedPost, *string, error) {
|
||||
var cursorInt int64 = 0
|
||||
if cursor != "" {
|
||||
var err error
|
||||
cursorInt, err = strconv.ParseInt(cursor, 10, 64)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("cursor is not an integer: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
feedAzItems, err := f.feedAzCollection.GetByCreatedAt(ctx, cursorInt, limit+1)
|
||||
if err != nil {
|
||||
logger.Log.Error("failed to get feedAzCollection items", "error", err)
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
|
||||
var newCursor *string
|
||||
|
||||
if feedAzItemsLen := int64(len(feedAzItems)); limit >= feedAzItemsLen {
|
||||
posts := make([]*bsky.FeedDefs_SkeletonFeedPost, feedAzItemsLen)
|
||||
for i, feedItem := range feedAzItems {
|
||||
posts[i] = &bsky.FeedDefs_SkeletonFeedPost{
|
||||
Post: "at://" + feedItem.DID + "/app.bsky.feed.post/" + feedItem.RecordKey,
|
||||
}
|
||||
}
|
||||
return posts, newCursor, nil
|
||||
} else {
|
||||
posts := make([]*bsky.FeedDefs_SkeletonFeedPost, feedAzItemsLen-1)
|
||||
for i, feedItem := range feedAzItems[:feedAzItemsLen-1] {
|
||||
posts[i] = &bsky.FeedDefs_SkeletonFeedPost{
|
||||
Post: "at://" + feedItem.DID + "/app.bsky.feed.post/" + feedItem.RecordKey,
|
||||
}
|
||||
}
|
||||
return posts, utils.ToPtr(strconv.FormatInt(cursorInt+limit, 10)), nil
|
||||
}
|
||||
}
|
13
pkg/feed/base.go
Normal file
13
pkg/feed/base.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package feed
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/bluesky-social/indigo/api/bsky"
|
||||
)
|
||||
|
||||
type Feed interface {
|
||||
GetPage(ctx context.Context, userDID string, limit int64, cursor string) (feedPosts []*bsky.FeedDefs_SkeletonFeedPost, newCursor *string, err error)
|
||||
GetName(ctx context.Context) string
|
||||
Describe(ctx context.Context) bsky.FeedDescribeFeedGenerator_Feed
|
||||
}
|
130
pkg/generator/az.go
Normal file
130
pkg/generator/az.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"slices"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/storage/mongodb/collections"
|
||||
"github.com/aykhans/bsky-feedgen/pkg/types"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
var azValidUsers []string = []string{
|
||||
"did:plc:jbt4qi6psd7rutwzedtecsq7",
|
||||
"did:plc:yzgdpxsklrmfgqmjghdvw3ti",
|
||||
"did:plc:cs2cbzojm6hmx5lfxiuft3mq",
|
||||
}
|
||||
|
||||
type FeedGeneratorAz struct {
|
||||
postCollection *collections.PostCollection
|
||||
feedAzCollection *collections.FeedAzCollection
|
||||
textRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
func NewFeedGeneratorAz(
|
||||
postCollection *collections.PostCollection,
|
||||
feedAzCollection *collections.FeedAzCollection,
|
||||
) *FeedGeneratorAz {
|
||||
return &FeedGeneratorAz{
|
||||
postCollection: postCollection,
|
||||
feedAzCollection: feedAzCollection,
|
||||
textRegex: regexp.MustCompile("(?i)(azerbaijan|azərbaycan|aзербайджан|azerbaycan)"),
|
||||
}
|
||||
}
|
||||
|
||||
func (generator *FeedGeneratorAz) Start(ctx context.Context, cursorOption types.GeneratorCursor, batchSize int) error {
|
||||
var mongoCursor *mongo.Cursor
|
||||
switch cursorOption {
|
||||
case types.GeneratorCursorLastGenerated:
|
||||
sequenceCursor, err := generator.feedAzCollection.GetMaxSequence(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sequenceCursor == nil {
|
||||
mongoCursor, err = generator.postCollection.Collection.Find(
|
||||
ctx,
|
||||
bson.D{},
|
||||
options.Find().SetSort(bson.D{{Key: "sequence", Value: 1}}),
|
||||
)
|
||||
} else {
|
||||
mongoCursor, err = generator.postCollection.Collection.Find(
|
||||
ctx,
|
||||
bson.M{"sequence": bson.M{"$gt": *sequenceCursor}},
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case types.GeneratorCursorFirstPost:
|
||||
var err error
|
||||
mongoCursor, err = generator.postCollection.Collection.Find(
|
||||
ctx,
|
||||
bson.D{},
|
||||
options.Find().SetSort(bson.D{{Key: "sequence", Value: 1}}),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
defer func() { _ = mongoCursor.Close(ctx) }()
|
||||
|
||||
feedAzBatch := []*collections.FeedAz{}
|
||||
for mongoCursor.Next(ctx) {
|
||||
var doc *collections.Post
|
||||
if err := mongoCursor.Decode(&doc); err != nil {
|
||||
return fmt.Errorf("mongodb cursor decode error: %v", err)
|
||||
}
|
||||
|
||||
if generator.IsValid(doc) == false {
|
||||
continue
|
||||
}
|
||||
|
||||
feedAzBatch = append(
|
||||
feedAzBatch,
|
||||
&collections.FeedAz{
|
||||
ID: doc.ID,
|
||||
Sequence: doc.Sequence,
|
||||
DID: doc.DID,
|
||||
RecordKey: doc.RecordKey,
|
||||
CreatedAt: doc.CreatedAt,
|
||||
},
|
||||
)
|
||||
|
||||
if len(feedAzBatch)%batchSize == 0 {
|
||||
err := generator.feedAzCollection.Insert(ctx, true, feedAzBatch...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert FeedAz error: %v", err)
|
||||
}
|
||||
feedAzBatch = []*collections.FeedAz{}
|
||||
}
|
||||
}
|
||||
|
||||
if len(feedAzBatch) > 0 {
|
||||
err := generator.feedAzCollection.Insert(ctx, true, feedAzBatch...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert FeedAz error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (generator *FeedGeneratorAz) IsValid(post *collections.Post) bool {
|
||||
if post.Reply != nil && post.Reply.RootURI != post.Reply.ParentURI {
|
||||
return false
|
||||
}
|
||||
|
||||
if slices.Contains(azValidUsers, post.DID) || // Posts from always-valid users
|
||||
(slices.Contains(post.Langs, "az") && len(post.Langs) < 3) || // Posts in Azerbaijani language with fewer than 3 languages
|
||||
generator.textRegex.MatchString(post.Text) { // Posts containing Azerbaijan-related keywords
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
17
pkg/logger/base.go
Normal file
17
pkg/logger/base.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
var Log *slog.Logger
|
||||
|
||||
func init() {
|
||||
Log = slog.New(
|
||||
slog.NewTextHandler(
|
||||
os.Stdout,
|
||||
&slog.HandlerOptions{AddSource: true},
|
||||
),
|
||||
)
|
||||
}
|
250
pkg/manage/base.go
Normal file
250
pkg/manage/base.go
Normal file
@@ -0,0 +1,250 @@
|
||||
// This package was primarily developed using LLM models and should NOT be considered reliable.
|
||||
// The purpose of this package is to provide functionality for creating, updating, and deleting feed records on Bluesky, as no suitable tools were found for this purpose.
|
||||
// If a reliable tool becomes available that can perform these operations, this package will be deprecated and the discovered tool will be referenced in the project instead.
|
||||
|
||||
package manage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/utils"
|
||||
"github.com/bluesky-social/indigo/api/atproto"
|
||||
"github.com/bluesky-social/indigo/api/bsky"
|
||||
lexutil "github.com/bluesky-social/indigo/lex/util"
|
||||
"github.com/bluesky-social/indigo/xrpc"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPDSHost = "https://bsky.social"
|
||||
)
|
||||
|
||||
func NewClient(pdsHost *string) *xrpc.Client {
|
||||
if pdsHost == nil {
|
||||
pdsHost = utils.ToPtr(DefaultPDSHost)
|
||||
}
|
||||
|
||||
return &xrpc.Client{
|
||||
Host: *pdsHost,
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientWithAuth(ctx context.Context, client *xrpc.Client, identifier, password string) (*xrpc.Client, error) {
|
||||
if client == nil {
|
||||
client = NewClient(nil)
|
||||
}
|
||||
|
||||
auth, err := atproto.ServerCreateSession(ctx, client, &atproto.ServerCreateSession_Input{
|
||||
Identifier: identifier,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth session: %v", err)
|
||||
}
|
||||
|
||||
client.Auth = &xrpc.AuthInfo{
|
||||
AccessJwt: auth.AccessJwt,
|
||||
RefreshJwt: auth.RefreshJwt,
|
||||
Did: auth.Did,
|
||||
Handle: auth.Handle,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, clientWithAuth *xrpc.Client, avatarPath string) (*atproto.RepoUploadBlob_Output, error) {
|
||||
if clientWithAuth == nil {
|
||||
return nil, fmt.Errorf("client can't be nil")
|
||||
}
|
||||
if clientWithAuth.Auth == nil {
|
||||
return nil, fmt.Errorf("client auth can't be nil")
|
||||
}
|
||||
|
||||
avatarFile, err := os.Open(avatarPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open avatar file: %v", err)
|
||||
}
|
||||
defer func() { _ = avatarFile.Close() }()
|
||||
|
||||
uploadResp, err := atproto.RepoUploadBlob(ctx, clientWithAuth, avatarFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to upload avatar: %v", err)
|
||||
}
|
||||
|
||||
return uploadResp, nil
|
||||
}
|
||||
|
||||
func GetFeedGenerator(ctx context.Context, clientWithAuth *xrpc.Client, keyName string) (*atproto.RepoGetRecord_Output, error) {
|
||||
if clientWithAuth == nil {
|
||||
return nil, fmt.Errorf("client can't be nil")
|
||||
}
|
||||
if clientWithAuth.Auth == nil {
|
||||
return nil, fmt.Errorf("client auth can't be nil")
|
||||
}
|
||||
|
||||
record, err := atproto.RepoGetRecord(
|
||||
ctx,
|
||||
clientWithAuth,
|
||||
"",
|
||||
"app.bsky.feed.generator",
|
||||
clientWithAuth.Auth.Did,
|
||||
keyName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get axisting feed generator: %v", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func CreateFeedGenerator(
|
||||
ctx context.Context,
|
||||
clientWithAuth *xrpc.Client,
|
||||
displayName string,
|
||||
description *string,
|
||||
avatarPath *string,
|
||||
did string,
|
||||
keyName string,
|
||||
) error {
|
||||
if clientWithAuth == nil {
|
||||
return fmt.Errorf("client can't be nil")
|
||||
}
|
||||
if clientWithAuth.Auth == nil {
|
||||
return fmt.Errorf("client auth can't be nil")
|
||||
}
|
||||
|
||||
var avatarBlob *lexutil.LexBlob
|
||||
if avatarPath != nil {
|
||||
uploadResp, err := uploadBlob(ctx, clientWithAuth, *avatarPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
avatarBlob = uploadResp.Blob
|
||||
}
|
||||
|
||||
record := bsky.FeedGenerator{
|
||||
DisplayName: displayName,
|
||||
Description: description,
|
||||
Avatar: avatarBlob,
|
||||
Did: did,
|
||||
CreatedAt: time.Now().Format(time.RFC3339Nano),
|
||||
}
|
||||
|
||||
wrappedRecord := &lexutil.LexiconTypeDecoder{
|
||||
Val: &record,
|
||||
}
|
||||
|
||||
_, err := atproto.RepoCreateRecord(ctx, clientWithAuth, &atproto.RepoCreateRecord_Input{
|
||||
Collection: "app.bsky.feed.generator",
|
||||
Repo: clientWithAuth.Auth.Did, // Your DID (the one creating the record)
|
||||
Record: wrappedRecord,
|
||||
Rkey: &keyName,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create feed generator: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateFeedGenerator(
|
||||
ctx context.Context,
|
||||
clientWithAuth *xrpc.Client,
|
||||
displayName *string,
|
||||
description *string,
|
||||
avatarPath *string,
|
||||
did *string,
|
||||
keyName string,
|
||||
) error {
|
||||
if clientWithAuth == nil {
|
||||
return fmt.Errorf("client can't be nil")
|
||||
}
|
||||
if clientWithAuth.Auth == nil {
|
||||
return fmt.Errorf("client auth can't be nil")
|
||||
}
|
||||
|
||||
existingRecord, err := GetFeedGenerator(ctx, clientWithAuth, keyName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get axisting feed generator: %v", err)
|
||||
}
|
||||
|
||||
if existingRecord != nil && existingRecord.Value != nil {
|
||||
if existingFeedgen, ok := existingRecord.Value.Val.(*bsky.FeedGenerator); ok {
|
||||
if avatarPath != nil {
|
||||
uploadResp, err := uploadBlob(ctx, clientWithAuth, *avatarPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existingFeedgen.Avatar = uploadResp.Blob
|
||||
}
|
||||
|
||||
if displayName != nil {
|
||||
existingFeedgen.DisplayName = *displayName
|
||||
}
|
||||
|
||||
if description != nil {
|
||||
existingFeedgen.Description = description
|
||||
}
|
||||
|
||||
if did != nil {
|
||||
existingFeedgen.Did = *did
|
||||
}
|
||||
|
||||
wrappedExistingFeedgen := &lexutil.LexiconTypeDecoder{
|
||||
Val: &bsky.FeedGenerator{
|
||||
DisplayName: existingFeedgen.DisplayName,
|
||||
Description: existingFeedgen.Description,
|
||||
Did: existingFeedgen.Did,
|
||||
Avatar: existingFeedgen.Avatar,
|
||||
CreatedAt: existingFeedgen.CreatedAt,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := atproto.RepoPutRecord(ctx, clientWithAuth, &atproto.RepoPutRecord_Input{
|
||||
Collection: "app.bsky.feed.generator",
|
||||
Repo: clientWithAuth.Auth.Did, // Your DID
|
||||
Rkey: keyName, // The Rkey of the record to update
|
||||
Record: wrappedExistingFeedgen,
|
||||
SwapRecord: existingRecord.Cid,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update feed generator: %v", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("feed generator not found")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteFeedGenerator(
|
||||
ctx context.Context,
|
||||
clientWithAuth *xrpc.Client,
|
||||
keyName string,
|
||||
) error {
|
||||
if clientWithAuth == nil {
|
||||
return fmt.Errorf("client can't be nil")
|
||||
}
|
||||
if clientWithAuth.Auth == nil {
|
||||
return fmt.Errorf("client auth can't be nil")
|
||||
}
|
||||
|
||||
f, err := atproto.RepoDeleteRecord(ctx, clientWithAuth, &atproto.RepoDeleteRecord_Input{
|
||||
Collection: "app.bsky.feed.generator",
|
||||
Repo: clientWithAuth.Auth.Did,
|
||||
Rkey: keyName,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete feed generator: %v", err)
|
||||
}
|
||||
if f.Commit == nil {
|
||||
return fmt.Errorf("feed generator not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
194
pkg/storage/mongodb/collections/feed_az.go
Normal file
194
pkg/storage/mongodb/collections/feed_az.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package collections
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/config"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type FeedAzCollection struct {
|
||||
Collection *mongo.Collection
|
||||
}
|
||||
|
||||
func NewFeedAzCollection(client *mongo.Client) (*FeedAzCollection, error) {
|
||||
client.Database(config.MongoDBBaseDB).Collection("")
|
||||
coll := client.Database(config.MongoDBBaseDB).Collection("feed_az")
|
||||
|
||||
_, err := coll.Indexes().CreateMany(
|
||||
context.Background(),
|
||||
[]mongo.IndexModel{
|
||||
{
|
||||
Keys: bson.D{{Key: "sequence", Value: -1}},
|
||||
},
|
||||
{
|
||||
Keys: bson.D{{Key: "created_at", Value: -1}},
|
||||
},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &FeedAzCollection{Collection: coll}, nil
|
||||
}
|
||||
|
||||
type FeedAz struct {
|
||||
ID string `bson:"_id"`
|
||||
Sequence int64 `bson:"sequence"`
|
||||
DID string `bson:"did"`
|
||||
RecordKey string `bson:"record_key"`
|
||||
CreatedAt time.Time `bson:"created_at"`
|
||||
}
|
||||
|
||||
func (f FeedAzCollection) GetByCreatedAt(ctx context.Context, skip int64, limit int64) ([]*FeedAz, error) {
|
||||
cursor, err := f.Collection.Find(
|
||||
ctx, bson.D{},
|
||||
options.Find().
|
||||
SetSort(bson.D{{Key: "created_at", Value: -1}}).
|
||||
SetSkip(skip).
|
||||
SetLimit(limit),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = cursor.Close(ctx) }()
|
||||
|
||||
var feedAzItems []*FeedAz
|
||||
if err = cursor.All(ctx, &feedAzItems); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return feedAzItems, nil
|
||||
}
|
||||
|
||||
func (f FeedAzCollection) GetMaxSequence(ctx context.Context) (*int64, error) {
|
||||
pipeline := mongo.Pipeline{
|
||||
{
|
||||
{Key: "$group", Value: bson.D{
|
||||
{Key: "_id", Value: nil},
|
||||
{Key: "maxSequence", Value: bson.D{
|
||||
{Key: "$max", Value: "$sequence"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cursor, err := f.Collection.Aggregate(ctx, pipeline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = cursor.Close(ctx) }()
|
||||
|
||||
var result struct {
|
||||
MaxSequence int64 `bson:"maxSequence"`
|
||||
}
|
||||
|
||||
if cursor.Next(ctx) {
|
||||
if err := cursor.Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result.MaxSequence, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f FeedAzCollection) Insert(ctx context.Context, overwrite bool, feedAz ...*FeedAz) error {
|
||||
switch len(feedAz) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
if overwrite == false {
|
||||
_, err := f.Collection.InsertOne(ctx, feedAz[0])
|
||||
return err
|
||||
}
|
||||
_, err := f.Collection.ReplaceOne(
|
||||
ctx,
|
||||
bson.M{"_id": feedAz[0].ID},
|
||||
feedAz[0],
|
||||
options.Replace().SetUpsert(true),
|
||||
)
|
||||
return err
|
||||
default:
|
||||
if overwrite == false {
|
||||
documents := make([]any, len(feedAz))
|
||||
for i, feed := range feedAz {
|
||||
documents[i] = feed
|
||||
}
|
||||
|
||||
_, err := f.Collection.InsertMany(ctx, documents)
|
||||
return err
|
||||
}
|
||||
var models []mongo.WriteModel
|
||||
|
||||
for _, feed := range feedAz {
|
||||
filter := bson.M{"_id": feed.ID}
|
||||
model := mongo.NewReplaceOneModel().
|
||||
SetFilter(filter).
|
||||
SetReplacement(feed).
|
||||
SetUpsert(true)
|
||||
models = append(models, model)
|
||||
}
|
||||
|
||||
opts := options.BulkWrite().SetOrdered(false)
|
||||
_, err := f.Collection.BulkWrite(ctx, models, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f FeedAzCollection) CutoffByCount(
|
||||
ctx context.Context,
|
||||
maxDocumentCount int64,
|
||||
) (int64, error) {
|
||||
count, err := f.Collection.CountDocuments(ctx, bson.M{})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if count <= maxDocumentCount {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
deleteCount := count - maxDocumentCount
|
||||
|
||||
findOpts := options.Find().
|
||||
SetSort(bson.D{{Key: "created_at", Value: 1}}).
|
||||
SetLimit(deleteCount)
|
||||
|
||||
cursor, err := f.Collection.Find(ctx, bson.M{}, findOpts)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = cursor.Close(ctx) }()
|
||||
|
||||
var docsToDelete []bson.M
|
||||
if err = cursor.All(ctx, &docsToDelete); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(docsToDelete) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
ids := make([]any, len(docsToDelete))
|
||||
for i := range docsToDelete {
|
||||
ids[i] = docsToDelete[i]["_id"]
|
||||
}
|
||||
|
||||
result, err := f.Collection.DeleteMany(ctx, bson.M{"_id": bson.M{"$in": ids}})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.DeletedCount, nil
|
||||
}
|
183
pkg/storage/mongodb/collections/post.go
Normal file
183
pkg/storage/mongodb/collections/post.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package collections
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/config"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type PostCollection struct {
|
||||
Collection *mongo.Collection
|
||||
}
|
||||
|
||||
func NewPostCollection(client *mongo.Client) (*PostCollection, error) {
|
||||
client.Database(config.MongoDBBaseDB).Collection("")
|
||||
coll := client.Database(config.MongoDBBaseDB).Collection("post")
|
||||
_, err := coll.Indexes().CreateOne(
|
||||
context.Background(),
|
||||
mongo.IndexModel{
|
||||
Keys: bson.D{{Key: "sequence", Value: -1}},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PostCollection{Collection: coll}, nil
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
ID string `bson:"_id"`
|
||||
Sequence int64 `bson:"sequence"`
|
||||
DID string `bson:"did"`
|
||||
RecordKey string `bson:"record_key"`
|
||||
CreatedAt time.Time `bson:"created_at"`
|
||||
Langs []string `bson:"langs"`
|
||||
Tags []string `bson:"tags"`
|
||||
Text string `bson:"text"`
|
||||
Facets *Facets `bson:"facets"`
|
||||
Reply *Reply `bson:"reply"`
|
||||
}
|
||||
|
||||
type Facets struct {
|
||||
Tags []string `bson:"tags"`
|
||||
Links []string `bson:"links"`
|
||||
Mentions []string `bson:"mentions"`
|
||||
}
|
||||
|
||||
type Reply struct {
|
||||
RootURI string `bson:"root_uri"`
|
||||
ParentURI string `bson:"parent_uri"`
|
||||
}
|
||||
|
||||
func (p PostCollection) CutoffByCount(
|
||||
ctx context.Context,
|
||||
maxDocumentCount int64,
|
||||
) (int64, error) {
|
||||
count, err := p.Collection.CountDocuments(ctx, bson.M{})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if count <= maxDocumentCount {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
deleteCount := count - maxDocumentCount
|
||||
|
||||
findOpts := options.Find().
|
||||
SetSort(bson.D{{Key: "created_at", Value: 1}}).
|
||||
SetLimit(deleteCount)
|
||||
|
||||
cursor, err := p.Collection.Find(ctx, bson.M{}, findOpts)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = cursor.Close(ctx) }()
|
||||
|
||||
var docsToDelete []bson.M
|
||||
if err = cursor.All(ctx, &docsToDelete); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(docsToDelete) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
ids := make([]any, len(docsToDelete))
|
||||
for i := range docsToDelete {
|
||||
ids[i] = docsToDelete[i]["_id"]
|
||||
}
|
||||
|
||||
result, err := p.Collection.DeleteMany(ctx, bson.M{"_id": bson.M{"$in": ids}})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.DeletedCount, nil
|
||||
}
|
||||
|
||||
func (p PostCollection) GetMaxSequence(ctx context.Context) (*int64, error) {
|
||||
pipeline := mongo.Pipeline{
|
||||
{
|
||||
{Key: "$group", Value: bson.D{
|
||||
{Key: "_id", Value: nil},
|
||||
{Key: "maxSequence", Value: bson.D{
|
||||
{Key: "$max", Value: "$sequence"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cursor, err := p.Collection.Aggregate(ctx, pipeline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = cursor.Close(ctx) }()
|
||||
|
||||
var result struct {
|
||||
MaxSequence int64 `bson:"maxSequence"`
|
||||
}
|
||||
|
||||
if cursor.Next(ctx) {
|
||||
if err := cursor.Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result.MaxSequence, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p PostCollection) Insert(ctx context.Context, overwrite bool, posts ...*Post) error {
|
||||
switch len(posts) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
if overwrite == false {
|
||||
_, err := p.Collection.InsertOne(ctx, posts[0])
|
||||
return err
|
||||
}
|
||||
_, err := p.Collection.ReplaceOne(
|
||||
ctx,
|
||||
bson.M{"_id": posts[0].ID},
|
||||
posts[0],
|
||||
options.Replace().SetUpsert(true),
|
||||
)
|
||||
return err
|
||||
default:
|
||||
if overwrite == false {
|
||||
documents := make([]any, len(posts))
|
||||
for i, post := range posts {
|
||||
documents[i] = post
|
||||
}
|
||||
|
||||
_, err := p.Collection.InsertMany(ctx, documents)
|
||||
return err
|
||||
}
|
||||
var models []mongo.WriteModel
|
||||
|
||||
for _, post := range posts {
|
||||
filter := bson.M{"_id": post.ID}
|
||||
model := mongo.NewReplaceOneModel().
|
||||
SetFilter(filter).
|
||||
SetReplacement(post).
|
||||
SetUpsert(true)
|
||||
models = append(models, model)
|
||||
}
|
||||
|
||||
opts := options.BulkWrite().SetOrdered(false)
|
||||
_, err := p.Collection.BulkWrite(ctx, models, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
39
pkg/storage/mongodb/db.go
Normal file
39
pkg/storage/mongodb/db.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/bsky-feedgen/pkg/config"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func NewDB(ctx context.Context, dbConfig *config.MongoDBConfig) (*mongo.Client, error) {
|
||||
clientOptions := options.Client().ApplyURI(
|
||||
fmt.Sprintf("mongodb://%s:%v/", dbConfig.Host, dbConfig.Port),
|
||||
)
|
||||
|
||||
if dbConfig.Username != "" {
|
||||
clientOptions.SetAuth(options.Credential{
|
||||
Username: dbConfig.Username,
|
||||
Password: dbConfig.Password,
|
||||
})
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check the connection
|
||||
err = client.Ping(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
50
pkg/types/consumer_cursor.go
Normal file
50
pkg/types/consumer_cursor.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package types
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ConsumerCursor string
|
||||
|
||||
var (
|
||||
ConsumerCursorLastConsumed ConsumerCursor = "last-consumed"
|
||||
ConsumerCursorFirstStream ConsumerCursor = "first-stream"
|
||||
ConsumerCursorCurrentStream ConsumerCursor = "current-stream"
|
||||
)
|
||||
|
||||
func (c ConsumerCursor) String() string {
|
||||
return string(c)
|
||||
}
|
||||
|
||||
func (c ConsumerCursor) IsValid() bool {
|
||||
return c == ConsumerCursorLastConsumed || c == ConsumerCursorFirstStream || c == ConsumerCursorCurrentStream
|
||||
}
|
||||
|
||||
func (c ConsumerCursor) Equal(other ConsumerCursor) bool {
|
||||
return c == other
|
||||
}
|
||||
|
||||
func (c ConsumerCursor) IsLastConsumed() bool {
|
||||
return c == ConsumerCursorLastConsumed
|
||||
}
|
||||
|
||||
func (c ConsumerCursor) IsFirstStream() bool {
|
||||
return c == ConsumerCursorFirstStream
|
||||
}
|
||||
|
||||
func (c ConsumerCursor) IsCurrentStream() bool {
|
||||
return c == ConsumerCursorCurrentStream
|
||||
}
|
||||
|
||||
func (c *ConsumerCursor) Set(value string) error {
|
||||
switch value {
|
||||
case ConsumerCursorLastConsumed.String(), "":
|
||||
*c = ConsumerCursorLastConsumed
|
||||
case ConsumerCursorFirstStream.String():
|
||||
*c = ConsumerCursorFirstStream
|
||||
case ConsumerCursorCurrentStream.String():
|
||||
*c = ConsumerCursorCurrentStream
|
||||
default:
|
||||
return fmt.Errorf("invalid cursor value: %s", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
18
pkg/types/err_map.go
Normal file
18
pkg/types/err_map.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package types
|
||||
|
||||
type ErrMap map[string]error
|
||||
|
||||
func (ErrMap ErrMap) ToStringMap() map[string]string {
|
||||
if len(ErrMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
stringMap := make(map[string]string)
|
||||
for key, err := range ErrMap {
|
||||
if err != nil {
|
||||
stringMap[key] = err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return stringMap
|
||||
}
|
7
pkg/types/errors.go
Normal file
7
pkg/types/errors.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package types
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInternal = errors.New("internal error")
|
||||
)
|
43
pkg/types/generator_cursor.go
Normal file
43
pkg/types/generator_cursor.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package types
|
||||
|
||||
import "fmt"
|
||||
|
||||
type GeneratorCursor string
|
||||
|
||||
var (
|
||||
GeneratorCursorLastGenerated GeneratorCursor = "last-generated"
|
||||
GeneratorCursorFirstPost GeneratorCursor = "first-post"
|
||||
)
|
||||
|
||||
func (c GeneratorCursor) String() string {
|
||||
return string(c)
|
||||
}
|
||||
|
||||
func (c GeneratorCursor) IsValid() bool {
|
||||
return c == GeneratorCursorLastGenerated || c == GeneratorCursorFirstPost
|
||||
}
|
||||
|
||||
func (c GeneratorCursor) Equal(other GeneratorCursor) bool {
|
||||
return c == other
|
||||
}
|
||||
|
||||
func (c GeneratorCursor) IsLastGenerated() bool {
|
||||
return c == GeneratorCursorLastGenerated
|
||||
}
|
||||
|
||||
func (c GeneratorCursor) IsFirstPost() bool {
|
||||
return c == GeneratorCursorFirstPost
|
||||
}
|
||||
|
||||
func (c *GeneratorCursor) Set(value string) error {
|
||||
switch value {
|
||||
case GeneratorCursorLastGenerated.String(), "":
|
||||
*c = GeneratorCursorLastGenerated
|
||||
case GeneratorCursorFirstPost.String():
|
||||
*c = GeneratorCursorFirstPost
|
||||
default:
|
||||
return fmt.Errorf("invalid cursor value: %s", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
151
pkg/utils/base.go
Normal file
151
pkg/utils/base.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/whyrusleeping/go-did"
|
||||
)
|
||||
|
||||
// ParseString attempts to parse the input string `s` into a value of the specified type T.
|
||||
// It supports parsing into the following types:
|
||||
// - int, int8, int16, int32, int64
|
||||
// - uint, uint8, uint16, uint32, uint64
|
||||
// - float64
|
||||
// - bool
|
||||
// - string
|
||||
// - time.Duration
|
||||
// - url.URL / *url.URL
|
||||
// - did.DID / *did.DID
|
||||
//
|
||||
// If T is not one of these supported types, it returns an error.
|
||||
// If parsing the string `s` fails for a supported type, it returns the zero value of T
|
||||
// and the parsing error.
|
||||
func ParseString[T any](s string) (T, error) {
|
||||
var value T
|
||||
|
||||
switch any(value).(type) {
|
||||
case int:
|
||||
i, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(i).(T)
|
||||
case int8:
|
||||
i, err := strconv.ParseInt(s, 10, 8)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(int8(i)).(T)
|
||||
case int16:
|
||||
i, err := strconv.ParseInt(s, 10, 16)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(int16(i)).(T)
|
||||
case int32:
|
||||
i, err := strconv.ParseInt(s, 10, 32)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(int32(i)).(T)
|
||||
case int64:
|
||||
i, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(i).(T)
|
||||
case uint:
|
||||
u, err := strconv.ParseUint(s, 10, 0)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(uint(u)).(T)
|
||||
case uint8:
|
||||
u, err := strconv.ParseUint(s, 10, 8)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(uint8(u)).(T)
|
||||
case uint16:
|
||||
u, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(uint16(u)).(T)
|
||||
case uint32:
|
||||
u, err := strconv.ParseUint(s, 10, 32)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(uint32(u)).(T)
|
||||
case uint64:
|
||||
u, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(u).(T)
|
||||
case float64:
|
||||
f, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(f).(T)
|
||||
case bool:
|
||||
b, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(b).(T)
|
||||
case string:
|
||||
value = any(s).(T)
|
||||
case []string:
|
||||
var items []string
|
||||
err := json.Unmarshal([]byte(s), &items)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(items).(T)
|
||||
case time.Duration:
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(d).(T)
|
||||
case url.URL:
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(*u).(T)
|
||||
case *url.URL:
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(u).(T)
|
||||
case did.DID:
|
||||
d, err := did.ParseDID(s)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(d).(T)
|
||||
case *did.DID:
|
||||
d, err := did.ParseDID(s)
|
||||
if err != nil {
|
||||
return value, err
|
||||
}
|
||||
value = any(&d).(T)
|
||||
default:
|
||||
return value, fmt.Errorf("unsupported type: %T", value)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func ToPtr[T any](value T) *T {
|
||||
return &value
|
||||
}
|
50
pkg/utils/env.go
Normal file
50
pkg/utils/env.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// GetEnv retrieves the environment variable named by the key envName,
|
||||
// and attempts to parse its value into the specified type T.
|
||||
//
|
||||
// It returns the parsed value of type T and a nil error on success.
|
||||
// It returns a zero value of type T and an error if the environment
|
||||
// variable is not set, or if parsing fails using the ParseValue function.
|
||||
func GetEnv[T any](envName string) (T, error) {
|
||||
var zero T
|
||||
|
||||
envStr := os.Getenv(envName)
|
||||
if envStr == "" {
|
||||
return zero, fmt.Errorf("environment variable %s is not set", envName)
|
||||
}
|
||||
|
||||
parsedEnv, err := ParseString[T](envStr)
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("failed to parse environment variable %s: %s", envName, err)
|
||||
}
|
||||
|
||||
return parsedEnv, nil
|
||||
}
|
||||
|
||||
// GetEnvOr retrieves the environment variable named by the key envName,
|
||||
// and attempts to parse its value into the specified type T.
|
||||
//
|
||||
// It returns the parsed value of type T and a nil error on success.
|
||||
// If the environment variable is not set, it returns the provided default value
|
||||
// and a nil error. If parsing fails, it returns a zero value of type T and an error.
|
||||
func GetEnvOr[T any](envName string, defaultValue T) (T, error) {
|
||||
var zero T
|
||||
|
||||
envStr := os.Getenv(envName)
|
||||
if envStr == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
parsedEnv, err := ParseString[T](envStr)
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("failed to parse environment variable %s: %s", envName, err)
|
||||
}
|
||||
|
||||
return parsedEnv, nil
|
||||
}
|
Reference in New Issue
Block a user