mirror of
https://github.com/aykhans/dodo.git
synced 2025-09-07 11:30:47 +00:00
304 lines
9.0 KiB
Go
304 lines
9.0 KiB
Go
package config
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"os"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/aykhans/dodo/pkg/types"
|
|
"github.com/aykhans/dodo/pkg/utils"
|
|
"github.com/jedib0t/go-pretty/v6/text"
|
|
)
|
|
|
|
const VERSION string = "1.0.0"
|
|
|
|
var Defaults = struct {
|
|
UserAgent string
|
|
Method string
|
|
RequestTimeout time.Duration
|
|
DodosCount uint
|
|
Yes bool
|
|
SkipVerify bool
|
|
}{
|
|
UserAgent: "dodo/" + VERSION,
|
|
Method: "GET",
|
|
RequestTimeout: time.Second * 10,
|
|
DodosCount: 1,
|
|
Yes: false,
|
|
SkipVerify: false,
|
|
}
|
|
|
|
var (
|
|
ValidProxySchemes = []string{"http", "socks5", "socks5h"}
|
|
ValidRequestURLSchemes = []string{"http", "https"}
|
|
)
|
|
|
|
type IParser interface {
|
|
Parse() (*Config, error)
|
|
}
|
|
|
|
type Config struct {
|
|
Files []types.ConfigFile
|
|
Method *string
|
|
URL *url.URL
|
|
Timeout *time.Duration
|
|
DodosCount *uint
|
|
RequestCount *uint
|
|
Duration *time.Duration
|
|
Yes *bool
|
|
SkipVerify *bool
|
|
Params types.Params
|
|
Headers types.Headers
|
|
Cookies types.Cookies
|
|
Bodies types.Bodies
|
|
Proxies types.Proxies
|
|
}
|
|
|
|
func NewConfig() *Config {
|
|
return &Config{}
|
|
}
|
|
|
|
func (config *Config) Merge(newConfig *Config) {
|
|
config.Files = append(config.Files, newConfig.Files...)
|
|
if newConfig.Method != nil {
|
|
config.Method = newConfig.Method
|
|
}
|
|
if newConfig.URL != nil {
|
|
config.URL = newConfig.URL
|
|
}
|
|
if newConfig.Timeout != nil {
|
|
config.Timeout = newConfig.Timeout
|
|
}
|
|
if newConfig.DodosCount != nil {
|
|
config.DodosCount = newConfig.DodosCount
|
|
}
|
|
if newConfig.RequestCount != nil {
|
|
config.RequestCount = newConfig.RequestCount
|
|
}
|
|
if newConfig.Duration != nil {
|
|
config.Duration = newConfig.Duration
|
|
}
|
|
if newConfig.Yes != nil {
|
|
config.Yes = newConfig.Yes
|
|
}
|
|
if newConfig.SkipVerify != nil {
|
|
config.SkipVerify = newConfig.SkipVerify
|
|
}
|
|
if len(newConfig.Params) != 0 {
|
|
config.Params.Append(newConfig.Params...)
|
|
}
|
|
if len(newConfig.Headers) != 0 {
|
|
config.Headers.Append(newConfig.Headers...)
|
|
}
|
|
if len(newConfig.Cookies) != 0 {
|
|
config.Cookies.Append(newConfig.Cookies...)
|
|
}
|
|
if len(newConfig.Bodies) != 0 {
|
|
config.Bodies.Append(newConfig.Bodies...)
|
|
}
|
|
if len(newConfig.Proxies) != 0 {
|
|
config.Proxies.Append(newConfig.Proxies...)
|
|
}
|
|
}
|
|
|
|
func (config *Config) SetDefaults() {
|
|
if config.Method == nil {
|
|
config.Method = utils.ToPtr(Defaults.Method)
|
|
}
|
|
if config.Timeout == nil {
|
|
config.Timeout = &Defaults.RequestTimeout
|
|
}
|
|
if config.DodosCount == nil {
|
|
config.DodosCount = utils.ToPtr(Defaults.DodosCount)
|
|
}
|
|
if config.Yes == nil {
|
|
config.Yes = utils.ToPtr(Defaults.Yes)
|
|
}
|
|
if config.SkipVerify == nil {
|
|
config.SkipVerify = utils.ToPtr(Defaults.SkipVerify)
|
|
}
|
|
if !config.Headers.Has("User-Agent") {
|
|
config.Headers = append(config.Headers, types.Header{Key: "User-Agent", Value: []string{Defaults.UserAgent}})
|
|
}
|
|
}
|
|
|
|
// Validate validates the config fields.
|
|
// It can return the following errors:
|
|
// - types.FieldValidationErrors
|
|
func (config Config) Validate() error {
|
|
validationErrors := make([]types.FieldValidationError, 0)
|
|
|
|
if config.Method == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Method", "", errors.New("method is required")))
|
|
}
|
|
|
|
if config.URL == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("URL", "", errors.New("URL is required")))
|
|
} else if !slices.Contains(ValidRequestURLSchemes, config.URL.Scheme) {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("URL", config.URL.String(), fmt.Errorf("URL scheme must be one of: %v", ValidRequestURLSchemes)))
|
|
}
|
|
|
|
if config.DodosCount == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Dodos Count", "", errors.New("dodos count is required")))
|
|
} else if *config.DodosCount == 0 {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Dodos Count", "0", errors.New("dodos count must be greater than 0")))
|
|
}
|
|
|
|
switch {
|
|
case config.RequestCount == nil && config.Duration == nil:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Request Count / Duration", "", errors.New("either request count or duration must be specified")))
|
|
case (config.RequestCount != nil && config.Duration != nil) && (*config.RequestCount == 0 && *config.Duration == 0):
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Request Count / Duration", "0", errors.New("both request count and duration cannot be zero")))
|
|
case config.RequestCount != nil && config.Duration == nil && *config.RequestCount == 0:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Request Count", "0", errors.New("request count must be greater than 0")))
|
|
case config.RequestCount == nil && config.Duration != nil && *config.Duration == 0:
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Duration", "0", errors.New("duration must be greater than 0")))
|
|
}
|
|
|
|
if config.Yes == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Yes", "", errors.New("yes field is required")))
|
|
}
|
|
|
|
if config.SkipVerify == nil {
|
|
validationErrors = append(validationErrors, types.NewFieldValidationError("Skip Verify", "", errors.New("skip verify field is required")))
|
|
}
|
|
|
|
for i, proxy := range config.Proxies {
|
|
if !slices.Contains(ValidProxySchemes, proxy.Scheme) {
|
|
validationErrors = append(
|
|
validationErrors,
|
|
types.NewFieldValidationError(
|
|
fmt.Sprintf("Proxy[%d]", i),
|
|
proxy.String(),
|
|
fmt.Errorf("proxy scheme must be one of: %v", ValidProxySchemes),
|
|
),
|
|
)
|
|
}
|
|
}
|
|
|
|
if len(validationErrors) > 0 {
|
|
return types.NewFieldValidationErrors(validationErrors)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ReadAllConfigs() *Config {
|
|
envParser := NewConfigENVParser("DODO")
|
|
envConfig, err := envParser.Parse()
|
|
_ = utils.HandleErrorOrDie(err,
|
|
utils.OnCustomError(func(err types.FieldParseErrors) error {
|
|
printParseErrors("ENV", err.Errors...)
|
|
fmt.Println()
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
)
|
|
|
|
cliParser := NewConfigCLIParser(os.Args)
|
|
cliConf, err := cliParser.Parse()
|
|
_ = utils.HandleErrorOrDie(err,
|
|
utils.OnSentinelError(types.ErrCLINoArgs, func(err error) error {
|
|
cliParser.PrintHelp()
|
|
utils.PrintErrAndExit(text.FgYellow, 1, "\nNo arguments provided.")
|
|
return nil
|
|
}),
|
|
utils.OnCustomError(func(err types.CLIUnexpectedArgsError) error {
|
|
cliParser.PrintHelp()
|
|
utils.PrintErrAndExit(text.FgYellow, 1, "\nUnexpected CLI arguments provided: %v", err.Args)
|
|
return nil
|
|
}),
|
|
utils.OnCustomError(func(err types.FieldParseErrors) error {
|
|
cliParser.PrintHelp()
|
|
fmt.Println()
|
|
printParseErrors("CLI", err.Errors...)
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
)
|
|
|
|
envConfig.Merge(cliConf)
|
|
|
|
for _, configFile := range envConfig.Files {
|
|
fileConfig, err := parseConfigFile(configFile, 10)
|
|
_ = utils.HandleErrorOrDie(err,
|
|
utils.OnCustomError(func(err types.ConfigFileReadError) error {
|
|
cliParser.PrintHelp()
|
|
utils.PrintErrAndExit(text.FgYellow, 1, "\nFailed to read config file '%s': %v", configFile.Path(), err)
|
|
return nil
|
|
}),
|
|
utils.OnCustomError(func(err types.UnmarshalError) error {
|
|
utils.PrintErrAndExit(text.FgYellow, 1, "\nFailed to unmarshal config file '%s': %v", configFile.Path(), err)
|
|
return nil
|
|
}),
|
|
utils.OnCustomError(func(err types.FieldParseErrors) error {
|
|
printParseErrors(fmt.Sprintf("CONFIG FILE '%s'", configFile.Path()), err.Errors...)
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
)
|
|
|
|
envConfig.Merge(fileConfig)
|
|
}
|
|
|
|
envConfig.SetDefaults()
|
|
|
|
err = envConfig.Validate()
|
|
_ = utils.HandleErrorOrDie(err,
|
|
utils.OnCustomError(func(err types.FieldValidationErrors) error {
|
|
for _, fieldErr := range err.Errors {
|
|
if fieldErr.Value == "" {
|
|
utils.PrintErr(text.FgYellow, "[VALIDATION] Field '%s': %v", fieldErr.Field, fieldErr.Err)
|
|
} else {
|
|
utils.PrintErr(text.FgYellow, "[VALIDATION] Field '%s' (%s): %v", fieldErr.Field, fieldErr.Value, fieldErr.Err)
|
|
}
|
|
}
|
|
os.Exit(1)
|
|
return nil
|
|
}),
|
|
)
|
|
|
|
return envConfig
|
|
}
|
|
|
|
// parseConfigFile recursively parses a config file and its nested files up to maxDepth levels.
|
|
// Returns the merged configuration or an error if parsing fails.
|
|
// It can return the following errors:
|
|
// - types.ConfigFileReadError
|
|
// - types.UnmarshalError
|
|
// - types.FieldParseErrors
|
|
func parseConfigFile(configFile types.ConfigFile, maxDepth int) (*Config, error) {
|
|
configFileParser := NewConfigFileParser(configFile)
|
|
fileConfig, err := configFileParser.Parse()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if maxDepth <= 0 {
|
|
return fileConfig, nil
|
|
}
|
|
|
|
for _, c := range fileConfig.Files {
|
|
innerFileConfig, err := parseConfigFile(c, maxDepth-1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fileConfig.Merge(innerFileConfig)
|
|
}
|
|
|
|
return fileConfig, nil
|
|
}
|
|
|
|
func printParseErrors(parserName string, errors ...types.FieldParseError) {
|
|
for _, fieldErr := range errors {
|
|
if fieldErr.Value == "" {
|
|
utils.PrintErr(text.FgYellow, "[%s] Field '%s': %v", parserName, fieldErr.Field, fieldErr.Err)
|
|
}
|
|
utils.PrintErr(text.FgYellow, "[%s] Field '%s' (%s): %v", parserName, fieldErr.Field, fieldErr.Value, fieldErr.Err)
|
|
}
|
|
}
|