mirror of
https://github.com/aykhans/dodo.git
synced 2025-09-07 11:30:47 +00:00
add 'Validate' method to the 'Config'
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/aykhans/dodo/pkg/types"
|
||||
@@ -29,7 +31,10 @@ var Defaults = struct {
|
||||
SkipVerify: false,
|
||||
}
|
||||
|
||||
var SupportedProxySchemes = []string{"http", "socks5", "socks5h"}
|
||||
var (
|
||||
ValidProxySchemes = []string{"http", "socks5", "socks5h"}
|
||||
ValidRequestURLSchemes = []string{"http", "https"}
|
||||
)
|
||||
|
||||
type IParser interface {
|
||||
Parse() (*Config, error)
|
||||
@@ -120,12 +125,73 @@ func (config *Config) SetDefaults() {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the config fields.
|
||||
// It can return the following errors:
|
||||
// - types.FieldValidationErrors
|
||||
func (config Config) Validate() error {
|
||||
validationErrors := make([]types.FieldValidationError, 0)
|
||||
|
||||
if config.Method == nil {
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("Method", "", errors.New("method is required")))
|
||||
}
|
||||
|
||||
if config.URL == nil {
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("URL", "", errors.New("URL is required")))
|
||||
} else if !slices.Contains(ValidRequestURLSchemes, config.URL.Scheme) {
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("URL", config.URL.String(), fmt.Errorf("URL scheme must be one of: %v", ValidRequestURLSchemes)))
|
||||
}
|
||||
|
||||
if config.DodosCount == nil {
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("Dodos Count", "", errors.New("dodos count is required")))
|
||||
} else if *config.DodosCount == 0 {
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("Dodos Count", "0", errors.New("dodos count must be greater than 0")))
|
||||
}
|
||||
|
||||
switch {
|
||||
case config.RequestCount == nil && config.Duration == nil:
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("Request Count / Duration", "", errors.New("either request count or duration must be specified")))
|
||||
case (config.RequestCount != nil && config.Duration != nil) && (*config.RequestCount == 0 && *config.Duration == 0):
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("Request Count / Duration", "0", errors.New("both request count and duration cannot be zero")))
|
||||
case config.RequestCount != nil && config.Duration == nil && *config.RequestCount == 0:
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("Request Count", "0", errors.New("request count must be greater than 0")))
|
||||
case config.RequestCount == nil && config.Duration != nil && *config.Duration == 0:
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("Duration", "0", errors.New("duration must be greater than 0")))
|
||||
}
|
||||
|
||||
if config.Yes == nil {
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("Yes", "", errors.New("yes field is required")))
|
||||
}
|
||||
|
||||
if config.SkipVerify == nil {
|
||||
validationErrors = append(validationErrors, types.NewFieldValidationError("Skip Verify", "", errors.New("skip verify field is required")))
|
||||
}
|
||||
|
||||
for i, proxy := range config.Proxies {
|
||||
if !slices.Contains(ValidProxySchemes, proxy.Scheme) {
|
||||
validationErrors = append(
|
||||
validationErrors,
|
||||
types.NewFieldValidationError(
|
||||
fmt.Sprintf("Proxy[%d]", i),
|
||||
proxy.String(),
|
||||
fmt.Errorf("proxy scheme must be one of: %v", ValidProxySchemes),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(validationErrors) > 0 {
|
||||
return types.NewFieldValidationErrors(validationErrors)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadAllConfigs() *Config {
|
||||
envParser := NewConfigENVParser("DODO")
|
||||
envConfig, err := envParser.Parse()
|
||||
_ = utils.HandleErrorOrDie(err,
|
||||
utils.OnCustomError(func(err types.FieldParseErrors) error {
|
||||
printValidationErrors("ENV", err.Errors...)
|
||||
printParseErrors("ENV", err.Errors...)
|
||||
fmt.Println()
|
||||
os.Exit(1)
|
||||
return nil
|
||||
@@ -148,7 +214,7 @@ func ReadAllConfigs() *Config {
|
||||
utils.OnCustomError(func(err types.FieldParseErrors) error {
|
||||
cliParser.PrintHelp()
|
||||
fmt.Println()
|
||||
printValidationErrors("CLI", err.Errors...)
|
||||
printParseErrors("CLI", err.Errors...)
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}),
|
||||
@@ -169,7 +235,7 @@ func ReadAllConfigs() *Config {
|
||||
return nil
|
||||
}),
|
||||
utils.OnCustomError(func(err types.FieldParseErrors) error {
|
||||
printValidationErrors(fmt.Sprintf("CONFIG FILE '%s'", configFile.Path()), err.Errors...)
|
||||
printParseErrors(fmt.Sprintf("CONFIG FILE '%s'", configFile.Path()), err.Errors...)
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}),
|
||||
@@ -178,6 +244,23 @@ func ReadAllConfigs() *Config {
|
||||
envConfig.Merge(fileConfig)
|
||||
}
|
||||
|
||||
envConfig.SetDefaults()
|
||||
|
||||
err = envConfig.Validate()
|
||||
_ = utils.HandleErrorOrDie(err,
|
||||
utils.OnCustomError(func(err types.FieldValidationErrors) error {
|
||||
for _, fieldErr := range err.Errors {
|
||||
if fieldErr.Value == "" {
|
||||
utils.PrintErr(text.FgYellow, "[VALIDATION] Field '%s': %v", fieldErr.Field, fieldErr.Err)
|
||||
} else {
|
||||
utils.PrintErr(text.FgYellow, "[VALIDATION] Field '%s' (%s): %v", fieldErr.Field, fieldErr.Value, fieldErr.Err)
|
||||
}
|
||||
}
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}),
|
||||
)
|
||||
|
||||
return envConfig
|
||||
}
|
||||
|
||||
@@ -210,7 +293,7 @@ func parseConfigFile(configFile types.ConfigFile, maxDepth int) (*Config, error)
|
||||
return fileConfig, nil
|
||||
}
|
||||
|
||||
func printValidationErrors(parserName string, errors ...types.FieldParseError) {
|
||||
func printParseErrors(parserName string, errors ...types.FieldParseError) {
|
||||
for _, fieldErr := range errors {
|
||||
if fieldErr.Value == "" {
|
||||
utils.PrintErr(text.FgYellow, "[%s] Field '%s': %v", parserName, fieldErr.Field, fieldErr.Err)
|
||||
|
Reference in New Issue
Block a user