diff --git a/.golangci.yml b/.golangci.yml index 3693c4d..8a131b8 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -31,3 +31,9 @@ linters: - "all" - "-S1002" - "-ST1000" + + exclusions: + rules: + - path: _test\.go + linters: + - errcheck diff --git a/config/cli_test.go b/config/cli_test.go new file mode 100644 index 0000000..b055304 --- /dev/null +++ b/config/cli_test.go @@ -0,0 +1,181 @@ +package config + +import ( + "flag" + "io" + "os" + "testing" + "time" + + "github.com/aykhans/dodo/types" + "github.com/stretchr/testify/assert" +) + +func TestReadCLI(t *testing.T) { + tests := []struct { + name string + args []string + expectFile types.ConfigFile + expectError bool + expectedConfig *Config + }{ + { + name: "simple url and duration", + args: []string{"-u", "https://example.com", "-o", "1m"}, + expectFile: "", + expectError: false, + expectedConfig: &Config{ + URL: &types.RequestURL{}, + Duration: &types.Duration{Duration: time.Minute}, + }, + }, + { + name: "config file only", + args: []string{"-f", "/path/to/config.json"}, + expectFile: "/path/to/config.json", + expectError: false, + expectedConfig: &Config{}, + }, + { + name: "all flags", + args: []string{"-f", "/path/to/config.json", "-u", "https://example.com", "-m", "POST", "-d", "10", "-r", "1000", "-o", "3m", "-t", "3s", "-b", "body1", "-H", "header1:value1", "-p", "param1=value1", "-c", "cookie1=value1", "-x", "http://proxy.example.com:8080", "-y"}, + expectFile: "/path/to/config.json", + expectError: false, + expectedConfig: &Config{ + Method: stringPtr("POST"), + URL: &types.RequestURL{}, + DodosCount: uintPtr(10), + RequestCount: uintPtr(1000), + Duration: &types.Duration{Duration: 3 * time.Minute}, + Timeout: &types.Timeout{Duration: 3 * time.Second}, + Yes: boolPtr(true), + }, + }, + { + name: "unexpected arguments", + args: []string{"-u", "https://example.com", "extraArg"}, + expectFile: "", + expectError: true, + expectedConfig: &Config{}, + }, + } + + // Save original command-line arguments + origArgs := os.Args + origFlagCommandLine := flag.CommandLine + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset flag.CommandLine to its original state + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + // Override os.Args for the test + os.Args = append([]string{"dodo"}, tt.args...) + + // Initialize a new config + config := NewConfig() + + // Mock URL to avoid actual URL parsing issues in tests + if tt.expectedConfig.URL != nil { + urlObj := types.RequestURL{} + urlObj.Set("https://example.com") + tt.expectedConfig.URL = &urlObj + } + + // Call the function being tested + file, err := config.ReadCLI() + + // Reset os.Args after test + os.Args = origArgs + + // Assert expected results + assert.Equal(t, tt.expectFile, file) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + + // Check expected config values if URL is set + if tt.expectedConfig.URL != nil { + assert.NotNil(t, config.URL) + assert.Equal(t, "https://example.com", config.URL.String()) + } + + // Check duration if expected + if tt.expectedConfig.Duration != nil { + assert.NotNil(t, config.Duration) + assert.Equal(t, tt.expectedConfig.Duration.Duration, config.Duration.Duration) + } + + // Check other values as needed + if tt.expectedConfig.Method != nil { + assert.Equal(t, *tt.expectedConfig.Method, *config.Method) + } + if tt.expectedConfig.DodosCount != nil { + assert.Equal(t, *tt.expectedConfig.DodosCount, *config.DodosCount) + } + if tt.expectedConfig.RequestCount != nil { + assert.Equal(t, *tt.expectedConfig.RequestCount, *config.RequestCount) + } + if tt.expectedConfig.Timeout != nil { + assert.Equal(t, tt.expectedConfig.Timeout.Duration, config.Timeout.Duration) + } + if tt.expectedConfig.Yes != nil { + assert.Equal(t, *tt.expectedConfig.Yes, *config.Yes) + } + } + }) + } + + // Restore original flag.CommandLine + flag.CommandLine = origFlagCommandLine +} + +// Skip the prompt tests as they require interactive input/output handling +// which is difficult to test reliably in unit tests +func TestCLIYesOrNoReaderBasic(t *testing.T) { + // We're just going to verify the function exists and returns the default value + // when called with "\n" as input (which should trigger the default path) + result := func() bool { + // Save original standard input + origStdin := os.Stdin + origStdout := os.Stdout + + // Create a pipe to mock standard input + r, w, _ := os.Pipe() + os.Stdin = r + + // Redirect stdout to null device + devNull, _ := os.Open(os.DevNull) + os.Stdout = devNull + + // Write newline to mock stdin (should trigger default behavior) + io.WriteString(w, "\n") + w.Close() + + // Call the function being tested with default=true + result := CLIYesOrNoReader("Test message", true) + + // Restore original stdin and stdout + os.Stdin = origStdin + os.Stdout = origStdout + + return result + }() + + // Default value should be returned + assert.True(t, result) +} + +// Helper types and functions for testing +func stringPtr(s string) *string { + return &s +} + +func uintPtr(u uint) *uint { + return &u +} + +func boolPtr(b bool) *bool { + return &b +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..7d9c2e4 --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,514 @@ +package config + +import ( + "net/url" + "os" + "testing" + "time" + + "github.com/aykhans/dodo/types" + "github.com/aykhans/dodo/utils" + "github.com/stretchr/testify/assert" +) + +func TestNewConfig(t *testing.T) { + config := NewConfig() + assert.NotNil(t, config) + assert.IsType(t, &Config{}, config) +} + +func TestNewRequestConfig(t *testing.T) { + // Create a sample Config object + urlObj := types.RequestURL{} + urlObj.Set("https://example.com") + + conf := &Config{ + Method: utils.ToPtr("GET"), + URL: &urlObj, + Timeout: &types.Timeout{Duration: 5 * time.Second}, + DodosCount: utils.ToPtr(uint(10)), + RequestCount: utils.ToPtr(uint(100)), + Duration: &types.Duration{Duration: 1 * time.Minute}, + Yes: utils.ToPtr(true), + Params: types.Params{{Key: "key1", Value: []string{"value1"}}}, + Headers: types.Headers{{Key: "User-Agent", Value: []string{"TestAgent"}}}, + Cookies: types.Cookies{{Key: "session", Value: []string{"123"}}}, + Body: types.Body{"test body"}, + Proxies: types.Proxies{url.URL{Scheme: "http", Host: "proxy.example.com:8080"}}, + } + + // Call the function being tested + rc := NewRequestConfig(conf) + + // Assert the fields are correctly mapped + assert.Equal(t, "GET", rc.Method) + assert.Equal(t, "https://example.com", rc.URL.String()) + assert.Equal(t, 5*time.Second, rc.Timeout) + assert.Equal(t, uint(10), rc.DodosCount) + assert.Equal(t, uint(100), rc.RequestCount) + assert.Equal(t, 1*time.Minute, rc.Duration) + assert.True(t, rc.Yes) + assert.Equal(t, types.Params{{Key: "key1", Value: []string{"value1"}}}, rc.Params) + assert.Equal(t, types.Headers{{Key: "User-Agent", Value: []string{"TestAgent"}}}, rc.Headers) + assert.Equal(t, types.Cookies{{Key: "session", Value: []string{"123"}}}, rc.Cookies) + assert.Equal(t, types.Body{"test body"}, rc.Body) + assert.Equal(t, types.Proxies{url.URL{Scheme: "http", Host: "proxy.example.com:8080"}}, rc.Proxies) +} + +func TestGetValidDodosCountForRequests(t *testing.T) { + tests := []struct { + name string + dodosCount uint + requestCount uint + expected uint + }{ + { + name: "no request count limit", + dodosCount: 10, + requestCount: 0, + expected: 10, + }, + { + name: "dodos count less than request count", + dodosCount: 5, + requestCount: 100, + expected: 5, + }, + { + name: "dodos count greater than request count", + dodosCount: 100, + requestCount: 10, + expected: 10, + }, + { + name: "dodos count equal to request count", + dodosCount: 50, + requestCount: 50, + expected: 50, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rc := &RequestConfig{ + DodosCount: tt.dodosCount, + RequestCount: tt.requestCount, + } + result := rc.GetValidDodosCountForRequests() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGetMaxConns(t *testing.T) { + tests := []struct { + name string + dodosCount uint + requestCount uint + minConns uint + expected uint + }{ + { + name: "min connections higher than valid dodos count", + dodosCount: 10, + requestCount: 0, + minConns: 20, + expected: 30, // 20 * 150% + }, + { + name: "min connections lower than valid dodos count", + dodosCount: 30, + requestCount: 0, + minConns: 10, + expected: 45, // 30 * 150% + }, + { + name: "request count limits dodos count", + dodosCount: 100, + requestCount: 20, + minConns: 5, + expected: 30, // 20 * 150% + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rc := &RequestConfig{ + DodosCount: tt.dodosCount, + RequestCount: tt.requestCount, + } + result := rc.GetMaxConns(tt.minConns) + assert.Equal(t, tt.expected, result) + }) + } +} + +// Skip the Print test as it's mainly a formatting function +// that uses external table rendering library +func TestRequestConfigPrint(t *testing.T) { + // Create a sample RequestConfig + rc := &RequestConfig{ + Method: "GET", + URL: url.URL{Scheme: "https", Host: "example.com"}, + Timeout: 5 * time.Second, + DodosCount: 10, + RequestCount: 100, + Duration: 1 * time.Minute, + Params: types.Params{{Key: "param1", Value: []string{"value1"}}}, + Headers: types.Headers{{Key: "User-Agent", Value: []string{"TestAgent"}}}, + Cookies: types.Cookies{{Key: "session", Value: []string{"123"}}}, + Body: types.Body{"test body"}, + Proxies: types.Proxies{url.URL{Scheme: "http", Host: "proxy.example.com:8080"}}, + } + + // We'll just call the function to ensure it doesn't panic + // Redirect output to /dev/null + origStdout := os.Stdout + devNull, _ := os.Open(os.DevNull) + os.Stdout = devNull + + // Call the function + rc.Print() + + // Restore stdout + os.Stdout = origStdout + + // No assertions needed, we're just checking that it doesn't panic +} + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config *Config + expectedErrors int + expectURLQuery bool + }{ + { + name: "valid config", + config: &Config{ + Method: utils.ToPtr("GET"), + URL: func() *types.RequestURL { u := &types.RequestURL{}; u.Set("https://example.com"); return u }(), + Timeout: &types.Timeout{Duration: 5 * time.Second}, + DodosCount: utils.ToPtr(uint(10)), + RequestCount: utils.ToPtr(uint(100)), + Duration: &types.Duration{Duration: 1 * time.Minute}, + }, + expectedErrors: 0, + expectURLQuery: false, + }, + { + name: "missing URL", + config: &Config{ + Method: utils.ToPtr("GET"), + URL: nil, + Timeout: &types.Timeout{Duration: 5 * time.Second}, + DodosCount: utils.ToPtr(uint(10)), + RequestCount: utils.ToPtr(uint(100)), + Duration: &types.Duration{Duration: 1 * time.Minute}, + }, + expectedErrors: 1, + expectURLQuery: false, + }, + { + name: "missing method", + config: &Config{ + Method: nil, + URL: func() *types.RequestURL { u := &types.RequestURL{}; u.Set("https://example.com"); return u }(), + Timeout: &types.Timeout{Duration: 5 * time.Second}, + DodosCount: utils.ToPtr(uint(10)), + RequestCount: utils.ToPtr(uint(100)), + Duration: &types.Duration{Duration: 1 * time.Minute}, + }, + expectedErrors: 1, + expectURLQuery: false, + }, + { + name: "invalid URL scheme", + config: &Config{ + Method: utils.ToPtr("GET"), + URL: func() *types.RequestURL { u := &types.RequestURL{}; u.Scheme = "ftp"; u.Host = "example.com"; return u }(), + Timeout: &types.Timeout{Duration: 5 * time.Second}, + DodosCount: utils.ToPtr(uint(10)), + RequestCount: utils.ToPtr(uint(100)), + Duration: &types.Duration{Duration: 1 * time.Minute}, + }, + expectedErrors: 1, + expectURLQuery: false, + }, + { + name: "missing both duration and request count", + config: &Config{ + Method: utils.ToPtr("GET"), + URL: func() *types.RequestURL { u := &types.RequestURL{}; u.Set("https://example.com"); return u }(), + Timeout: &types.Timeout{Duration: 5 * time.Second}, + DodosCount: utils.ToPtr(uint(10)), + RequestCount: utils.ToPtr(uint(0)), + Duration: &types.Duration{Duration: 0}, + }, + expectedErrors: 1, + expectURLQuery: false, + }, + { + name: "URL with query parameters", + config: &Config{ + Method: utils.ToPtr("GET"), + URL: func() *types.RequestURL { + u := &types.RequestURL{} + u.Set("https://example.com?param1=value1¶m2=value2") + return u + }(), + Timeout: &types.Timeout{Duration: 5 * time.Second}, + DodosCount: utils.ToPtr(uint(10)), + RequestCount: utils.ToPtr(uint(100)), + Duration: &types.Duration{Duration: 1 * time.Minute}, + }, + expectedErrors: 0, + expectURLQuery: true, + }, + { + name: "invalid proxy scheme", + config: &Config{ + Method: utils.ToPtr("GET"), + URL: func() *types.RequestURL { u := &types.RequestURL{}; u.Set("https://example.com"); return u }(), + Timeout: &types.Timeout{Duration: 5 * time.Second}, + DodosCount: utils.ToPtr(uint(10)), + RequestCount: utils.ToPtr(uint(100)), + Duration: &types.Duration{Duration: 1 * time.Minute}, + Proxies: types.Proxies{url.URL{Scheme: "invalid", Host: "proxy.example.com:8080"}}, + }, + expectedErrors: 1, + expectURLQuery: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errors := tt.config.Validate() + + // Check number of errors + assert.Len(t, errors, tt.expectedErrors) + + // Check if URL query parameters are extracted properly + if tt.expectURLQuery { + assert.Empty(t, tt.config.URL.RawQuery) + assert.NotEmpty(t, tt.config.Params) + found := false + for _, param := range tt.config.Params { + if param.Key == "param1" && len(param.Value) > 0 && param.Value[0] == "value1" { + found = true + break + } + } + assert.True(t, found, "Expected param1=value1 in Params but not found") + } + }) + } +} + +func TestMergeConfig(t *testing.T) { + baseURL := func() *types.RequestURL { u := &types.RequestURL{}; u.Set("https://example.com"); return u }() + newURL := func() *types.RequestURL { u := &types.RequestURL{}; u.Set("https://new-example.com"); return u }() + + baseConfig := &Config{ + Method: utils.ToPtr("GET"), + URL: baseURL, + Timeout: &types.Timeout{Duration: 5 * time.Second}, + DodosCount: utils.ToPtr(uint(10)), + RequestCount: utils.ToPtr(uint(100)), + Duration: &types.Duration{Duration: 1 * time.Minute}, + Yes: utils.ToPtr(false), + Params: types.Params{{Key: "base-param", Value: []string{"base-value"}}}, + Headers: types.Headers{{Key: "base-header", Value: []string{"base-value"}}}, + Cookies: types.Cookies{{Key: "base-cookie", Value: []string{"base-value"}}}, + Body: types.Body{"base-body"}, + Proxies: types.Proxies{url.URL{Scheme: "http", Host: "base-proxy.example.com:8080"}}, + } + + tests := []struct { + name string + newConfig *Config + assertions func(t *testing.T, result *Config) + }{ + { + name: "merge all fields", + newConfig: &Config{ + Method: utils.ToPtr("POST"), + URL: newURL, + Timeout: &types.Timeout{Duration: 10 * time.Second}, + DodosCount: utils.ToPtr(uint(20)), + RequestCount: utils.ToPtr(uint(200)), + Duration: &types.Duration{Duration: 2 * time.Minute}, + Yes: utils.ToPtr(true), + Params: types.Params{{Key: "new-param", Value: []string{"new-value"}}}, + Headers: types.Headers{{Key: "new-header", Value: []string{"new-value"}}}, + Cookies: types.Cookies{{Key: "new-cookie", Value: []string{"new-value"}}}, + Body: types.Body{"new-body"}, + Proxies: types.Proxies{url.URL{Scheme: "http", Host: "new-proxy.example.com:8080"}}, + }, + assertions: func(t *testing.T, result *Config) { + assert.Equal(t, "POST", *result.Method) + assert.Equal(t, "https://new-example.com", result.URL.String()) + assert.Equal(t, 10*time.Second, result.Timeout.Duration) + assert.Equal(t, uint(20), *result.DodosCount) + assert.Equal(t, uint(200), *result.RequestCount) + assert.Equal(t, 2*time.Minute, result.Duration.Duration) + assert.True(t, *result.Yes) + assert.Equal(t, types.Params{{Key: "new-param", Value: []string{"new-value"}}}, result.Params) + assert.Equal(t, types.Headers{{Key: "new-header", Value: []string{"new-value"}}}, result.Headers) + assert.Equal(t, types.Cookies{{Key: "new-cookie", Value: []string{"new-value"}}}, result.Cookies) + assert.Equal(t, types.Body{"new-body"}, result.Body) + assert.Equal(t, types.Proxies{url.URL{Scheme: "http", Host: "new-proxy.example.com:8080"}}, result.Proxies) + }, + }, + { + name: "merge only specified fields", + newConfig: &Config{ + Method: utils.ToPtr("POST"), + URL: newURL, + Yes: utils.ToPtr(true), + }, + assertions: func(t *testing.T, result *Config) { + assert.Equal(t, "POST", *result.Method) + assert.Equal(t, "https://new-example.com", result.URL.String()) + assert.Equal(t, 5*time.Second, result.Timeout.Duration) // unchanged + assert.Equal(t, uint(10), *result.DodosCount) // unchanged + assert.Equal(t, uint(100), *result.RequestCount) // unchanged + assert.Equal(t, 1*time.Minute, result.Duration.Duration) // unchanged + assert.True(t, *result.Yes) // changed + assert.Equal(t, types.Params{{Key: "base-param", Value: []string{"base-value"}}}, result.Params) // unchanged + assert.Equal(t, types.Headers{{Key: "base-header", Value: []string{"base-value"}}}, result.Headers) // unchanged + assert.Equal(t, types.Cookies{{Key: "base-cookie", Value: []string{"base-value"}}}, result.Cookies) // unchanged + assert.Equal(t, types.Body{"base-body"}, result.Body) // unchanged + assert.Equal(t, types.Proxies{url.URL{Scheme: "http", Host: "base-proxy.example.com:8080"}}, result.Proxies) // unchanged + }, + }, + { + name: "merge empty config", + newConfig: &Config{}, + assertions: func(t *testing.T, result *Config) { + // All fields should remain unchanged + assert.Equal(t, "GET", *result.Method) + assert.Equal(t, "https://example.com", result.URL.String()) + assert.Equal(t, 5*time.Second, result.Timeout.Duration) + assert.Equal(t, uint(10), *result.DodosCount) + assert.Equal(t, uint(100), *result.RequestCount) + assert.Equal(t, 1*time.Minute, result.Duration.Duration) + assert.False(t, *result.Yes) + assert.Equal(t, types.Params{{Key: "base-param", Value: []string{"base-value"}}}, result.Params) + assert.Equal(t, types.Headers{{Key: "base-header", Value: []string{"base-value"}}}, result.Headers) + assert.Equal(t, types.Cookies{{Key: "base-cookie", Value: []string{"base-value"}}}, result.Cookies) + assert.Equal(t, types.Body{"base-body"}, result.Body) + assert.Equal(t, types.Proxies{url.URL{Scheme: "http", Host: "base-proxy.example.com:8080"}}, result.Proxies) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a copy of the base config for each test + baseURL := func() *types.RequestURL { u := &types.RequestURL{}; u.Set("https://example.com"); return u }() + testConfig := &Config{ + Method: utils.ToPtr(*baseConfig.Method), + URL: baseURL, + Timeout: &types.Timeout{Duration: baseConfig.Timeout.Duration}, + DodosCount: utils.ToPtr(*baseConfig.DodosCount), + RequestCount: utils.ToPtr(*baseConfig.RequestCount), + Duration: &types.Duration{Duration: baseConfig.Duration.Duration}, + Yes: utils.ToPtr(*baseConfig.Yes), + Params: append(types.Params{}, baseConfig.Params...), + Headers: append(types.Headers{}, baseConfig.Headers...), + Cookies: append(types.Cookies{}, baseConfig.Cookies...), + Body: append(types.Body{}, baseConfig.Body...), + Proxies: append(types.Proxies{}, baseConfig.Proxies...), + } + + // Call the function being tested + testConfig.MergeConfig(tt.newConfig) + + // Run assertions + tt.assertions(t, testConfig) + }) + } +} + +func TestSetDefaults(t *testing.T) { + tests := []struct { + name string + config *Config + validate func(t *testing.T, config *Config) + }{ + { + name: "empty config", + config: &Config{}, + validate: func(t *testing.T, config *Config) { + assert.Equal(t, DefaultMethod, *config.Method) + assert.Equal(t, DefaultTimeout, config.Timeout.Duration) + assert.Equal(t, DefaultDodosCount, *config.DodosCount) + assert.Equal(t, DefaultRequestCount, *config.RequestCount) + assert.Equal(t, DefaultDuration, config.Duration.Duration) + assert.Equal(t, DefaultYes, *config.Yes) + assert.True(t, config.Headers.Has("User-Agent")) + userAgent := config.Headers.GetValue("User-Agent") + assert.NotNil(t, userAgent) + assert.Contains(t, (*userAgent)[0], DefaultUserAgent) + }, + }, + { + name: "partial config", + config: &Config{ + Method: utils.ToPtr("POST"), + Timeout: &types.Timeout{Duration: 30 * time.Second}, + Headers: types.Headers{{Key: "Custom-Header", Value: []string{"value"}}}, + }, + validate: func(t *testing.T, config *Config) { + assert.Equal(t, "POST", *config.Method) // should keep existing value + assert.Equal(t, 30*time.Second, config.Timeout.Duration) // should keep existing value + assert.Equal(t, DefaultDodosCount, *config.DodosCount) // should set default + assert.Equal(t, DefaultRequestCount, *config.RequestCount) // should set default + assert.Equal(t, DefaultDuration, config.Duration.Duration) // should set default + assert.Equal(t, DefaultYes, *config.Yes) // should set default + assert.True(t, config.Headers.Has("Custom-Header")) // should keep existing header + assert.True(t, config.Headers.Has("User-Agent")) // should add User-Agent + userAgent := config.Headers.GetValue("User-Agent") + assert.NotNil(t, userAgent) + assert.Contains(t, (*userAgent)[0], DefaultUserAgent) + }, + }, + { + name: "complete config", + config: &Config{ + Method: utils.ToPtr("DELETE"), + URL: func() *types.RequestURL { u := &types.RequestURL{}; u.Set("https://example.com"); return u }(), + Timeout: &types.Timeout{Duration: 15 * time.Second}, + DodosCount: utils.ToPtr(uint(5)), + RequestCount: utils.ToPtr(uint(500)), + Duration: &types.Duration{Duration: 5 * time.Minute}, + Yes: utils.ToPtr(true), + Headers: types.Headers{{Key: "User-Agent", Value: []string{"CustomAgent"}}}, + }, + validate: func(t *testing.T, config *Config) { + assert.Equal(t, "DELETE", *config.Method) + assert.Equal(t, 15*time.Second, config.Timeout.Duration) + assert.Equal(t, uint(5), *config.DodosCount) + assert.Equal(t, uint(500), *config.RequestCount) + assert.Equal(t, 5*time.Minute, config.Duration.Duration) + assert.True(t, *config.Yes) + assert.True(t, config.Headers.Has("User-Agent")) + userAgent := config.Headers.GetValue("User-Agent") + assert.NotNil(t, userAgent) + assert.Equal(t, "CustomAgent", (*userAgent)[0]) // should keep custom user agent + assert.NotEqual(t, DefaultUserAgent, (*userAgent)[0]) // should not overwrite + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Call the function being tested + tt.config.SetDefaults() + + // Validate the result + tt.validate(t, tt.config) + }) + } +} diff --git a/config/file_test.go b/config/file_test.go new file mode 100644 index 0000000..e77506a --- /dev/null +++ b/config/file_test.go @@ -0,0 +1,350 @@ +package config + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/aykhans/dodo/types" + "github.com/stretchr/testify/assert" +) + +func TestReadFile(t *testing.T) { + // Create temporary files for testing + tempDir := t.TempDir() + + // Create a valid JSON config file + validJSONFile := filepath.Join(tempDir, "valid.json") + jsonContent := `{ + "method": "POST", + "url": "https://example.com", + "timeout": "5s", + "dodos": 5, + "requests": 100, + "duration": "1m", + "yes": true, + "headers": [{"Content-Type": "application/json"}] + }` + err := os.WriteFile(validJSONFile, []byte(jsonContent), 0644) + assert.NoError(t, err) + + // Create a valid YAML config file + validYAMLFile := filepath.Join(tempDir, "valid.yaml") + yamlContent := ` +method: POST +url: https://example.com +timeout: 5s +dodos: 5 +requests: 100 +duration: 1m +yes: true +headers: + - Content-Type: application/json +` + err = os.WriteFile(validYAMLFile, []byte(yamlContent), 0644) + assert.NoError(t, err) + + // Create an invalid JSON config file + invalidJSONFile := filepath.Join(tempDir, "invalid.json") + invalidJSONContent := `{ + "method": "POST", + "url": "https://example.com", + syntax error + }` + err = os.WriteFile(invalidJSONFile, []byte(invalidJSONContent), 0644) + assert.NoError(t, err) + + // Create a file with unsupported extension + unsupportedFile := filepath.Join(tempDir, "config.txt") + err = os.WriteFile(unsupportedFile, []byte("some content"), 0644) + assert.NoError(t, err) + + // Setup HTTP test server for remote config + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/valid.json": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + io.WriteString(w, jsonContent) + case "/invalid.json": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + io.WriteString(w, invalidJSONContent) + case "/valid.yaml": + w.Header().Set("Content-Type", "application/yaml") + w.WriteHeader(http.StatusOK) + io.WriteString(w, yamlContent) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + tests := []struct { + name string + filePath types.ConfigFile + expectErr bool + validate func(t *testing.T, config *Config) + }{ + { + name: "valid local JSON file", + filePath: types.ConfigFile(validJSONFile), + expectErr: false, + validate: func(t *testing.T, config *Config) { + assert.Equal(t, "POST", *config.Method) + assert.Equal(t, "https://example.com", config.URL.String()) + assert.Equal(t, int64(5000000000), config.Timeout.Nanoseconds()) + assert.Equal(t, uint(5), *config.DodosCount) + assert.Equal(t, uint(100), *config.RequestCount) + assert.Equal(t, int64(60000000000), config.Duration.Nanoseconds()) + assert.True(t, *config.Yes) + assert.Equal(t, 1, len(config.Headers)) + assert.Equal(t, "Content-Type", config.Headers[0].Key) + assert.Equal(t, "application/json", config.Headers[0].Value[0]) + }, + }, + { + name: "valid local YAML file", + filePath: types.ConfigFile(validYAMLFile), + expectErr: false, + validate: func(t *testing.T, config *Config) { + assert.Equal(t, "POST", *config.Method) + assert.Equal(t, "https://example.com", config.URL.String()) + assert.Equal(t, int64(5000000000), config.Timeout.Nanoseconds()) + assert.Equal(t, uint(5), *config.DodosCount) + assert.Equal(t, uint(100), *config.RequestCount) + assert.Equal(t, int64(60000000000), config.Duration.Nanoseconds()) + assert.True(t, *config.Yes) + assert.Equal(t, 1, len(config.Headers)) + assert.Equal(t, "Content-Type", config.Headers[0].Key) + assert.Equal(t, "application/json", config.Headers[0].Value[0]) + }, + }, + { + name: "invalid local JSON file", + filePath: types.ConfigFile(invalidJSONFile), + expectErr: true, + validate: func(t *testing.T, config *Config) {}, + }, + { + name: "unsupported file extension", + filePath: types.ConfigFile(unsupportedFile), + expectErr: true, + validate: func(t *testing.T, config *Config) {}, + }, + { + name: "non-existent file", + filePath: types.ConfigFile(filepath.Join(tempDir, "nonexistent.json")), + expectErr: true, + validate: func(t *testing.T, config *Config) {}, + }, + { + name: "valid remote JSON file", + filePath: types.ConfigFile(server.URL + "/valid.json"), + expectErr: false, + validate: func(t *testing.T, config *Config) { + assert.Equal(t, "POST", *config.Method) + assert.Equal(t, "https://example.com", config.URL.String()) + assert.Equal(t, int64(5000000000), config.Timeout.Nanoseconds()) + assert.Equal(t, uint(5), *config.DodosCount) + assert.Equal(t, uint(100), *config.RequestCount) + assert.Equal(t, int64(60000000000), config.Duration.Nanoseconds()) + assert.True(t, *config.Yes) + assert.Equal(t, 1, len(config.Headers)) + assert.Equal(t, "Content-Type", config.Headers[0].Key) + assert.Equal(t, "application/json", config.Headers[0].Value[0]) + }, + }, + { + name: "valid remote YAML file", + filePath: types.ConfigFile(server.URL + "/valid.yaml"), + expectErr: false, + validate: func(t *testing.T, config *Config) { + assert.Equal(t, "POST", *config.Method) + assert.Equal(t, "https://example.com", config.URL.String()) + assert.Equal(t, int64(5000000000), config.Timeout.Nanoseconds()) + assert.Equal(t, uint(5), *config.DodosCount) + assert.Equal(t, uint(100), *config.RequestCount) + assert.Equal(t, int64(60000000000), config.Duration.Nanoseconds()) + assert.True(t, *config.Yes) + assert.Equal(t, 1, len(config.Headers)) + assert.Equal(t, "Content-Type", config.Headers[0].Key) + assert.Equal(t, "application/json", config.Headers[0].Value[0]) + }, + }, + { + name: "invalid remote JSON file", + filePath: types.ConfigFile(server.URL + "/invalid.json"), + expectErr: true, + validate: func(t *testing.T, config *Config) {}, + }, + { + name: "non-existent remote file", + filePath: types.ConfigFile(server.URL + "/nonexistent.json"), + expectErr: true, + validate: func(t *testing.T, config *Config) {}, + }, + { + name: "invalid URL", + filePath: types.ConfigFile("http://nonexistent.example.com/config.json"), + expectErr: true, + validate: func(t *testing.T, config *Config) {}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := NewConfig() + err := config.ReadFile(tt.filePath) + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + tt.validate(t, config) + } + }) + } +} + +func TestParseJSONConfig(t *testing.T) { + tests := []struct { + name string + jsonData string + expectErr bool + validate func(t *testing.T, config *Config) + }{ + { + name: "valid JSON config", + jsonData: `{ + "method": "POST", + "url": "https://example.com", + "timeout": "5s", + "dodos": 5, + "requests": 100 + }`, + expectErr: false, + validate: func(t *testing.T, config *Config) { + assert.Equal(t, "POST", *config.Method) + assert.Equal(t, "https://example.com", config.URL.String()) + assert.Equal(t, int64(5000000000), config.Timeout.Nanoseconds()) + assert.Equal(t, uint(5), *config.DodosCount) + assert.Equal(t, uint(100), *config.RequestCount) + }, + }, + { + name: "invalid JSON syntax", + jsonData: `{ + "method": "POST", + "url": "https://example.com", + syntax error + }`, + expectErr: true, + validate: func(t *testing.T, config *Config) {}, + }, + { + name: "invalid type for field", + jsonData: `{ + "method": "POST", + "url": "https://example.com", + "dodos": "not-a-number" + }`, + expectErr: true, + validate: func(t *testing.T, config *Config) {}, + }, + { + name: "empty JSON object", + jsonData: `{}`, + expectErr: false, + validate: func(t *testing.T, config *Config) { + assert.Nil(t, config.Method) + assert.Nil(t, config.URL) + assert.Nil(t, config.Timeout) + assert.Nil(t, config.DodosCount) + assert.Nil(t, config.RequestCount) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := NewConfig() + err := parseJSONConfig([]byte(tt.jsonData), config) + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + tt.validate(t, config) + } + }) + } +} + +func TestParseYAMLConfig(t *testing.T) { + tests := []struct { + name string + yamlData string + expectErr bool + validate func(t *testing.T, config *Config) + }{ + { + name: "valid YAML config", + yamlData: ` +method: POST +url: https://example.com +timeout: 5s +dodos: 5 +requests: 100 +`, + expectErr: false, + validate: func(t *testing.T, config *Config) { + assert.Equal(t, "POST", *config.Method) + assert.Equal(t, "https://example.com", config.URL.String()) + assert.Equal(t, int64(5000000000), config.Timeout.Nanoseconds()) + assert.Equal(t, uint(5), *config.DodosCount) + assert.Equal(t, uint(100), *config.RequestCount) + }, + }, + { + name: "invalid YAML syntax", + yamlData: ` +method: POST +url: https://example.com +dodos: 5 + invalid indentation +`, + expectErr: true, + validate: func(t *testing.T, config *Config) {}, + }, + { + name: "empty YAML", + yamlData: ``, + expectErr: false, + validate: func(t *testing.T, config *Config) { + assert.Nil(t, config.Method) + assert.Nil(t, config.URL) + assert.Nil(t, config.Timeout) + assert.Nil(t, config.DodosCount) + assert.Nil(t, config.RequestCount) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := NewConfig() + err := parseYAMLConfig([]byte(tt.yamlData), config) + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + tt.validate(t, config) + } + }) + } +} diff --git a/go.mod b/go.mod index e027790..1dc9c02 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,17 @@ go 1.24.0 require ( github.com/jedib0t/go-pretty/v6 v6.6.7 + github.com/stretchr/testify v1.10.0 github.com/valyala/fasthttp v1.60.0 gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/andybalholm/brotli v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect golang.org/x/net v0.38.0 // indirect