package config import ( "errors" "fmt" "net/url" "os" "slices" "time" "github.com/aykhans/dodo/pkg/types" "github.com/aykhans/dodo/pkg/utils" "github.com/jedib0t/go-pretty/v6/text" ) const VERSION string = "1.0.0" var Defaults = struct { UserAgent string Method string RequestTimeout time.Duration DodosCount uint Yes bool SkipVerify bool }{ UserAgent: "dodo/" + VERSION, Method: "GET", RequestTimeout: time.Second * 10, DodosCount: 1, Yes: false, SkipVerify: false, } var ( ValidProxySchemes = []string{"http", "socks5", "socks5h"} ValidRequestURLSchemes = []string{"http", "https"} ) type IParser interface { Parse() (*Config, error) } type Config struct { Files []types.ConfigFile Method *string URL *url.URL Timeout *time.Duration DodosCount *uint RequestCount *uint Duration *time.Duration Yes *bool SkipVerify *bool Params types.Params Headers types.Headers Cookies types.Cookies Bodies types.Bodies Proxies types.Proxies } func NewConfig() *Config { return &Config{} } func (config *Config) Merge(newConfig *Config) { config.Files = append(config.Files, newConfig.Files...) if newConfig.Method != nil { config.Method = newConfig.Method } if newConfig.URL != nil { config.URL = newConfig.URL } if newConfig.Timeout != nil { config.Timeout = newConfig.Timeout } if newConfig.DodosCount != nil { config.DodosCount = newConfig.DodosCount } if newConfig.RequestCount != nil { config.RequestCount = newConfig.RequestCount } if newConfig.Duration != nil { config.Duration = newConfig.Duration } if newConfig.Yes != nil { config.Yes = newConfig.Yes } if newConfig.SkipVerify != nil { config.SkipVerify = newConfig.SkipVerify } if len(newConfig.Params) != 0 { config.Params.Append(newConfig.Params...) } if len(newConfig.Headers) != 0 { config.Headers.Append(newConfig.Headers...) } if len(newConfig.Cookies) != 0 { config.Cookies.Append(newConfig.Cookies...) } if len(newConfig.Bodies) != 0 { config.Bodies.Append(newConfig.Bodies...) } if len(newConfig.Proxies) != 0 { config.Proxies.Append(newConfig.Proxies...) } } func (config *Config) SetDefaults() { if config.Method == nil { config.Method = utils.ToPtr(Defaults.Method) } if config.Timeout == nil { config.Timeout = &Defaults.RequestTimeout } if config.DodosCount == nil { config.DodosCount = utils.ToPtr(Defaults.DodosCount) } if config.Yes == nil { config.Yes = utils.ToPtr(Defaults.Yes) } if config.SkipVerify == nil { config.SkipVerify = utils.ToPtr(Defaults.SkipVerify) } if !config.Headers.Has("User-Agent") { config.Headers = append(config.Headers, types.Header{Key: "User-Agent", Value: []string{Defaults.UserAgent}}) } } // Validate validates the config fields. // It can return the following errors: // - types.FieldValidationErrors func (config Config) Validate() error { validationErrors := make([]types.FieldValidationError, 0) if config.Method == nil { validationErrors = append(validationErrors, types.NewFieldValidationError("Method", "", errors.New("method is required"))) } if config.URL == nil { validationErrors = append(validationErrors, types.NewFieldValidationError("URL", "", errors.New("URL is required"))) } else if !slices.Contains(ValidRequestURLSchemes, config.URL.Scheme) { validationErrors = append(validationErrors, types.NewFieldValidationError("URL", config.URL.String(), fmt.Errorf("URL scheme must be one of: %v", ValidRequestURLSchemes))) } if config.DodosCount == nil { validationErrors = append(validationErrors, types.NewFieldValidationError("Dodos Count", "", errors.New("dodos count is required"))) } else if *config.DodosCount == 0 { validationErrors = append(validationErrors, types.NewFieldValidationError("Dodos Count", "0", errors.New("dodos count must be greater than 0"))) } switch { case config.RequestCount == nil && config.Duration == nil: validationErrors = append(validationErrors, types.NewFieldValidationError("Request Count / Duration", "", errors.New("either request count or duration must be specified"))) case (config.RequestCount != nil && config.Duration != nil) && (*config.RequestCount == 0 && *config.Duration == 0): validationErrors = append(validationErrors, types.NewFieldValidationError("Request Count / Duration", "0", errors.New("both request count and duration cannot be zero"))) case config.RequestCount != nil && config.Duration == nil && *config.RequestCount == 0: validationErrors = append(validationErrors, types.NewFieldValidationError("Request Count", "0", errors.New("request count must be greater than 0"))) case config.RequestCount == nil && config.Duration != nil && *config.Duration == 0: validationErrors = append(validationErrors, types.NewFieldValidationError("Duration", "0", errors.New("duration must be greater than 0"))) } if config.Yes == nil { validationErrors = append(validationErrors, types.NewFieldValidationError("Yes", "", errors.New("yes field is required"))) } if config.SkipVerify == nil { validationErrors = append(validationErrors, types.NewFieldValidationError("Skip Verify", "", errors.New("skip verify field is required"))) } for i, proxy := range config.Proxies { if !slices.Contains(ValidProxySchemes, proxy.Scheme) { validationErrors = append( validationErrors, types.NewFieldValidationError( fmt.Sprintf("Proxy[%d]", i), proxy.String(), fmt.Errorf("proxy scheme must be one of: %v", ValidProxySchemes), ), ) } } if len(validationErrors) > 0 { return types.NewFieldValidationErrors(validationErrors) } return nil } func ReadAllConfigs() *Config { envParser := NewConfigENVParser("DODO") envConfig, err := envParser.Parse() _ = utils.HandleErrorOrDie(err, utils.OnCustomError(func(err types.FieldParseErrors) error { printParseErrors("ENV", err.Errors...) fmt.Println() os.Exit(1) return nil }), ) cliParser := NewConfigCLIParser(os.Args) cliConf, err := cliParser.Parse() _ = utils.HandleErrorOrDie(err, utils.OnSentinelError(types.ErrCLINoArgs, func(err error) error { cliParser.PrintHelp() utils.PrintErrAndExit(text.FgYellow, 1, "\nNo arguments provided.") return nil }), utils.OnCustomError(func(err types.CLIUnexpectedArgsError) error { cliParser.PrintHelp() utils.PrintErrAndExit(text.FgYellow, 1, "\nUnexpected CLI arguments provided: %v", err.Args) return nil }), utils.OnCustomError(func(err types.FieldParseErrors) error { cliParser.PrintHelp() fmt.Println() printParseErrors("CLI", err.Errors...) os.Exit(1) return nil }), ) envConfig.Merge(cliConf) for _, configFile := range envConfig.Files { fileConfig, err := parseConfigFile(configFile, 10) _ = utils.HandleErrorOrDie(err, utils.OnCustomError(func(err types.ConfigFileReadError) error { cliParser.PrintHelp() utils.PrintErrAndExit(text.FgYellow, 1, "\nFailed to read config file '%s': %v", configFile.Path(), err) return nil }), utils.OnCustomError(func(err types.UnmarshalError) error { utils.PrintErrAndExit(text.FgYellow, 1, "\nFailed to unmarshal config file '%s': %v", configFile.Path(), err) return nil }), utils.OnCustomError(func(err types.FieldParseErrors) error { printParseErrors(fmt.Sprintf("CONFIG FILE '%s'", configFile.Path()), err.Errors...) os.Exit(1) return nil }), ) envConfig.Merge(fileConfig) } envConfig.SetDefaults() err = envConfig.Validate() _ = utils.HandleErrorOrDie(err, utils.OnCustomError(func(err types.FieldValidationErrors) error { for _, fieldErr := range err.Errors { if fieldErr.Value == "" { utils.PrintErr(text.FgYellow, "[VALIDATION] Field '%s': %v", fieldErr.Field, fieldErr.Err) } else { utils.PrintErr(text.FgYellow, "[VALIDATION] Field '%s' (%s): %v", fieldErr.Field, fieldErr.Value, fieldErr.Err) } } os.Exit(1) return nil }), ) return envConfig } // parseConfigFile recursively parses a config file and its nested files up to maxDepth levels. // Returns the merged configuration or an error if parsing fails. // It can return the following errors: // - types.ConfigFileReadError // - types.UnmarshalError // - types.FieldParseErrors func parseConfigFile(configFile types.ConfigFile, maxDepth int) (*Config, error) { configFileParser := NewConfigFileParser(configFile) fileConfig, err := configFileParser.Parse() if err != nil { return nil, err } if maxDepth <= 0 { return fileConfig, nil } for _, c := range fileConfig.Files { innerFileConfig, err := parseConfigFile(c, maxDepth-1) if err != nil { return nil, err } fileConfig.Merge(innerFileConfig) } return fileConfig, nil } func printParseErrors(parserName string, errors ...types.FieldParseError) { for _, fieldErr := range errors { if fieldErr.Value == "" { utils.PrintErr(text.FgYellow, "[%s] Field '%s': %v", parserName, fieldErr.Field, fieldErr.Err) } utils.PrintErr(text.FgYellow, "[%s] Field '%s' (%s): %v", parserName, fieldErr.Field, fieldErr.Value, fieldErr.Err) } }