mirror of
https://github.com/aykhans/sarin.git
synced 2026-02-28 14:59:14 +00:00
Introduce structured error types and bump Go/linter versions
Replace ad-hoc fmt.Errorf/errors.New calls with typed error structs across config, sarin, and script packages to enable type-based error handling. Add script-specific error handlers in CLI entry point. Fix variable shadowing bug in Worker for scriptTransformer. Bump Go to 1.25.7 and golangci-lint to v2.8.0.
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -95,6 +94,9 @@ func NewHostClients(
|
||||
return []*fasthttp.HostClient{client}, nil
|
||||
}
|
||||
|
||||
// NewProxyDialFunc creates a dial function for the given proxy URL.
|
||||
// It can return the following errors:
|
||||
// - types.ProxyUnsupportedSchemeError
|
||||
func NewProxyDialFunc(ctx context.Context, proxyURL *url.URL, timeout time.Duration) (fasthttp.DialFunc, error) {
|
||||
var (
|
||||
dialer fasthttp.DialFunc
|
||||
@@ -117,16 +119,14 @@ func NewProxyDialFunc(ctx context.Context, proxyURL *url.URL, timeout time.Durat
|
||||
case "https":
|
||||
dialer = fasthttpHTTPSDialerDualStackTimeout(proxyURL, timeout)
|
||||
default:
|
||||
return nil, errors.New("unsupported proxy scheme")
|
||||
}
|
||||
|
||||
if dialer == nil {
|
||||
return nil, errors.New("internal error: proxy dialer is nil")
|
||||
return nil, types.NewProxyUnsupportedSchemeError(proxyURL.Scheme)
|
||||
}
|
||||
|
||||
return dialer, nil
|
||||
}
|
||||
|
||||
// The returned dial function can return the following errors:
|
||||
// - types.ProxyDialError
|
||||
func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL, timeout time.Duration, resolveLocally bool) (fasthttp.DialFunc, error) {
|
||||
netDialer := &net.Dialer{}
|
||||
|
||||
@@ -147,12 +147,18 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyStr := proxyURL.String()
|
||||
|
||||
// Assert to ContextDialer for timeout support
|
||||
contextDialer, ok := socksDialer.(proxy.ContextDialer)
|
||||
if !ok {
|
||||
// Fallback without timeout (should not happen with net.Dialer)
|
||||
return func(addr string) (net.Conn, error) {
|
||||
return socksDialer.Dial("tcp", addr)
|
||||
conn, err := socksDialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
return conn, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -163,7 +169,7 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
|
||||
if resolveLocally {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Cap DNS resolution to half the timeout to reserve time for dial
|
||||
@@ -171,10 +177,10 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
|
||||
ips, err := net.DefaultResolver.LookupIP(dnsCtx, "ip", host)
|
||||
dnsCancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, errors.New("no IP addresses found for host: " + host)
|
||||
return nil, types.NewProxyDialError(proxyStr, types.NewProxyResolveError(host))
|
||||
}
|
||||
|
||||
// Use the first resolved IP
|
||||
@@ -184,16 +190,22 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
|
||||
// Use remaining time for dial
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
return nil, context.DeadlineExceeded
|
||||
return nil, types.NewProxyDialError(proxyStr, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, remaining)
|
||||
defer dialCancel()
|
||||
|
||||
return contextDialer.DialContext(dialCtx, "tcp", addr)
|
||||
conn, err := contextDialer.DialContext(dialCtx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
return conn, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// The returned dial function can return the following errors:
|
||||
// - types.ProxyDialError
|
||||
func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duration) fasthttp.DialFunc {
|
||||
proxyAddr := proxyURL.Host
|
||||
if proxyURL.Port() == "" {
|
||||
@@ -209,24 +221,26 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
|
||||
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
}
|
||||
|
||||
proxyStr := proxyURL.String()
|
||||
|
||||
return func(addr string) (net.Conn, error) {
|
||||
// Establish TCP connection to proxy with timeout
|
||||
start := time.Now()
|
||||
conn, err := fasthttp.DialDualStackTimeout(proxyAddr, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
remaining := timeout - time.Since(start)
|
||||
if remaining <= 0 {
|
||||
conn.Close() //nolint:errcheck,gosec
|
||||
return nil, context.DeadlineExceeded
|
||||
return nil, types.NewProxyDialError(proxyStr, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// Set deadline for the TLS handshake and CONNECT request
|
||||
if err := conn.SetDeadline(time.Now().Add(remaining)); err != nil {
|
||||
conn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Upgrade to TLS
|
||||
@@ -235,7 +249,7 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
|
||||
})
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Build and send CONNECT request
|
||||
@@ -251,7 +265,7 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
|
||||
|
||||
if err := connectReq.Write(tlsConn); err != nil {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Read response using buffered reader, but return wrapped connection
|
||||
@@ -260,19 +274,19 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
|
||||
resp, err := http.ReadResponse(bufReader, connectReq)
|
||||
if err != nil {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
resp.Body.Close() //nolint:errcheck,gosec
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, errors.New("proxy CONNECT failed: " + resp.Status)
|
||||
return nil, types.NewProxyDialError(proxyStr, types.NewProxyConnectError(resp.Status))
|
||||
}
|
||||
|
||||
// Clear deadline for the tunneled connection
|
||||
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Return wrapped connection that uses the buffered reader
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package sarin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -10,6 +9,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.aykhans.me/sarin/internal/types"
|
||||
)
|
||||
|
||||
// CachedFile holds the cached content and metadata of a file.
|
||||
@@ -31,6 +32,10 @@ func NewFileCache(requestTimeout time.Duration) *FileCache {
|
||||
|
||||
// GetOrLoad retrieves a file from cache or loads it using the provided source.
|
||||
// The source can be a local file path or an HTTP/HTTPS URL.
|
||||
// It can return the following errors:
|
||||
// - types.FileReadError
|
||||
// - types.HTTPFetchError
|
||||
// - types.HTTPStatusError
|
||||
func (fc *FileCache) GetOrLoad(source string) (*CachedFile, error) {
|
||||
if val, ok := fc.cache.Load(source); ok {
|
||||
return val.(*CachedFile), nil
|
||||
@@ -59,14 +64,21 @@ func (fc *FileCache) GetOrLoad(source string) (*CachedFile, error) {
|
||||
return actual.(*CachedFile), nil
|
||||
}
|
||||
|
||||
// readLocalFile reads a file from the local filesystem and returns its content and filename.
|
||||
// It can return the following errors:
|
||||
// - types.FileReadError
|
||||
func (fc *FileCache) readLocalFile(filePath string) ([]byte, string, error) {
|
||||
content, err := os.ReadFile(filePath) //nolint:gosec
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to read file %s: %w", filePath, err)
|
||||
return nil, "", types.NewFileReadError(filePath, err)
|
||||
}
|
||||
return content, filepath.Base(filePath), nil
|
||||
}
|
||||
|
||||
// fetchURL downloads file contents from an HTTP/HTTPS URL.
|
||||
// It can return the following errors:
|
||||
// - types.HTTPFetchError
|
||||
// - types.HTTPStatusError
|
||||
func (fc *FileCache) fetchURL(url string) ([]byte, string, error) {
|
||||
client := &http.Client{
|
||||
Timeout: fc.requestTimeout,
|
||||
@@ -74,17 +86,17 @@ func (fc *FileCache) fetchURL(url string) ([]byte, string, error) {
|
||||
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to fetch URL %s: %w", url, err)
|
||||
return nil, "", types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("failed to fetch URL %s: HTTP %d", url, resp.StatusCode)
|
||||
return nil, "", types.NewHTTPStatusError(url, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to read response body from %s: %w", url, err)
|
||||
return nil, "", types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
|
||||
// Extract filename from URL path
|
||||
|
||||
@@ -2,7 +2,6 @@ package sarin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math/rand/v2"
|
||||
"net/url"
|
||||
@@ -261,12 +260,12 @@ func NewValuesGeneratorFunc(values []string, templateFunctions template.FuncMap)
|
||||
for _, generator := range generators {
|
||||
rendered, err = generator(nil)
|
||||
if err != nil {
|
||||
return valuesData{}, fmt.Errorf("values rendering: %w", err)
|
||||
return valuesData{}, types.NewTemplateRenderError(err)
|
||||
}
|
||||
|
||||
data, err = godotenv.Unmarshal(rendered)
|
||||
if err != nil {
|
||||
return valuesData{}, fmt.Errorf("values rendering: %w", err)
|
||||
return valuesData{}, types.NewTemplateRenderError(err)
|
||||
}
|
||||
|
||||
maps.Copy(result, data)
|
||||
@@ -283,7 +282,7 @@ func createTemplateFunc(value string, templateFunctions template.FuncMap) (func(
|
||||
return func(data any) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
if err = tmpl.Execute(&buf, data); err != nil {
|
||||
return "", fmt.Errorf("template rendering: %w", err)
|
||||
return "", types.NewTemplateRenderError(err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}, true
|
||||
|
||||
@@ -58,8 +58,9 @@ type sarin struct {
|
||||
|
||||
// NewSarin creates a new sarin instance for load testing.
|
||||
// It can return the following errors:
|
||||
// - types.ProxyDialError
|
||||
// - script loading errors
|
||||
// - types.ProxyDialError
|
||||
// - types.ErrScriptEmpty
|
||||
// - types.ScriptLoadError
|
||||
func NewSarin(
|
||||
ctx context.Context,
|
||||
methods []string,
|
||||
@@ -216,7 +217,8 @@ func (q sarin) Worker(
|
||||
// Scripts are pre-validated in NewSarin, so this should not fail
|
||||
var scriptTransformer *script.Transformer
|
||||
if !q.scriptChain.IsEmpty() {
|
||||
scriptTransformer, err := q.scriptChain.NewTransformer()
|
||||
var err error
|
||||
scriptTransformer, err = q.scriptChain.NewTransformer()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package sarin
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"math/rand/v2"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v7"
|
||||
"go.aykhans.me/sarin/internal/types"
|
||||
)
|
||||
|
||||
func NewDefaultTemplateFuncMap(randSource rand.Source, fileCache *FileCache) template.FuncMap {
|
||||
@@ -90,7 +90,7 @@ func NewDefaultTemplateFuncMap(randSource rand.Source, fileCache *FileCache) tem
|
||||
// {{ file_Base64 "https://example.com/image.png" }}
|
||||
"file_Base64": func(source string) (string, error) {
|
||||
if fileCache == nil {
|
||||
return "", errors.New("file cache is not initialized")
|
||||
return "", types.ErrFileCacheNotInitialized
|
||||
}
|
||||
cached, err := fileCache.GetOrLoad(source)
|
||||
if err != nil {
|
||||
@@ -582,7 +582,7 @@ func NewDefaultBodyTemplateFuncMap(
|
||||
// {{ body_FormData "name" "John" "avatar" "@/path/to/photo.jpg" "doc" "@https://example.com/file.pdf" }}
|
||||
funcMap["body_FormData"] = func(pairs ...string) (string, error) {
|
||||
if len(pairs)%2 != 0 {
|
||||
return "", errors.New("body_FormData requires an even number of arguments (key-value pairs)")
|
||||
return "", types.ErrFormDataOddArgs
|
||||
}
|
||||
|
||||
var multipartData bytes.Buffer
|
||||
@@ -602,7 +602,7 @@ func NewDefaultBodyTemplateFuncMap(
|
||||
case strings.HasPrefix(val, "@"):
|
||||
// File (local path or remote URL)
|
||||
if fileCache == nil {
|
||||
return "", errors.New("file cache is not initialized")
|
||||
return "", types.ErrFileCacheNotInitialized
|
||||
}
|
||||
source := val[1:]
|
||||
cached, err := fileCache.GetOrLoad(source)
|
||||
|
||||
Reference in New Issue
Block a user