diff --git a/go.mod b/go.mod index 8b6979f..924aa15 100644 --- a/go.mod +++ b/go.mod @@ -65,14 +65,13 @@ require ( ) require ( - github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 github.com/h2non/filetype v1.1.3 github.com/improbable-eng/grpc-web v0.15.0 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 github.com/mssola/useragent v1.0.0 - github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/posthog/posthog-go v0.0.0-20240327112532-87b23fe11103 golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 diff --git a/go.sum b/go.sum index e9e7107..7c60663 100644 --- a/go.sum +++ b/go.sum @@ -96,8 +96,8 @@ github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= -github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -278,8 +278,6 @@ github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnh github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= -github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg= github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= diff --git a/server/route/api/v1/acl.go b/server/route/api/v1/acl.go index 6b0b508..bd5fddc 100644 --- a/server/route/api/v1/acl.go +++ b/server/route/api/v1/acl.go @@ -5,7 +5,7 @@ import ( "net/http" "strings" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/pkg/errors" "google.golang.org/grpc" "google.golang.org/grpc/codes" diff --git a/server/route/api/v1/user_service.go b/server/route/api/v1/user_service.go index 943eb7c..f9970c3 100644 --- a/server/route/api/v1/user_service.go +++ b/server/route/api/v1/user_service.go @@ -4,7 +4,7 @@ import ( "context" "time" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" "golang.org/x/exp/slices" diff --git a/server/route/auth/auth.go b/server/route/auth/auth.go index 2ce7d3b..458fa55 100644 --- a/server/route/auth/auth.go +++ b/server/route/auth/auth.go @@ -4,7 +4,7 @@ import ( "fmt" "time" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" ) const ( diff --git a/server/service/license/cache.go b/server/service/license/cache.go deleted file mode 100644 index 8591713..0000000 --- a/server/service/license/cache.go +++ /dev/null @@ -1,24 +0,0 @@ -package license - -import ( - "fmt" - "time" - - "github.com/patrickmn/go-cache" -) - -var ( - licenseCache = cache.New(24*time.Hour, 24*time.Hour) -) - -func SetLicenseCache(licenseKey, instanceName string, license LicenseKey) { - licenseCache.Set(fmt.Sprintf("%s-%s", licenseKey, instanceName), license, 24*time.Hour) -} - -func GetLicenseCache(licenseKey, instanceName string) *LicenseKey { - cache, ok := licenseCache.Get(fmt.Sprintf("%s-%s", licenseKey, instanceName)) - if !ok { - return nil - } - return cache.(*LicenseKey) -} diff --git a/server/service/license/lemonsqueezy/lemonsqueezy.go b/server/service/license/lemonsqueezy/lemonsqueezy.go new file mode 100644 index 0000000..10ba8f2 --- /dev/null +++ b/server/service/license/lemonsqueezy/lemonsqueezy.go @@ -0,0 +1 @@ +package lemonsqueezy diff --git a/server/service/license/requests.go b/server/service/license/lemonsqueezy/requests.go similarity index 95% rename from server/service/license/requests.go rename to server/service/license/lemonsqueezy/requests.go index 4d60e29..93ac5ff 100644 --- a/server/service/license/requests.go +++ b/server/service/license/lemonsqueezy/requests.go @@ -1,4 +1,4 @@ -package license +package lemonsqueezy import ( "bytes" @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "time" "github.com/pkg/errors" ) @@ -57,7 +56,7 @@ type ActiveLicenseKeyResponse struct { Meta *LicenseKeyMeta `json:"meta"` } -func validateLicenseKey(licenseKey string, instanceName string) (*ValidateLicenseKeyResponse, error) { +func ValidateLicenseKey(licenseKey string, instanceName string) (*ValidateLicenseKeyResponse, error) { data := map[string]string{"license_key": licenseKey} if instanceName != "" { data["instance_name"] = instanceName @@ -98,11 +97,10 @@ func validateLicenseKey(licenseKey string, instanceName string) (*ValidateLicens return nil, errors.New("invalid store or product id") } } - licenseCache.Set("key", "value", 24*time.Hour) return &response, nil } -func activeLicenseKey(licenseKey string, instanceName string) (*ActiveLicenseKeyResponse, error) { +func ActiveLicenseKey(licenseKey string, instanceName string) (*ActiveLicenseKeyResponse, error) { data := map[string]string{"license_key": licenseKey, "instance_name": instanceName} payload, err := json.Marshal(data) if err != nil { diff --git a/server/service/license/requests_test.go b/server/service/license/lemonsqueezy/requests_test.go similarity index 88% rename from server/service/license/requests_test.go rename to server/service/license/lemonsqueezy/requests_test.go index 87fb42f..44092e7 100644 --- a/server/service/license/requests_test.go +++ b/server/service/license/lemonsqueezy/requests_test.go @@ -1,4 +1,4 @@ -package license +package lemonsqueezy import ( "testing" @@ -27,7 +27,7 @@ func TestValidateLicenseKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - response, err := validateLicenseKey(tt.key, "test-instance") + response, err := ValidateLicenseKey(tt.key, "test-instance") if tt.err != nil { require.EqualError(t, err, tt.err.Error()) return @@ -58,7 +58,7 @@ func TestActiveLicenseKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - response, err := activeLicenseKey(tt.key, "test-instance") + response, err := ActiveLicenseKey(tt.key, "test-instance") require.NoError(t, err) require.Equal(t, tt.expected, response.Activated) }) diff --git a/server/service/license/license.go b/server/service/license/license.go index 701b8c3..3430abd 100644 --- a/server/service/license/license.go +++ b/server/service/license/license.go @@ -2,17 +2,23 @@ package license import ( "context" + _ "embed" "time" + "github.com/golang-jwt/jwt/v5" "github.com/pkg/errors" "google.golang.org/protobuf/types/known/timestamppb" apiv1pb "github.com/yourselfhosted/slash/proto/gen/api/v1" storepb "github.com/yourselfhosted/slash/proto/gen/store" "github.com/yourselfhosted/slash/server/profile" + "github.com/yourselfhosted/slash/server/service/license/lemonsqueezy" "github.com/yourselfhosted/slash/store" ) +//go:embed slash.public.pem +var slashPublicRSAKey string + type LicenseService struct { Profile *profile.Profile Store *store.Store @@ -49,25 +55,15 @@ func (s *LicenseService) LoadSubscription(ctx context.Context) (*apiv1pb.Subscri return subscription, nil } - validateResponse, err := validateLicenseKey(licenseKey, "") + result, err := validateLicenseKey(licenseKey) if err != nil { return nil, errors.Wrap(err, "failed to validate license key") } - if validateResponse.Valid { - subscription.Plan = apiv1pb.PlanType_PRO - if validateResponse.LicenseKey.ExpiresAt != nil && *validateResponse.LicenseKey.ExpiresAt != "" { - expiresTime, err := time.Parse(time.RFC3339Nano, *validateResponse.LicenseKey.ExpiresAt) - if err != nil { - return nil, errors.Wrap(err, "failed to parse license key expires time") - } - subscription.ExpiresTime = timestamppb.New(expiresTime) - } - startedTime, err := time.Parse(time.RFC3339Nano, validateResponse.LicenseKey.CreatedAt) - if err != nil { - return nil, errors.Wrap(err, "failed to parse license key created time") - } - subscription.StartedTime = timestamppb.New(startedTime) + if result == nil { + return subscription, nil } + subscription.Plan = result.Plan + subscription.ExpiresTime = timestamppb.New(result.ExpiresTime) s.cachedSubscription = subscription return subscription, nil } @@ -76,11 +72,11 @@ func (s *LicenseService) UpdateSubscription(ctx context.Context, licenseKey stri if licenseKey == "" { return nil, errors.New("license key is required") } - validateResponse, err := validateLicenseKey(licenseKey, "") + result, err := validateLicenseKey(licenseKey) if err != nil { return nil, errors.Wrap(err, "failed to validate license key") } - if !validateResponse.Valid { + if result == nil { return nil, errors.New("invalid license key") } _, err = s.Store.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{ @@ -96,7 +92,14 @@ func (s *LicenseService) UpdateSubscription(ctx context.Context, licenseKey stri } func (s *LicenseService) GetSubscription(ctx context.Context) (*apiv1pb.Subscription, error) { - return s.LoadSubscription(ctx) + subscription, err := s.LoadSubscription(ctx) + if err != nil || subscription.Plan == apiv1pb.PlanType_PLAN_TYPE_UNSPECIFIED { + // nolint + return &apiv1pb.Subscription{ + Plan: apiv1pb.PlanType_FREE, + }, nil + } + return subscription, nil } func (s *LicenseService) IsFeatureEnabled(feature FeatureType) bool { @@ -106,3 +109,80 @@ func (s *LicenseService) IsFeatureEnabled(feature FeatureType) bool { } return matrix[s.cachedSubscription.Plan-1] } + +type ValidateResult struct { + Plan apiv1pb.PlanType + ExpiresTime time.Time +} + +type Claims struct { + jwt.RegisteredClaims + + Owner string `json:"owner"` + Plan string `json:"plan"` + Trial bool `json:"trial"` +} + +func validateLicenseKey(licenseKey string) (*ValidateResult, error) { + // Try to parse the license key as a JWT token. + claims, _ := parseLicenseKey(licenseKey) + if claims != nil { + plan := apiv1pb.PlanType(apiv1pb.PlanType_value[claims.Plan]) + if plan == apiv1pb.PlanType_PLAN_TYPE_UNSPECIFIED { + return nil, errors.New("invalid plan") + } + return &ValidateResult{ + Plan: apiv1pb.PlanType(apiv1pb.PlanType_value[claims.Plan]), + ExpiresTime: claims.ExpiresAt.Time, + }, nil + } + + // Try to validate the license key with the license server. + validateResponse, err := lemonsqueezy.ValidateLicenseKey(licenseKey, "") + if err != nil { + return nil, errors.Wrap(err, "failed to validate license key") + } + if validateResponse.Valid { + result := &ValidateResult{ + Plan: apiv1pb.PlanType_PRO, + } + if validateResponse.LicenseKey.ExpiresAt != nil && *validateResponse.LicenseKey.ExpiresAt != "" { + expiresTime, err := time.Parse(time.RFC3339Nano, *validateResponse.LicenseKey.ExpiresAt) + if err != nil { + return nil, errors.Wrap(err, "failed to parse license key expires time") + } + result.ExpiresTime = expiresTime + } + return result, nil + } + + // Otherwise, return an error. + return nil, errors.New("invalid license key") +} + +func parseLicenseKey(licenseKey string) (*Claims, error) { + token, err := jwt.ParseWithClaims(licenseKey, &Claims{}, func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, errors.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(slashPublicRSAKey)) + if err != nil { + return nil, err + } + + return key, nil + }) + if err != nil { + return nil, errors.Wrap(err, "failed to parse token") + } + if token == nil || !token.Valid { + return nil, errors.New("invalid token") + } + + claims, ok := token.Claims.(*Claims) + if !ok { + return nil, errors.New("invalid claims") + } + return claims, nil +} diff --git a/server/service/license/slash.public.pem b/server/service/license/slash.public.pem new file mode 100644 index 0000000..72b718d --- /dev/null +++ b/server/service/license/slash.public.pem @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsNHQEGf4EiGUKG/buu1d +llNjxwmKiUX0htAoBa7JPqNjlQqyd27gBQCJ9b1d4gor3SBbEdKKirph6I/jJ2in +LQwSVtIuQUdILC0PSEyUZ1t/QOOfgNuAW15cvj7e1W2I3GqTy/PwQ08+xTziDiU0 +j9fM15vMEx/G378ikPaSfaoLueugI/tpta3Ho6wJqpNr2pL2+pIb1LUurltufA/O +5mIcxorlu+1iSB5PLB6X1ptipDkD+ZdHlDLzKgzkUoIrqDynC7jlwhiDtqDl2q+j +VTyZKhP6PB81rI4/DXAafrl4ndAGxWiZj83/+m/uzqIx8XWMA10jnSeGvBxbVWrw +rwIDAQAB +-----END PUBLIC KEY----- \ No newline at end of file