diff --git a/internal/sarin/request.go b/internal/sarin/request.go index 870ebf9..f2088ca 100644 --- a/internal/sarin/request.go +++ b/internal/sarin/request.go @@ -17,7 +17,7 @@ import ( type RequestGenerator func(*fasthttp.Request) error -type RequestGeneratorWithData func(*fasthttp.Request, any) error +type requestDataGenerator func(*script.RequestData, any) error type valuesData struct { Values map[string]string @@ -59,13 +59,22 @@ func NewRequestGenerator( hasScripts := scriptTransformer != nil && !scriptTransformer.IsEmpty() + host := requestURL.Host + scheme := requestURL.Scheme + + reqData := &script.RequestData{ + Headers: make(map[string][]string), + Params: make(map[string][]string), + Cookies: make(map[string][]string), + } + var ( data valuesData path string err error ) return func(req *fasthttp.Request) error { - req.Header.SetHost(requestURL.Host) + resetRequestData(reqData) data, err = valuesGenerator() if err != nil { @@ -76,44 +85,39 @@ func NewRequestGenerator( if err != nil { return err } - req.SetRequestURI(path) + reqData.Path = path - if err = methodGenerator(req, data); err != nil { + if err = methodGenerator(reqData, data); err != nil { return err } bodyTemplateFuncMapData.ClearFormDataContenType() - if err = bodyGenerator(req, data); err != nil { + if err = bodyGenerator(reqData, data); err != nil { return err } - if err = headersGenerator(req, data); err != nil { + if err = headersGenerator(reqData, data); err != nil { return err } if bodyTemplateFuncMapData.GetFormDataContenType() != "" { - req.Header.Add("Content-Type", bodyTemplateFuncMapData.GetFormDataContenType()) + reqData.Headers["Content-Type"] = append(reqData.Headers["Content-Type"], bodyTemplateFuncMapData.GetFormDataContenType()) } - if err = paramsGenerator(req, data); err != nil { + if err = paramsGenerator(reqData, data); err != nil { return err } - if err = cookiesGenerator(req, data); err != nil { + if err = cookiesGenerator(reqData, data); err != nil { return err } - if requestURL.Scheme == "https" { - req.URI().SetScheme("https") - } - - // Apply script transformations if any if hasScripts { - reqData := script.RequestDataFromFastHTTP(req) if err = scriptTransformer.Transform(reqData); err != nil { return err } - script.ApplyToFastHTTP(reqData, req) } + applyRequestDataToFastHTTP(reqData, req, host, scheme) + return nil }, isPathGeneratorDynamic || isMethodGeneratorDynamic || @@ -124,50 +128,92 @@ func NewRequestGenerator( hasScripts } -func NewMethodGeneratorFunc(localRand *rand.Rand, methods []string, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) { +func resetRequestData(reqData *script.RequestData) { + reqData.Method = "" + reqData.Path = "" + reqData.Body = "" + clear(reqData.Headers) + clear(reqData.Params) + clear(reqData.Cookies) +} + +func applyRequestDataToFastHTTP(reqData *script.RequestData, req *fasthttp.Request, host, scheme string) { + req.Header.SetHost(host) + req.SetRequestURI(reqData.Path) + req.Header.SetMethod(reqData.Method) + req.SetBody([]byte(reqData.Body)) + + for k, values := range reqData.Headers { + for _, v := range values { + req.Header.Add(k, v) + } + } + + for k, values := range reqData.Params { + for _, v := range values { + req.URI().QueryArgs().Add(k, v) + } + } + + if len(reqData.Cookies) > 0 { + cookieStrings := make([]string, 0, len(reqData.Cookies)) + for k, values := range reqData.Cookies { + for _, v := range values { + cookieStrings = append(cookieStrings, k+"="+v) + } + } + req.Header.Add("Cookie", strings.Join(cookieStrings, "; ")) + } + + if scheme == "https" { + req.URI().SetScheme("https") + } +} + +func NewMethodGeneratorFunc(localRand *rand.Rand, methods []string, templateFunctions template.FuncMap) (requestDataGenerator, bool) { methodGenerator, isDynamic := buildStringSliceGenerator(localRand, methods, templateFunctions) var ( method string err error ) - return func(req *fasthttp.Request, data any) error { + return func(reqData *script.RequestData, data any) error { method, err = methodGenerator()(data) if err != nil { return err } - req.Header.SetMethod(method) + reqData.Method = method return nil }, isDynamic } -func NewBodyGeneratorFunc(localRand *rand.Rand, bodies []string, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) { +func NewBodyGeneratorFunc(localRand *rand.Rand, bodies []string, templateFunctions template.FuncMap) (requestDataGenerator, bool) { bodyGenerator, isDynamic := buildStringSliceGenerator(localRand, bodies, templateFunctions) var ( body string err error ) - return func(req *fasthttp.Request, data any) error { + return func(reqData *script.RequestData, data any) error { body, err = bodyGenerator()(data) if err != nil { return err } - req.SetBody([]byte(body)) + reqData.Body = body return nil }, isDynamic } -func NewParamsGeneratorFunc(localRand *rand.Rand, params types.Params, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) { +func NewParamsGeneratorFunc(localRand *rand.Rand, params types.Params, templateFunctions template.FuncMap) (requestDataGenerator, bool) { generators, isDynamic := buildKeyValueGenerators(localRand, params, templateFunctions) var ( key, value string err error ) - return func(req *fasthttp.Request, data any) error { + return func(reqData *script.RequestData, data any) error { for _, gen := range generators { key, err = gen.Key(data) if err != nil { @@ -179,20 +225,20 @@ func NewParamsGeneratorFunc(localRand *rand.Rand, params types.Params, templateF return err } - req.URI().QueryArgs().Add(key, value) + reqData.Params[key] = append(reqData.Params[key], value) } return nil }, isDynamic } -func NewHeadersGeneratorFunc(localRand *rand.Rand, headers types.Headers, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) { +func NewHeadersGeneratorFunc(localRand *rand.Rand, headers types.Headers, templateFunctions template.FuncMap) (requestDataGenerator, bool) { generators, isDynamic := buildKeyValueGenerators(localRand, headers, templateFunctions) var ( key, value string err error ) - return func(req *fasthttp.Request, data any) error { + return func(reqData *script.RequestData, data any) error { for _, gen := range generators { key, err = gen.Key(data) if err != nil { @@ -204,41 +250,33 @@ func NewHeadersGeneratorFunc(localRand *rand.Rand, headers types.Headers, templa return err } - req.Header.Add(key, value) + reqData.Headers[key] = append(reqData.Headers[key], value) } return nil }, isDynamic } -func NewCookiesGeneratorFunc(localRand *rand.Rand, cookies types.Cookies, templateFunctions template.FuncMap) (RequestGeneratorWithData, bool) { +func NewCookiesGeneratorFunc(localRand *rand.Rand, cookies types.Cookies, templateFunctions template.FuncMap) (requestDataGenerator, bool) { generators, isDynamic := buildKeyValueGenerators(localRand, cookies, templateFunctions) var ( key, value string err error ) - if len(generators) > 0 { - return func(req *fasthttp.Request, data any) error { - cookieStrings := make([]string, 0, len(generators)) - for _, gen := range generators { - key, err = gen.Key(data) - if err != nil { - return err - } - - value, err = gen.Value()(data) - if err != nil { - return err - } - - cookieStrings = append(cookieStrings, key+"="+value) + return func(reqData *script.RequestData, data any) error { + for _, gen := range generators { + key, err = gen.Key(data) + if err != nil { + return err } - req.Header.Add("Cookie", strings.Join(cookieStrings, "; ")) - return nil - }, isDynamic - } - return func(req *fasthttp.Request, data any) error { + value, err = gen.Value()(data) + if err != nil { + return err + } + + reqData.Cookies[key] = append(reqData.Cookies[key], value) + } return nil }, isDynamic } diff --git a/internal/script/chain.go b/internal/script/chain.go index 0871fcf..934c98c 100644 --- a/internal/script/chain.go +++ b/internal/script/chain.go @@ -1,7 +1,6 @@ package script import ( - "github.com/valyala/fasthttp" "go.aykhans.me/sarin/internal/types" ) @@ -106,83 +105,3 @@ func (t *Transformer) Close() { func (t *Transformer) IsEmpty() bool { return len(t.luaEngines) == 0 && len(t.jsEngines) == 0 } - -// RequestDataFromFastHTTP extracts RequestData from a fasthttp.Request. -func RequestDataFromFastHTTP(req *fasthttp.Request) *RequestData { - data := &RequestData{ - Method: string(req.Header.Method()), - URL: string(req.URI().FullURI()), - Path: string(req.URI().Path()), - Body: string(req.Body()), - Headers: make(map[string][]string), - Params: make(map[string][]string), - Cookies: make(map[string][]string), - } - - // Extract headers (supports multiple values per key) - req.Header.All()(func(key, value []byte) bool { - k := string(key) - data.Headers[k] = append(data.Headers[k], string(value)) - return true - }) - - // Extract query params (supports multiple values per key) - req.URI().QueryArgs().All()(func(key, value []byte) bool { - k := string(key) - data.Params[k] = append(data.Params[k], string(value)) - return true - }) - - // Extract cookies (supports multiple values per key) - req.Header.Cookies()(func(key, value []byte) bool { - k := string(key) - data.Cookies[k] = append(data.Cookies[k], string(value)) - return true - }) - - return data -} - -// ApplyToFastHTTP applies the modified RequestData back to a fasthttp.Request. -func ApplyToFastHTTP(data *RequestData, req *fasthttp.Request) { - // Method - req.Header.SetMethod(data.Method) - - // Path (preserve scheme and host) - req.URI().SetPath(data.Path) - - // Body - req.SetBody([]byte(data.Body)) - - // Clear and set headers (supports multiple values per key) - req.Header.All()(func(key, _ []byte) bool { - keyStr := string(key) - if keyStr != "Host" { - req.Header.Del(keyStr) - } - return true - }) - for k, values := range data.Headers { - if k != "Host" { // Don't overwrite Host - for _, v := range values { - req.Header.Add(k, v) - } - } - } - - // Clear and set query params (supports multiple values per key) - req.URI().QueryArgs().Reset() - for k, values := range data.Params { - for _, v := range values { - req.URI().QueryArgs().Add(k, v) - } - } - - // Clear and set cookies (supports multiple values per key) - req.Header.DelAllCookies() - for k, values := range data.Cookies { - for _, v := range values { - req.Header.SetCookie(k, v) - } - } -} diff --git a/internal/script/js.go b/internal/script/js.go index 4e3157f..cfaddfe 100644 --- a/internal/script/js.go +++ b/internal/script/js.go @@ -86,7 +86,6 @@ func (e *JsEngine) requestDataToObject(req *RequestData) goja.Value { obj := e.runtime.NewObject() _ = obj.Set("method", req.Method) - _ = obj.Set("url", req.URL) _ = obj.Set("path", req.Path) _ = obj.Set("body", req.Body) @@ -130,11 +129,6 @@ func (e *JsEngine) objectToRequestData(val goja.Value, req *RequestData) error { req.Method = v.String() } - // URL - if v := obj.Get("url"); v != nil && !goja.IsUndefined(v) { - req.URL = v.String() - } - // Path if v := obj.Get("path"); v != nil && !goja.IsUndefined(v) { req.Path = v.String() diff --git a/internal/script/lua.go b/internal/script/lua.go index 203ac95..918792a 100644 --- a/internal/script/lua.go +++ b/internal/script/lua.go @@ -90,7 +90,6 @@ func (e *LuaEngine) requestDataToTable(req *RequestData) *lua.LTable { t := L.NewTable() t.RawSetString("method", lua.LString(req.Method)) - t.RawSetString("url", lua.LString(req.URL)) t.RawSetString("path", lua.LString(req.Path)) t.RawSetString("body", lua.LString(req.Body)) @@ -137,11 +136,6 @@ func (e *LuaEngine) tableToRequestData(t *lua.LTable, req *RequestData) { req.Method = string(v.(lua.LString)) } - // URL - if v := t.RawGetString("url"); v.Type() == lua.LTString { - req.URL = string(v.(lua.LString)) - } - // Path if v := t.RawGetString("path"); v.Type() == lua.LTString { req.Path = string(v.(lua.LString)) diff --git a/internal/script/script.go b/internal/script/script.go index 607253b..b0e5f92 100644 --- a/internal/script/script.go +++ b/internal/script/script.go @@ -17,7 +17,6 @@ import ( // Headers, Params, and Cookies use []string values to support multiple values per key. type RequestData struct { Method string `json:"method"` - URL string `json:"url"` Path string `json:"path"` Headers map[string][]string `json:"headers"` Params map[string][]string `json:"params"`