mirror of
https://github.com/aykhans/sarin.git
synced 2026-02-28 14:59:14 +00:00
Compare commits
2 Commits
e83eacf380
...
6a713ef241
| Author | SHA1 | Date | |
|---|---|---|---|
| 6a713ef241 | |||
| 6dafc082ed |
4
.github/workflows/lint.yaml
vendored
4
.github/workflows/lint.yaml
vendored
@@ -16,8 +16,8 @@ jobs:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 1.25.5
|
||||
go-version: 1.25.7
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: v2.7.2
|
||||
version: v2.8.0
|
||||
|
||||
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
run: |
|
||||
echo "VERSION=$(git describe --tags --always)" >> $GITHUB_ENV
|
||||
echo "GIT_COMMIT=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||
echo "GO_VERSION=1.25.5" >> $GITHUB_ENV
|
||||
echo "GO_VERSION=1.25.7" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Go
|
||||
if: github.event_name == 'release' || inputs.build_binaries
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ARG GO_VERSION=1.25.5
|
||||
ARG GO_VERSION=1.25.7
|
||||
|
||||
FROM docker.io/library/golang:${GO_VERSION}-alpine AS builder
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ version: "3"
|
||||
|
||||
vars:
|
||||
BIN_DIR: ./bin
|
||||
GOLANGCI_LINT_VERSION: v2.7.2
|
||||
GOLANGCI_LINT_VERSION: v2.8.0
|
||||
GOLANGCI: "{{.BIN_DIR}}/golangci-lint-{{.GOLANGCI_LINT_VERSION}}"
|
||||
|
||||
tasks:
|
||||
|
||||
@@ -55,16 +55,22 @@ func main() {
|
||||
*combinedConfig.DryRun,
|
||||
combinedConfig.Lua, combinedConfig.Js,
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, config.StyleRed.Render("[ERROR] ")+err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
_ = utilsErr.MustHandle(err,
|
||||
utilsErr.OnType(func(err types.ProxyDialError) error {
|
||||
fmt.Fprintln(os.Stderr, config.StyleRed.Render("[PROXY] ")+err.Error())
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}),
|
||||
utilsErr.OnSentinel(types.ErrScriptEmpty, func(err error) error {
|
||||
fmt.Fprintln(os.Stderr, config.StyleRed.Render("[SCRIPT] ")+err.Error())
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}),
|
||||
utilsErr.OnType(func(err types.ScriptLoadError) error {
|
||||
fmt.Fprintln(os.Stderr, config.StyleRed.Render("[SCRIPT] ")+err.Error())
|
||||
os.Exit(1)
|
||||
return nil
|
||||
}),
|
||||
)
|
||||
|
||||
srn.Start(ctx)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
||||
module go.aykhans.me/sarin
|
||||
|
||||
go 1.25.5
|
||||
go 1.25.7
|
||||
|
||||
require (
|
||||
github.com/brianvoe/gofakeit/v7 v7.14.0
|
||||
|
||||
@@ -638,10 +638,16 @@ func parseConfigFile(configFile types.ConfigFile, maxDepth int) (*Config, error)
|
||||
// - Escaped "@": strings starting with "@@" (literal "@" at start)
|
||||
// - File reference: "@/path/to/file" or "@./relative/path"
|
||||
// - URL reference: "@http://..." or "@https://..."
|
||||
//
|
||||
// It can return the following errors:
|
||||
// - types.ErrScriptEmpty
|
||||
// - types.ErrScriptSourceEmpty
|
||||
// - types.ErrScriptURLNoHost
|
||||
// - types.URLParseError
|
||||
func validateScriptSource(script string) error {
|
||||
// Empty script is invalid
|
||||
if script == "" {
|
||||
return errors.New("script cannot be empty")
|
||||
return types.ErrScriptEmpty
|
||||
}
|
||||
|
||||
// Not a file/URL reference - it's an inline script
|
||||
@@ -658,17 +664,17 @@ func validateScriptSource(script string) error {
|
||||
source := script[1:] // Remove the @ prefix
|
||||
|
||||
if source == "" {
|
||||
return errors.New("script source cannot be empty after @")
|
||||
return types.ErrScriptSourceEmpty
|
||||
}
|
||||
|
||||
// Check if it's a URL
|
||||
// Check if it's a http(s) URL
|
||||
if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") {
|
||||
parsedURL, err := url.Parse(source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
return types.NewURLParseError(source, err)
|
||||
}
|
||||
if parsedURL.Host == "" {
|
||||
return errors.New("URL must have a host")
|
||||
return types.ErrScriptURLNoHost
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -49,6 +49,10 @@ func (parser ConfigFileParser) Parse() (*Config, error) {
|
||||
}
|
||||
|
||||
// fetchFile retrieves file contents from a local path or HTTP/HTTPS URL.
|
||||
// It can return the following errors:
|
||||
// - types.FileReadError
|
||||
// - types.HTTPFetchError
|
||||
// - types.HTTPStatusError
|
||||
func fetchFile(ctx context.Context, src string) ([]byte, error) {
|
||||
if strings.HasPrefix(src, "http://") || strings.HasPrefix(src, "https://") {
|
||||
return fetchHTTP(ctx, src)
|
||||
@@ -57,25 +61,28 @@ func fetchFile(ctx context.Context, src string) ([]byte, error) {
|
||||
}
|
||||
|
||||
// fetchHTTP downloads file contents from an HTTP/HTTPS URL.
|
||||
// It can return the following errors:
|
||||
// - types.HTTPFetchError
|
||||
// - types.HTTPStatusError
|
||||
func fetchHTTP(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
return nil, types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch file: %w", err)
|
||||
return nil, types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch file: HTTP %d %s", resp.StatusCode, resp.Status)
|
||||
return nil, types.NewHTTPStatusError(url, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
return nil, types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
@@ -83,19 +90,21 @@ func fetchHTTP(ctx context.Context, url string) ([]byte, error) {
|
||||
|
||||
// fetchLocal reads file contents from the local filesystem.
|
||||
// It resolves relative paths from the current working directory.
|
||||
// It can return the following errors:
|
||||
// - types.FileReadError
|
||||
func fetchLocal(src string) ([]byte, error) {
|
||||
path := src
|
||||
if !filepath.IsAbs(src) {
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get working directory: %w", err)
|
||||
return nil, types.NewFileReadError(src, err)
|
||||
}
|
||||
path = filepath.Join(pwd, src)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path) //nolint:gosec
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||
return nil, types.NewFileReadError(path, err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"go.aykhans.me/sarin/internal/types"
|
||||
)
|
||||
|
||||
// It can return the following errors:
|
||||
// - types.TemplateParseError
|
||||
func validateTemplateString(value string, funcMap template.FuncMap) error {
|
||||
if value == "" {
|
||||
return nil
|
||||
@@ -15,7 +17,7 @@ func validateTemplateString(value string, funcMap template.FuncMap) error {
|
||||
|
||||
_, err := template.New("").Funcs(funcMap).Parse(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("template parse error: %w", err)
|
||||
return types.NewTemplateParseError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -95,6 +94,9 @@ func NewHostClients(
|
||||
return []*fasthttp.HostClient{client}, nil
|
||||
}
|
||||
|
||||
// NewProxyDialFunc creates a dial function for the given proxy URL.
|
||||
// It can return the following errors:
|
||||
// - types.ProxyUnsupportedSchemeError
|
||||
func NewProxyDialFunc(ctx context.Context, proxyURL *url.URL, timeout time.Duration) (fasthttp.DialFunc, error) {
|
||||
var (
|
||||
dialer fasthttp.DialFunc
|
||||
@@ -117,16 +119,14 @@ func NewProxyDialFunc(ctx context.Context, proxyURL *url.URL, timeout time.Durat
|
||||
case "https":
|
||||
dialer = fasthttpHTTPSDialerDualStackTimeout(proxyURL, timeout)
|
||||
default:
|
||||
return nil, errors.New("unsupported proxy scheme")
|
||||
}
|
||||
|
||||
if dialer == nil {
|
||||
return nil, errors.New("internal error: proxy dialer is nil")
|
||||
return nil, types.NewProxyUnsupportedSchemeError(proxyURL.Scheme)
|
||||
}
|
||||
|
||||
return dialer, nil
|
||||
}
|
||||
|
||||
// The returned dial function can return the following errors:
|
||||
// - types.ProxyDialError
|
||||
func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL, timeout time.Duration, resolveLocally bool) (fasthttp.DialFunc, error) {
|
||||
netDialer := &net.Dialer{}
|
||||
|
||||
@@ -147,12 +147,18 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyStr := proxyURL.String()
|
||||
|
||||
// Assert to ContextDialer for timeout support
|
||||
contextDialer, ok := socksDialer.(proxy.ContextDialer)
|
||||
if !ok {
|
||||
// Fallback without timeout (should not happen with net.Dialer)
|
||||
return func(addr string) (net.Conn, error) {
|
||||
return socksDialer.Dial("tcp", addr)
|
||||
conn, err := socksDialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
return conn, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -163,7 +169,7 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
|
||||
if resolveLocally {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Cap DNS resolution to half the timeout to reserve time for dial
|
||||
@@ -171,10 +177,10 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
|
||||
ips, err := net.DefaultResolver.LookupIP(dnsCtx, "ip", host)
|
||||
dnsCancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, errors.New("no IP addresses found for host: " + host)
|
||||
return nil, types.NewProxyDialError(proxyStr, types.NewProxyResolveError(host))
|
||||
}
|
||||
|
||||
// Use the first resolved IP
|
||||
@@ -184,16 +190,22 @@ func fasthttpSocksDialerDualStackTimeout(ctx context.Context, proxyURL *url.URL,
|
||||
// Use remaining time for dial
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= 0 {
|
||||
return nil, context.DeadlineExceeded
|
||||
return nil, types.NewProxyDialError(proxyStr, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, remaining)
|
||||
defer dialCancel()
|
||||
|
||||
return contextDialer.DialContext(dialCtx, "tcp", addr)
|
||||
conn, err := contextDialer.DialContext(dialCtx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
return conn, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// The returned dial function can return the following errors:
|
||||
// - types.ProxyDialError
|
||||
func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duration) fasthttp.DialFunc {
|
||||
proxyAddr := proxyURL.Host
|
||||
if proxyURL.Port() == "" {
|
||||
@@ -209,24 +221,26 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
|
||||
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
}
|
||||
|
||||
proxyStr := proxyURL.String()
|
||||
|
||||
return func(addr string) (net.Conn, error) {
|
||||
// Establish TCP connection to proxy with timeout
|
||||
start := time.Now()
|
||||
conn, err := fasthttp.DialDualStackTimeout(proxyAddr, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
remaining := timeout - time.Since(start)
|
||||
if remaining <= 0 {
|
||||
conn.Close() //nolint:errcheck,gosec
|
||||
return nil, context.DeadlineExceeded
|
||||
return nil, types.NewProxyDialError(proxyStr, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// Set deadline for the TLS handshake and CONNECT request
|
||||
if err := conn.SetDeadline(time.Now().Add(remaining)); err != nil {
|
||||
conn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Upgrade to TLS
|
||||
@@ -235,7 +249,7 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
|
||||
})
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Build and send CONNECT request
|
||||
@@ -251,7 +265,7 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
|
||||
|
||||
if err := connectReq.Write(tlsConn); err != nil {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Read response using buffered reader, but return wrapped connection
|
||||
@@ -260,19 +274,19 @@ func fasthttpHTTPSDialerDualStackTimeout(proxyURL *url.URL, timeout time.Duratio
|
||||
resp, err := http.ReadResponse(bufReader, connectReq)
|
||||
if err != nil {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
resp.Body.Close() //nolint:errcheck,gosec
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, errors.New("proxy CONNECT failed: " + resp.Status)
|
||||
return nil, types.NewProxyDialError(proxyStr, types.NewProxyConnectError(resp.Status))
|
||||
}
|
||||
|
||||
// Clear deadline for the tunneled connection
|
||||
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
|
||||
tlsConn.Close() //nolint:errcheck,gosec
|
||||
return nil, err
|
||||
return nil, types.NewProxyDialError(proxyStr, err)
|
||||
}
|
||||
|
||||
// Return wrapped connection that uses the buffered reader
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package sarin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -10,6 +9,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.aykhans.me/sarin/internal/types"
|
||||
)
|
||||
|
||||
// CachedFile holds the cached content and metadata of a file.
|
||||
@@ -31,6 +32,10 @@ func NewFileCache(requestTimeout time.Duration) *FileCache {
|
||||
|
||||
// GetOrLoad retrieves a file from cache or loads it using the provided source.
|
||||
// The source can be a local file path or an HTTP/HTTPS URL.
|
||||
// It can return the following errors:
|
||||
// - types.FileReadError
|
||||
// - types.HTTPFetchError
|
||||
// - types.HTTPStatusError
|
||||
func (fc *FileCache) GetOrLoad(source string) (*CachedFile, error) {
|
||||
if val, ok := fc.cache.Load(source); ok {
|
||||
return val.(*CachedFile), nil
|
||||
@@ -59,14 +64,21 @@ func (fc *FileCache) GetOrLoad(source string) (*CachedFile, error) {
|
||||
return actual.(*CachedFile), nil
|
||||
}
|
||||
|
||||
// readLocalFile reads a file from the local filesystem and returns its content and filename.
|
||||
// It can return the following errors:
|
||||
// - types.FileReadError
|
||||
func (fc *FileCache) readLocalFile(filePath string) ([]byte, string, error) {
|
||||
content, err := os.ReadFile(filePath) //nolint:gosec
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to read file %s: %w", filePath, err)
|
||||
return nil, "", types.NewFileReadError(filePath, err)
|
||||
}
|
||||
return content, filepath.Base(filePath), nil
|
||||
}
|
||||
|
||||
// fetchURL downloads file contents from an HTTP/HTTPS URL.
|
||||
// It can return the following errors:
|
||||
// - types.HTTPFetchError
|
||||
// - types.HTTPStatusError
|
||||
func (fc *FileCache) fetchURL(url string) ([]byte, string, error) {
|
||||
client := &http.Client{
|
||||
Timeout: fc.requestTimeout,
|
||||
@@ -74,17 +86,17 @@ func (fc *FileCache) fetchURL(url string) ([]byte, string, error) {
|
||||
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to fetch URL %s: %w", url, err)
|
||||
return nil, "", types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("failed to fetch URL %s: HTTP %d", url, resp.StatusCode)
|
||||
return nil, "", types.NewHTTPStatusError(url, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to read response body from %s: %w", url, err)
|
||||
return nil, "", types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
|
||||
// Extract filename from URL path
|
||||
|
||||
@@ -2,7 +2,6 @@ package sarin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math/rand/v2"
|
||||
"net/url"
|
||||
@@ -18,7 +17,7 @@ import (
|
||||
|
||||
type RequestGenerator func(*fasthttp.Request) error
|
||||
|
||||
type RequestGeneratorWithData func(*fasthttp.Request, any) error
|
||||
type requestDataGenerator func(*script.RequestData, any) error
|
||||
|
||||
type valuesData struct {
|
||||
Values map[string]string
|
||||
@@ -60,13 +59,22 @@ func NewRequestGenerator(
|
||||
|
||||
hasScripts := scriptTransformer != nil && !scriptTransformer.IsEmpty()
|
||||
|
||||
host := requestURL.Host
|
||||
scheme := requestURL.Scheme
|
||||
|
||||
reqData := &script.RequestData{
|
||||
Headers: make(map[string][]string),
|
||||
Params: make(map[string][]string),
|
||||
Cookies: make(map[string][]string),
|
||||
}
|
||||
|
||||
var (
|
||||
data valuesData
|
||||
path string
|
||||
err error
|
||||
)
|
||||
return func(req *fasthttp.Request) error {
|
||||
req.Header.SetHost(requestURL.Host)
|
||||
resetRequestData(reqData)
|
||||
|
||||
data, err = valuesGenerator()
|
||||
if err != nil {
|
||||
@@ -77,44 +85,39 @@ func NewRequestGenerator(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.SetRequestURI(path)
|
||||
reqData.Path = path
|
||||
|
||||
if err = methodGenerator(req, data); err != nil {
|
||||
if err = methodGenerator(reqData, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bodyTemplateFuncMapData.ClearFormDataContenType()
|
||||
if err = bodyGenerator(req, data); err != nil {
|
||||
if err = bodyGenerator(reqData, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = headersGenerator(req, data); err != nil {
|
||||
if err = headersGenerator(reqData, data); err != nil {
|
||||
return err
|
||||
}
|
||||
if bodyTemplateFuncMapData.GetFormDataContenType() != "" {
|
||||
req.Header.Add("Content-Type", bodyTemplateFuncMapData.GetFormDataContenType())
|
||||
reqData.Headers["Content-Type"] = append(reqData.Headers["Content-Type"], bodyTemplateFuncMapData.GetFormDataContenType())
|
||||
}
|
||||
|
||||
if err = paramsGenerator(req, data); err != nil {
|
||||
if err = paramsGenerator(reqData, data); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = cookiesGenerator(req, data); err != nil {
|
||||
if err = cookiesGenerator(reqData, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if requestURL.Scheme == "https" {
|
||||
req.URI().SetScheme("https")
|
||||
}
|
||||
|
||||
// Apply script transformations if any
|
||||
if hasScripts {
|
||||
reqData := script.RequestDataFromFastHTTP(req)
|
||||
if err = scriptTransformer.Transform(reqData); err != nil {
|
||||
return err
|
||||
}
|
||||
script.ApplyToFastHTTP(reqData, req)
|
||||
}
|
||||
|
||||
applyRequestDataToFastHTTP(reqData, req, host, scheme)
|
||||
|
||||
return nil
|
||||
}, isPathGeneratorDynamic ||
|
||||
isMethodGeneratorDynamic ||
|
||||
@@ -125,50 +128,92 @@ func NewRequestGenerator(
|
||||
hasScripts
|
||||
}
|
||||
|
||||
func NewMethodGeneratorFunc(localRand *rand.Rand, methods []string, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
|
||||
func resetRequestData(reqData *script.RequestData) {
|
||||
reqData.Method = ""
|
||||
reqData.Path = ""
|
||||
reqData.Body = ""
|
||||
clear(reqData.Headers)
|
||||
clear(reqData.Params)
|
||||
clear(reqData.Cookies)
|
||||
}
|
||||
|
||||
func applyRequestDataToFastHTTP(reqData *script.RequestData, req *fasthttp.Request, host, scheme string) {
|
||||
req.Header.SetHost(host)
|
||||
req.SetRequestURI(reqData.Path)
|
||||
req.Header.SetMethod(reqData.Method)
|
||||
req.SetBody([]byte(reqData.Body))
|
||||
|
||||
for k, values := range reqData.Headers {
|
||||
for _, v := range values {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
for k, values := range reqData.Params {
|
||||
for _, v := range values {
|
||||
req.URI().QueryArgs().Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
if len(reqData.Cookies) > 0 {
|
||||
cookieStrings := make([]string, 0, len(reqData.Cookies))
|
||||
for k, values := range reqData.Cookies {
|
||||
for _, v := range values {
|
||||
cookieStrings = append(cookieStrings, k+"="+v)
|
||||
}
|
||||
}
|
||||
req.Header.Add("Cookie", strings.Join(cookieStrings, "; "))
|
||||
}
|
||||
|
||||
if scheme == "https" {
|
||||
req.URI().SetScheme("https")
|
||||
}
|
||||
}
|
||||
|
||||
func NewMethodGeneratorFunc(localRand *rand.Rand, methods []string, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
|
||||
methodGenerator, isDynamic := buildStringSliceGenerator(localRand, methods, templateFunctions)
|
||||
|
||||
var (
|
||||
method string
|
||||
err error
|
||||
)
|
||||
return func(req *fasthttp.Request, data any) error {
|
||||
return func(reqData *script.RequestData, data any) error {
|
||||
method, err = methodGenerator()(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.SetMethod(method)
|
||||
reqData.Method = method
|
||||
return nil
|
||||
}, isDynamic
|
||||
}
|
||||
|
||||
func NewBodyGeneratorFunc(localRand *rand.Rand, bodies []string, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
|
||||
func NewBodyGeneratorFunc(localRand *rand.Rand, bodies []string, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
|
||||
bodyGenerator, isDynamic := buildStringSliceGenerator(localRand, bodies, templateFunctions)
|
||||
|
||||
var (
|
||||
body string
|
||||
err error
|
||||
)
|
||||
return func(req *fasthttp.Request, data any) error {
|
||||
return func(reqData *script.RequestData, data any) error {
|
||||
body, err = bodyGenerator()(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.SetBody([]byte(body))
|
||||
reqData.Body = body
|
||||
return nil
|
||||
}, isDynamic
|
||||
}
|
||||
|
||||
func NewParamsGeneratorFunc(localRand *rand.Rand, params types.Params, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
|
||||
func NewParamsGeneratorFunc(localRand *rand.Rand, params types.Params, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
|
||||
generators, isDynamic := buildKeyValueGenerators(localRand, params, templateFunctions)
|
||||
|
||||
var (
|
||||
key, value string
|
||||
err error
|
||||
)
|
||||
return func(req *fasthttp.Request, data any) error {
|
||||
return func(reqData *script.RequestData, data any) error {
|
||||
for _, gen := range generators {
|
||||
key, err = gen.Key(data)
|
||||
if err != nil {
|
||||
@@ -180,20 +225,20 @@ func NewParamsGeneratorFunc(localRand *rand.Rand, params types.Params, templateF
|
||||
return err
|
||||
}
|
||||
|
||||
req.URI().QueryArgs().Add(key, value)
|
||||
reqData.Params[key] = append(reqData.Params[key], value)
|
||||
}
|
||||
return nil
|
||||
}, isDynamic
|
||||
}
|
||||
|
||||
func NewHeadersGeneratorFunc(localRand *rand.Rand, headers types.Headers, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
|
||||
func NewHeadersGeneratorFunc(localRand *rand.Rand, headers types.Headers, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
|
||||
generators, isDynamic := buildKeyValueGenerators(localRand, headers, templateFunctions)
|
||||
|
||||
var (
|
||||
key, value string
|
||||
err error
|
||||
)
|
||||
return func(req *fasthttp.Request, data any) error {
|
||||
return func(reqData *script.RequestData, data any) error {
|
||||
for _, gen := range generators {
|
||||
key, err = gen.Key(data)
|
||||
if err != nil {
|
||||
@@ -205,22 +250,20 @@ func NewHeadersGeneratorFunc(localRand *rand.Rand, headers types.Headers, templa
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Add(key, value)
|
||||
reqData.Headers[key] = append(reqData.Headers[key], value)
|
||||
}
|
||||
return nil
|
||||
}, isDynamic
|
||||
}
|
||||
|
||||
func NewCookiesGeneratorFunc(localRand *rand.Rand, cookies types.Cookies, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) {
|
||||
func NewCookiesGeneratorFunc(localRand *rand.Rand, cookies types.Cookies, templateFunctions template.FuncMap) (requestDataGenerator, bool) {
|
||||
generators, isDynamic := buildKeyValueGenerators(localRand, cookies, templateFunctions)
|
||||
|
||||
var (
|
||||
key, value string
|
||||
err error
|
||||
)
|
||||
if len(generators) > 0 {
|
||||
return func(req *fasthttp.Request, data any) error {
|
||||
cookieStrings := make([]string, 0, len(generators))
|
||||
return func(reqData *script.RequestData, data any) error {
|
||||
for _, gen := range generators {
|
||||
key, err = gen.Key(data)
|
||||
if err != nil {
|
||||
@@ -232,14 +275,8 @@ func NewCookiesGeneratorFunc(localRand *rand.Rand, cookies types.Cookies, templa
|
||||
return err
|
||||
}
|
||||
|
||||
cookieStrings = append(cookieStrings, key+"="+value)
|
||||
reqData.Cookies[key] = append(reqData.Cookies[key], value)
|
||||
}
|
||||
req.Header.Add("Cookie", strings.Join(cookieStrings, "; "))
|
||||
return nil
|
||||
}, isDynamic
|
||||
}
|
||||
|
||||
return func(req *fasthttp.Request, data any) error {
|
||||
return nil
|
||||
}, isDynamic
|
||||
}
|
||||
@@ -261,12 +298,12 @@ func NewValuesGeneratorFunc(values []string, templateFunctions template.FuncMap)
|
||||
for _, generator := range generators {
|
||||
rendered, err = generator(nil)
|
||||
if err != nil {
|
||||
return valuesData{}, fmt.Errorf("values rendering: %w", err)
|
||||
return valuesData{}, types.NewTemplateRenderError(err)
|
||||
}
|
||||
|
||||
data, err = godotenv.Unmarshal(rendered)
|
||||
if err != nil {
|
||||
return valuesData{}, fmt.Errorf("values rendering: %w", err)
|
||||
return valuesData{}, types.NewTemplateRenderError(err)
|
||||
}
|
||||
|
||||
maps.Copy(result, data)
|
||||
@@ -283,7 +320,7 @@ func createTemplateFunc(value string, templateFunctions template.FuncMap) (func(
|
||||
return func(data any) (string, error) {
|
||||
var buf bytes.Buffer
|
||||
if err = tmpl.Execute(&buf, data); err != nil {
|
||||
return "", fmt.Errorf("template rendering: %w", err)
|
||||
return "", types.NewTemplateRenderError(err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}, true
|
||||
|
||||
@@ -59,7 +59,8 @@ type sarin struct {
|
||||
// NewSarin creates a new sarin instance for load testing.
|
||||
// It can return the following errors:
|
||||
// - types.ProxyDialError
|
||||
// - script loading errors
|
||||
// - types.ErrScriptEmpty
|
||||
// - types.ScriptLoadError
|
||||
func NewSarin(
|
||||
ctx context.Context,
|
||||
methods []string,
|
||||
@@ -216,7 +217,8 @@ func (q sarin) Worker(
|
||||
// Scripts are pre-validated in NewSarin, so this should not fail
|
||||
var scriptTransformer *script.Transformer
|
||||
if !q.scriptChain.IsEmpty() {
|
||||
scriptTransformer, err := q.scriptChain.NewTransformer()
|
||||
var err error
|
||||
scriptTransformer, err = q.scriptChain.NewTransformer()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package sarin
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"math/rand/v2"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v7"
|
||||
"go.aykhans.me/sarin/internal/types"
|
||||
)
|
||||
|
||||
func NewDefaultTemplateFuncMap(randSource rand.Source, fileCache *FileCache) template.FuncMap {
|
||||
@@ -90,7 +90,7 @@ func NewDefaultTemplateFuncMap(randSource rand.Source, fileCache *FileCache) tem
|
||||
// {{ file_Base64 "https://example.com/image.png" }}
|
||||
"file_Base64": func(source string) (string, error) {
|
||||
if fileCache == nil {
|
||||
return "", errors.New("file cache is not initialized")
|
||||
return "", types.ErrFileCacheNotInitialized
|
||||
}
|
||||
cached, err := fileCache.GetOrLoad(source)
|
||||
if err != nil {
|
||||
@@ -582,7 +582,7 @@ func NewDefaultBodyTemplateFuncMap(
|
||||
// {{ body_FormData "name" "John" "avatar" "@/path/to/photo.jpg" "doc" "@https://example.com/file.pdf" }}
|
||||
funcMap["body_FormData"] = func(pairs ...string) (string, error) {
|
||||
if len(pairs)%2 != 0 {
|
||||
return "", errors.New("body_FormData requires an even number of arguments (key-value pairs)")
|
||||
return "", types.ErrFormDataOddArgs
|
||||
}
|
||||
|
||||
var multipartData bytes.Buffer
|
||||
@@ -602,7 +602,7 @@ func NewDefaultBodyTemplateFuncMap(
|
||||
case strings.HasPrefix(val, "@"):
|
||||
// File (local path or remote URL)
|
||||
if fileCache == nil {
|
||||
return "", errors.New("file cache is not initialized")
|
||||
return "", types.ErrFileCacheNotInitialized
|
||||
}
|
||||
source := val[1:]
|
||||
cached, err := fileCache.GetOrLoad(source)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package script
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"go.aykhans.me/sarin/internal/types"
|
||||
)
|
||||
|
||||
// Chain holds the loaded script sources and can create engine instances.
|
||||
@@ -36,6 +34,8 @@ type Transformer struct {
|
||||
|
||||
// NewTransformer creates engine instances from the chain's sources.
|
||||
// Call this once per worker goroutine.
|
||||
// It can return the following errors:
|
||||
// - types.ScriptChainError
|
||||
func (c *Chain) NewTransformer() (*Transformer, error) {
|
||||
if c.IsEmpty() {
|
||||
return &Transformer{}, nil
|
||||
@@ -51,7 +51,7 @@ func (c *Chain) NewTransformer() (*Transformer, error) {
|
||||
engine, err := NewLuaEngine(src.Content)
|
||||
if err != nil {
|
||||
t.Close() // Clean up already created engines
|
||||
return nil, fmt.Errorf("lua script[%d]: %w", i, err)
|
||||
return nil, types.NewScriptChainError("lua", i, err)
|
||||
}
|
||||
t.luaEngines = append(t.luaEngines, engine)
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (c *Chain) NewTransformer() (*Transformer, error) {
|
||||
engine, err := NewJsEngine(src.Content)
|
||||
if err != nil {
|
||||
t.Close() // Clean up already created engines
|
||||
return nil, fmt.Errorf("js script[%d]: %w", i, err)
|
||||
return nil, types.NewScriptChainError("js", i, err)
|
||||
}
|
||||
t.jsEngines = append(t.jsEngines, engine)
|
||||
}
|
||||
@@ -71,18 +71,20 @@ func (c *Chain) NewTransformer() (*Transformer, error) {
|
||||
|
||||
// Transform applies all scripts to the request data.
|
||||
// Lua scripts run first, then JavaScript scripts.
|
||||
// It can return the following errors:
|
||||
// - types.ScriptChainError
|
||||
func (t *Transformer) Transform(req *RequestData) error {
|
||||
// Run Lua scripts
|
||||
for i, engine := range t.luaEngines {
|
||||
if err := engine.Transform(req); err != nil {
|
||||
return fmt.Errorf("lua script[%d]: %w", i, err)
|
||||
return types.NewScriptChainError("lua", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run JS scripts
|
||||
for i, engine := range t.jsEngines {
|
||||
if err := engine.Transform(req); err != nil {
|
||||
return fmt.Errorf("js script[%d]: %w", i, err)
|
||||
return types.NewScriptChainError("js", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,83 +105,3 @@ func (t *Transformer) Close() {
|
||||
func (t *Transformer) IsEmpty() bool {
|
||||
return len(t.luaEngines) == 0 && len(t.jsEngines) == 0
|
||||
}
|
||||
|
||||
// RequestDataFromFastHTTP extracts RequestData from a fasthttp.Request.
|
||||
func RequestDataFromFastHTTP(req *fasthttp.Request) *RequestData {
|
||||
data := &RequestData{
|
||||
Method: string(req.Header.Method()),
|
||||
URL: string(req.URI().FullURI()),
|
||||
Path: string(req.URI().Path()),
|
||||
Body: string(req.Body()),
|
||||
Headers: make(map[string][]string),
|
||||
Params: make(map[string][]string),
|
||||
Cookies: make(map[string][]string),
|
||||
}
|
||||
|
||||
// Extract headers (supports multiple values per key)
|
||||
req.Header.All()(func(key, value []byte) bool {
|
||||
k := string(key)
|
||||
data.Headers[k] = append(data.Headers[k], string(value))
|
||||
return true
|
||||
})
|
||||
|
||||
// Extract query params (supports multiple values per key)
|
||||
req.URI().QueryArgs().All()(func(key, value []byte) bool {
|
||||
k := string(key)
|
||||
data.Params[k] = append(data.Params[k], string(value))
|
||||
return true
|
||||
})
|
||||
|
||||
// Extract cookies (supports multiple values per key)
|
||||
req.Header.Cookies()(func(key, value []byte) bool {
|
||||
k := string(key)
|
||||
data.Cookies[k] = append(data.Cookies[k], string(value))
|
||||
return true
|
||||
})
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// ApplyToFastHTTP applies the modified RequestData back to a fasthttp.Request.
|
||||
func ApplyToFastHTTP(data *RequestData, req *fasthttp.Request) {
|
||||
// Method
|
||||
req.Header.SetMethod(data.Method)
|
||||
|
||||
// Path (preserve scheme and host)
|
||||
req.URI().SetPath(data.Path)
|
||||
|
||||
// Body
|
||||
req.SetBody([]byte(data.Body))
|
||||
|
||||
// Clear and set headers (supports multiple values per key)
|
||||
req.Header.All()(func(key, _ []byte) bool {
|
||||
keyStr := string(key)
|
||||
if keyStr != "Host" {
|
||||
req.Header.Del(keyStr)
|
||||
}
|
||||
return true
|
||||
})
|
||||
for k, values := range data.Headers {
|
||||
if k != "Host" { // Don't overwrite Host
|
||||
for _, v := range values {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear and set query params (supports multiple values per key)
|
||||
req.URI().QueryArgs().Reset()
|
||||
for k, values := range data.Params {
|
||||
for _, v := range values {
|
||||
req.URI().QueryArgs().Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear and set cookies (supports multiple values per key)
|
||||
req.Header.DelAllCookies()
|
||||
for k, values := range data.Cookies {
|
||||
for _, v := range values {
|
||||
req.Header.SetCookie(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ package script
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"go.aykhans.me/sarin/internal/types"
|
||||
)
|
||||
|
||||
// JsEngine implements the Engine interface using goja (JavaScript).
|
||||
@@ -20,27 +20,31 @@ type JsEngine struct {
|
||||
// Example JavaScript script:
|
||||
//
|
||||
// function transform(req) {
|
||||
// req.headers["X-Custom"] = "value";
|
||||
// req.headers["X-Custom"] = ["value"];
|
||||
// return req;
|
||||
// }
|
||||
//
|
||||
// It can return the following errors:
|
||||
// - types.ErrScriptTransformMissing
|
||||
// - types.ScriptExecutionError
|
||||
func NewJsEngine(scriptContent string) (*JsEngine, error) {
|
||||
vm := goja.New()
|
||||
|
||||
// Execute the script to define the transform function
|
||||
_, err := vm.RunString(scriptContent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute JavaScript script: %w", err)
|
||||
return nil, types.NewScriptExecutionError("JavaScript", err)
|
||||
}
|
||||
|
||||
// Get the transform function
|
||||
transformVal := vm.Get("transform")
|
||||
if transformVal == nil || goja.IsUndefined(transformVal) || goja.IsNull(transformVal) {
|
||||
return nil, errors.New("script must define a global 'transform' function")
|
||||
return nil, types.ErrScriptTransformMissing
|
||||
}
|
||||
|
||||
transform, ok := goja.AssertFunction(transformVal)
|
||||
if !ok {
|
||||
return nil, errors.New("'transform' must be a function")
|
||||
return nil, types.NewScriptExecutionError("JavaScript", errors.New("'transform' must be a function"))
|
||||
}
|
||||
|
||||
return &JsEngine{
|
||||
@@ -50,6 +54,8 @@ func NewJsEngine(scriptContent string) (*JsEngine, error) {
|
||||
}
|
||||
|
||||
// Transform executes the JavaScript transform function with the given request data.
|
||||
// It can return the following errors:
|
||||
// - types.ScriptExecutionError
|
||||
func (e *JsEngine) Transform(req *RequestData) error {
|
||||
// Convert RequestData to JavaScript object
|
||||
reqObj := e.requestDataToObject(req)
|
||||
@@ -57,12 +63,12 @@ func (e *JsEngine) Transform(req *RequestData) error {
|
||||
// Call transform(req)
|
||||
result, err := e.transform(goja.Undefined(), reqObj)
|
||||
if err != nil {
|
||||
return fmt.Errorf("JavaScript transform error: %w", err)
|
||||
return types.NewScriptExecutionError("JavaScript", err)
|
||||
}
|
||||
|
||||
// Update RequestData from the returned object
|
||||
if err := e.objectToRequestData(result, req); err != nil {
|
||||
return fmt.Errorf("failed to parse transform result: %w", err)
|
||||
return types.NewScriptExecutionError("JavaScript", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -80,7 +86,6 @@ func (e *JsEngine) requestDataToObject(req *RequestData) goja.Value {
|
||||
obj := e.runtime.NewObject()
|
||||
|
||||
_ = obj.Set("method", req.Method)
|
||||
_ = obj.Set("url", req.URL)
|
||||
_ = obj.Set("path", req.Path)
|
||||
_ = obj.Set("body", req.Body)
|
||||
|
||||
@@ -111,12 +116,12 @@ func (e *JsEngine) requestDataToObject(req *RequestData) goja.Value {
|
||||
// objectToRequestData updates RequestData from a JavaScript object.
|
||||
func (e *JsEngine) objectToRequestData(val goja.Value, req *RequestData) error {
|
||||
if val == nil || goja.IsUndefined(val) || goja.IsNull(val) {
|
||||
return errors.New("transform function must return an object")
|
||||
return types.ErrScriptTransformReturnObject
|
||||
}
|
||||
|
||||
obj := val.ToObject(e.runtime)
|
||||
if obj == nil {
|
||||
return errors.New("transform function must return an object")
|
||||
return types.ErrScriptTransformReturnObject
|
||||
}
|
||||
|
||||
// Method
|
||||
@@ -124,11 +129,6 @@ func (e *JsEngine) objectToRequestData(val goja.Value, req *RequestData) error {
|
||||
req.Method = v.String()
|
||||
}
|
||||
|
||||
// URL
|
||||
if v := obj.Get("url"); v != nil && !goja.IsUndefined(v) {
|
||||
req.URL = v.String()
|
||||
}
|
||||
|
||||
// Path
|
||||
if v := obj.Get("path"); v != nil && !goja.IsUndefined(v) {
|
||||
req.Path = v.String()
|
||||
@@ -159,7 +159,7 @@ func (e *JsEngine) objectToRequestData(val goja.Value, req *RequestData) error {
|
||||
|
||||
// stringSliceToArray converts a Go []string to a JavaScript array.
|
||||
func (e *JsEngine) stringSliceToArray(values []string) *goja.Object {
|
||||
ifaces := make([]interface{}, len(values))
|
||||
ifaces := make([]any, len(values))
|
||||
for i, v := range values {
|
||||
ifaces[i] = v
|
||||
}
|
||||
@@ -181,7 +181,7 @@ func (e *JsEngine) objectToStringSliceMap(obj *goja.Object) map[string][]string
|
||||
}
|
||||
|
||||
// Check if it's an array
|
||||
if arr, ok := v.Export().([]interface{}); ok {
|
||||
if arr, ok := v.Export().([]any); ok {
|
||||
var values []string
|
||||
for _, item := range arr {
|
||||
if s, ok := item.(string); ok {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package script
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
"go.aykhans.me/sarin/internal/types"
|
||||
)
|
||||
|
||||
// LuaEngine implements the Engine interface using gopher-lua.
|
||||
@@ -20,23 +20,27 @@ type LuaEngine struct {
|
||||
// Example Lua script:
|
||||
//
|
||||
// function transform(req)
|
||||
// req.headers["X-Custom"] = "value"
|
||||
// req.headers["X-Custom"] = {"value"}
|
||||
// return req
|
||||
// end
|
||||
//
|
||||
// It can return the following errors:
|
||||
// - types.ErrScriptTransformMissing
|
||||
// - types.ScriptExecutionError
|
||||
func NewLuaEngine(scriptContent string) (*LuaEngine, error) {
|
||||
L := lua.NewState()
|
||||
|
||||
// Execute the script to define the transform function
|
||||
if err := L.DoString(scriptContent); err != nil {
|
||||
L.Close()
|
||||
return nil, fmt.Errorf("failed to execute Lua script: %w", err)
|
||||
return nil, types.NewScriptExecutionError("Lua", err)
|
||||
}
|
||||
|
||||
// Get the transform function
|
||||
transform := L.GetGlobal("transform")
|
||||
if transform.Type() != lua.LTFunction {
|
||||
L.Close()
|
||||
return nil, errors.New("script must define a global 'transform' function")
|
||||
return nil, types.ErrScriptTransformMissing
|
||||
}
|
||||
|
||||
return &LuaEngine{
|
||||
@@ -46,6 +50,8 @@ func NewLuaEngine(scriptContent string) (*LuaEngine, error) {
|
||||
}
|
||||
|
||||
// Transform executes the Lua transform function with the given request data.
|
||||
// It can return the following errors:
|
||||
// - types.ScriptExecutionError
|
||||
func (e *LuaEngine) Transform(req *RequestData) error {
|
||||
// Convert RequestData to Lua table
|
||||
reqTable := e.requestDataToTable(req)
|
||||
@@ -54,7 +60,7 @@ func (e *LuaEngine) Transform(req *RequestData) error {
|
||||
e.state.Push(e.transform)
|
||||
e.state.Push(reqTable)
|
||||
if err := e.state.PCall(1, 1, nil); err != nil {
|
||||
return fmt.Errorf("lua transform error: %w", err)
|
||||
return types.NewScriptExecutionError("Lua", err)
|
||||
}
|
||||
|
||||
// Get the result
|
||||
@@ -62,7 +68,7 @@ func (e *LuaEngine) Transform(req *RequestData) error {
|
||||
e.state.Pop(1)
|
||||
|
||||
if result.Type() != lua.LTTable {
|
||||
return fmt.Errorf("transform function must return a table, got %s", result.Type())
|
||||
return types.NewScriptExecutionError("Lua", fmt.Errorf("transform function must return a table, got %s", result.Type()))
|
||||
}
|
||||
|
||||
// Update RequestData from the returned table
|
||||
@@ -84,7 +90,6 @@ func (e *LuaEngine) requestDataToTable(req *RequestData) *lua.LTable {
|
||||
t := L.NewTable()
|
||||
|
||||
t.RawSetString("method", lua.LString(req.Method))
|
||||
t.RawSetString("url", lua.LString(req.URL))
|
||||
t.RawSetString("path", lua.LString(req.Path))
|
||||
t.RawSetString("body", lua.LString(req.Body))
|
||||
|
||||
@@ -131,11 +136,6 @@ func (e *LuaEngine) tableToRequestData(t *lua.LTable, req *RequestData) {
|
||||
req.Method = string(v.(lua.LString))
|
||||
}
|
||||
|
||||
// URL
|
||||
if v := t.RawGetString("url"); v.Type() == lua.LTString {
|
||||
req.URL = string(v.(lua.LString))
|
||||
}
|
||||
|
||||
// Path
|
||||
if v := t.RawGetString("path"); v.Type() == lua.LTString {
|
||||
req.Path = string(v.(lua.LString))
|
||||
|
||||
@@ -2,14 +2,14 @@ package script
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.aykhans.me/sarin/internal/types"
|
||||
)
|
||||
|
||||
// RequestData represents the request data passed to scripts for transformation.
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
// Headers, Params, and Cookies use []string values to support multiple values per key.
|
||||
type RequestData struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Path string `json:"path"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
Params map[string][]string `json:"params"`
|
||||
@@ -56,9 +55,13 @@ type Source struct {
|
||||
// - Escaped "@": strings starting with "@@" (literal "@" at start, returns string without first @)
|
||||
// - File reference: "@/path/to/file" or "@./relative/path"
|
||||
// - URL reference: "@http://..." or "@https://..."
|
||||
//
|
||||
// It can return the following errors:
|
||||
// - types.ErrScriptEmpty
|
||||
// - types.ScriptLoadError
|
||||
func LoadSource(ctx context.Context, source string, engineType EngineType) (*Source, error) {
|
||||
if source == "" {
|
||||
return nil, errors.New("script source cannot be empty")
|
||||
return nil, types.ErrScriptEmpty
|
||||
}
|
||||
|
||||
var content string
|
||||
@@ -77,7 +80,7 @@ func LoadSource(ctx context.Context, source string, engineType EngineType) (*Sou
|
||||
content, err = readFile(ref)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load script from %q: %w", ref, err)
|
||||
return nil, types.NewScriptLoadError(ref, err)
|
||||
}
|
||||
default:
|
||||
// Inline script
|
||||
@@ -91,12 +94,15 @@ func LoadSource(ctx context.Context, source string, engineType EngineType) (*Sou
|
||||
}
|
||||
|
||||
// LoadSources loads multiple script sources.
|
||||
// It can return the following errors:
|
||||
// - types.ErrScriptEmpty
|
||||
// - types.ScriptLoadError
|
||||
func LoadSources(ctx context.Context, sources []string, engineType EngineType) ([]*Source, error) {
|
||||
loaded := make([]*Source, 0, len(sources))
|
||||
for i, src := range sources {
|
||||
for _, src := range sources {
|
||||
source, err := LoadSource(ctx, src, engineType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("script[%d]: %w", i, err)
|
||||
return nil, err
|
||||
}
|
||||
loaded = append(loaded, source)
|
||||
}
|
||||
@@ -106,6 +112,12 @@ func LoadSources(ctx context.Context, sources []string, engineType EngineType) (
|
||||
// ValidateScript validates a script source by loading it and checking syntax.
|
||||
// It loads the script (from file/URL/inline), parses it, and verifies
|
||||
// that a 'transform' function is defined.
|
||||
// It can return the following errors:
|
||||
// - types.ErrScriptEmpty
|
||||
// - types.ErrScriptTransformMissing
|
||||
// - types.ScriptLoadError
|
||||
// - types.ScriptExecutionError
|
||||
// - types.ScriptUnknownEngineError
|
||||
func ValidateScript(ctx context.Context, source string, engineType EngineType) error {
|
||||
// Load the script source
|
||||
src, err := LoadSource(ctx, source, engineType)
|
||||
@@ -121,7 +133,7 @@ func ValidateScript(ctx context.Context, source string, engineType EngineType) e
|
||||
case EngineTypeJavaScript:
|
||||
engine, err = NewJsEngine(src.Content)
|
||||
default:
|
||||
return fmt.Errorf("unknown engine type: %s", engineType)
|
||||
return types.NewScriptUnknownEngineError(string(engineType))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -134,56 +146,67 @@ func ValidateScript(ctx context.Context, source string, engineType EngineType) e
|
||||
}
|
||||
|
||||
// ValidateScripts validates multiple script sources.
|
||||
// It can return the following errors:
|
||||
// - types.ErrScriptEmpty
|
||||
// - types.ErrScriptTransformMissing
|
||||
// - types.ScriptLoadError
|
||||
// - types.ScriptExecutionError
|
||||
// - types.ScriptUnknownEngineError
|
||||
func ValidateScripts(ctx context.Context, sources []string, engineType EngineType) error {
|
||||
for i, src := range sources {
|
||||
for _, src := range sources {
|
||||
if err := ValidateScript(ctx, src, engineType); err != nil {
|
||||
return fmt.Errorf("script[%d]: %w", i, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchURL downloads content from an HTTP/HTTPS URL.
|
||||
// It can return the following errors:
|
||||
// - types.HTTPFetchError
|
||||
// - types.HTTPStatusError
|
||||
func fetchURL(ctx context.Context, url string) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
return "", types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch: %w", err)
|
||||
return "", types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d %s", resp.StatusCode, resp.Status)
|
||||
return "", types.NewHTTPStatusError(url, resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
return "", types.NewHTTPFetchError(url, err)
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// readFile reads content from a local file.
|
||||
// It can return the following errors:
|
||||
// - types.FileReadError
|
||||
func readFile(path string) (string, error) {
|
||||
if !filepath.IsAbs(path) {
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get working directory: %w", err)
|
||||
return "", types.NewFileReadError(path, err)
|
||||
}
|
||||
path = filepath.Join(pwd, path)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path) //nolint:gosec
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
return "", types.NewFileReadError(path, err)
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
|
||||
@@ -6,16 +6,12 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// General
|
||||
ErrNoError = errors.New("no error (internal)")
|
||||
|
||||
// CLI
|
||||
ErrCLINoArgs = errors.New("CLI expects arguments but received none")
|
||||
)
|
||||
|
||||
// ======================================== General ========================================
|
||||
|
||||
var (
|
||||
ErrNoError = errors.New("no error (internal)")
|
||||
)
|
||||
|
||||
type FieldParseError struct {
|
||||
Field string
|
||||
Value string
|
||||
@@ -131,8 +127,147 @@ func (e UnmarshalError) Unwrap() error {
|
||||
return e.error
|
||||
}
|
||||
|
||||
// ======================================== General I/O ========================================
|
||||
|
||||
type FileReadError struct {
|
||||
Path string
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewFileReadError(path string, err error) FileReadError {
|
||||
if err == nil {
|
||||
err = ErrNoError
|
||||
}
|
||||
return FileReadError{path, err}
|
||||
}
|
||||
|
||||
func (e FileReadError) Error() string {
|
||||
return fmt.Sprintf("failed to read file %s: %v", e.Path, e.Err)
|
||||
}
|
||||
|
||||
func (e FileReadError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type HTTPFetchError struct {
|
||||
URL string
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewHTTPFetchError(url string, err error) HTTPFetchError {
|
||||
if err == nil {
|
||||
err = ErrNoError
|
||||
}
|
||||
return HTTPFetchError{url, err}
|
||||
}
|
||||
|
||||
func (e HTTPFetchError) Error() string {
|
||||
return fmt.Sprintf("failed to fetch %s: %v", e.URL, e.Err)
|
||||
}
|
||||
|
||||
func (e HTTPFetchError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type HTTPStatusError struct {
|
||||
URL string
|
||||
StatusCode int
|
||||
Status string
|
||||
}
|
||||
|
||||
func NewHTTPStatusError(url string, statusCode int, status string) HTTPStatusError {
|
||||
return HTTPStatusError{url, statusCode, status}
|
||||
}
|
||||
|
||||
func (e HTTPStatusError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d %s (url: %s)", e.StatusCode, e.Status, e.URL)
|
||||
}
|
||||
|
||||
type URLParseError struct {
|
||||
URL string
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewURLParseError(url string, err error) URLParseError {
|
||||
if err == nil {
|
||||
err = ErrNoError
|
||||
}
|
||||
return URLParseError{url, err}
|
||||
}
|
||||
|
||||
func (e URLParseError) Error() string {
|
||||
return fmt.Sprintf("invalid URL %q: %v", e.URL, e.Err)
|
||||
}
|
||||
|
||||
func (e URLParseError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// ======================================== Template ========================================
|
||||
|
||||
var (
|
||||
ErrFileCacheNotInitialized = errors.New("file cache is not initialized")
|
||||
ErrFormDataOddArgs = errors.New("body_FormData requires an even number of arguments (key-value pairs)")
|
||||
)
|
||||
|
||||
type TemplateParseError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewTemplateParseError(err error) TemplateParseError {
|
||||
if err == nil {
|
||||
err = ErrNoError
|
||||
}
|
||||
return TemplateParseError{err}
|
||||
}
|
||||
|
||||
func (e TemplateParseError) Error() string {
|
||||
return "template parse error: " + e.Err.Error()
|
||||
}
|
||||
|
||||
func (e TemplateParseError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type TemplateRenderError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewTemplateRenderError(err error) TemplateRenderError {
|
||||
if err == nil {
|
||||
err = ErrNoError
|
||||
}
|
||||
return TemplateRenderError{err}
|
||||
}
|
||||
|
||||
func (e TemplateRenderError) Error() string {
|
||||
return "template rendering: " + e.Err.Error()
|
||||
}
|
||||
|
||||
func (e TemplateRenderError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// ======================================== YAML ========================================
|
||||
|
||||
type YAMLFormatError struct {
|
||||
Detail string
|
||||
}
|
||||
|
||||
func NewYAMLFormatError(detail string) YAMLFormatError {
|
||||
return YAMLFormatError{detail}
|
||||
}
|
||||
|
||||
func (e YAMLFormatError) Error() string {
|
||||
return e.Detail
|
||||
}
|
||||
|
||||
// ======================================== CLI ========================================
|
||||
|
||||
var (
|
||||
ErrCLINoArgs = errors.New("CLI expects arguments but received none")
|
||||
)
|
||||
|
||||
type CLIUnexpectedArgsError struct {
|
||||
Args []string
|
||||
}
|
||||
@@ -168,6 +303,61 @@ func (e ConfigFileReadError) Unwrap() error {
|
||||
|
||||
// ======================================== Proxy ========================================
|
||||
|
||||
type ProxyUnsupportedSchemeError struct {
|
||||
Scheme string
|
||||
}
|
||||
|
||||
func NewProxyUnsupportedSchemeError(scheme string) ProxyUnsupportedSchemeError {
|
||||
return ProxyUnsupportedSchemeError{scheme}
|
||||
}
|
||||
|
||||
func (e ProxyUnsupportedSchemeError) Error() string {
|
||||
return "unsupported proxy scheme: " + e.Scheme
|
||||
}
|
||||
|
||||
type ProxyParseError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewProxyParseError(err error) ProxyParseError {
|
||||
if err == nil {
|
||||
err = ErrNoError
|
||||
}
|
||||
return ProxyParseError{err}
|
||||
}
|
||||
|
||||
func (e ProxyParseError) Error() string {
|
||||
return "failed to parse proxy URL: " + e.Err.Error()
|
||||
}
|
||||
|
||||
func (e ProxyParseError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type ProxyConnectError struct {
|
||||
Status string
|
||||
}
|
||||
|
||||
func NewProxyConnectError(status string) ProxyConnectError {
|
||||
return ProxyConnectError{status}
|
||||
}
|
||||
|
||||
func (e ProxyConnectError) Error() string {
|
||||
return "proxy CONNECT failed: " + e.Status
|
||||
}
|
||||
|
||||
type ProxyResolveError struct {
|
||||
Host string
|
||||
}
|
||||
|
||||
func NewProxyResolveError(host string) ProxyResolveError {
|
||||
return ProxyResolveError{host}
|
||||
}
|
||||
|
||||
func (e ProxyResolveError) Error() string {
|
||||
return "no IP addresses found for host: " + e.Host
|
||||
}
|
||||
|
||||
type ProxyDialError struct {
|
||||
Proxy string
|
||||
Err error
|
||||
@@ -187,3 +377,86 @@ func (e ProxyDialError) Error() string {
|
||||
func (e ProxyDialError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// ======================================== Script ========================================
|
||||
|
||||
var (
|
||||
ErrScriptEmpty = errors.New("script cannot be empty")
|
||||
ErrScriptSourceEmpty = errors.New("script source cannot be empty after @")
|
||||
ErrScriptTransformMissing = errors.New("script must define a global 'transform' function")
|
||||
ErrScriptTransformReturnObject = errors.New("transform function must return an object")
|
||||
ErrScriptURLNoHost = errors.New("script URL must have a host")
|
||||
)
|
||||
|
||||
type ScriptLoadError struct {
|
||||
Source string
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewScriptLoadError(source string, err error) ScriptLoadError {
|
||||
if err == nil {
|
||||
err = ErrNoError
|
||||
}
|
||||
return ScriptLoadError{source, err}
|
||||
}
|
||||
|
||||
func (e ScriptLoadError) Error() string {
|
||||
return fmt.Sprintf("failed to load script from %q: %v", e.Source, e.Err)
|
||||
}
|
||||
|
||||
func (e ScriptLoadError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type ScriptExecutionError struct {
|
||||
EngineType string
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewScriptExecutionError(engineType string, err error) ScriptExecutionError {
|
||||
if err == nil {
|
||||
err = ErrNoError
|
||||
}
|
||||
return ScriptExecutionError{engineType, err}
|
||||
}
|
||||
|
||||
func (e ScriptExecutionError) Error() string {
|
||||
return fmt.Sprintf("%s script error: %v", e.EngineType, e.Err)
|
||||
}
|
||||
|
||||
func (e ScriptExecutionError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type ScriptChainError struct {
|
||||
EngineType string
|
||||
Index int
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewScriptChainError(engineType string, index int, err error) ScriptChainError {
|
||||
if err == nil {
|
||||
err = ErrNoError
|
||||
}
|
||||
return ScriptChainError{engineType, index, err}
|
||||
}
|
||||
|
||||
func (e ScriptChainError) Error() string {
|
||||
return fmt.Sprintf("%s script[%d]: %v", e.EngineType, e.Index, e.Err)
|
||||
}
|
||||
|
||||
func (e ScriptChainError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type ScriptUnknownEngineError struct {
|
||||
EngineType string
|
||||
}
|
||||
|
||||
func NewScriptUnknownEngineError(engineType string) ScriptUnknownEngineError {
|
||||
return ScriptUnknownEngineError{engineType}
|
||||
}
|
||||
|
||||
func (e ScriptUnknownEngineError) Error() string {
|
||||
return "unknown engine type: " + e.EngineType
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
@@ -17,6 +16,9 @@ func (proxies *Proxies) Append(proxy ...Proxy) {
|
||||
*proxies = append(*proxies, proxy...)
|
||||
}
|
||||
|
||||
// Parse parses a raw proxy string and appends it to the list.
|
||||
// It can return the following errors:
|
||||
// - ProxyParseError
|
||||
func (proxies *Proxies) Parse(rawValue string) error {
|
||||
parsedProxy, err := ParseProxy(rawValue)
|
||||
if err != nil {
|
||||
@@ -27,10 +29,13 @@ func (proxies *Proxies) Parse(rawValue string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseProxy parses a raw proxy URL string into a Proxy.
|
||||
// It can return the following errors:
|
||||
// - ProxyParseError
|
||||
func ParseProxy(rawValue string) (*Proxy, error) {
|
||||
urlParsed, err := url.Parse(rawValue)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse proxy URL: %w", err)
|
||||
return nil, NewProxyParseError(err)
|
||||
}
|
||||
|
||||
proxyParsed := Proxy(*urlParsed)
|
||||
|
||||
Reference in New Issue
Block a user