diff --git a/custom_errors/errors.go b/custom_errors/errors.go index 6bc7ab2..890d5a2 100644 --- a/custom_errors/errors.go +++ b/custom_errors/errors.go @@ -10,6 +10,8 @@ import ( var ( ErrInvalidJSON = errors.New("invalid JSON file") ErrInvalidFile = errors.New("invalid file") + ErrInterrupt = errors.New("interrupted") + ErrNoInternet = errors.New("no internet connection") ) func As(err error, target any) bool { diff --git a/main.go b/main.go index 14fa5cb..faa6a9d 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,12 @@ package main import ( + "context" "net/url" "os" + "os/signal" "strings" + "syscall" "time" "github.com/aykhans/dodo/config" @@ -79,6 +82,24 @@ func main() { } dodoConf.Print() - responses := requests.Run(dodoConf) + ctx, cancel := context.WithCancel(context.Background()) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigChan + cancel() + }() + + responses, err := requests.Run(ctx, dodoConf) + if err != nil { + if customerrors.Is(err, customerrors.ErrInterrupt) { + utils.PrintlnC(utils.Colors.Yellow, err.Error()) + return + } else if customerrors.Is(err, customerrors.ErrNoInternet) { + utils.PrintAndExit("No internet connection") + } + panic(err) + } + responses.Print() } diff --git a/requests/requests.go b/requests/requests.go index d9aa41e..6711fb5 100644 --- a/requests/requests.go +++ b/requests/requests.go @@ -1,6 +1,7 @@ package requests import ( + "context" "fmt" "math/rand" "net/url" @@ -9,6 +10,7 @@ import ( "time" "github.com/aykhans/dodo/config" + customerrors "github.com/aykhans/dodo/custom_errors" "github.com/aykhans/dodo/readers" "github.com/aykhans/dodo/utils" "github.com/jedib0t/go-pretty/v6/progress" @@ -91,17 +93,27 @@ func (respones *Responses) Print() { } // Run executes the HTTP requests based on the provided request configuration. -// It returns the Responses type, which contains the responses received from all the requests. -func Run(requestConfig *config.RequestConfig) Responses { - if !checkConnection() { - utils.PrintAndExit("No internet connection") +// It checks for internet connection and returns an error if there is no connection. +// If the context is canceled while checking proxies, it returns the ErrInterrupt. +// If the context is canceled while sending requests, it returns the response objects obtained so far. +func Run(ctx context.Context, requestConfig *config.RequestConfig) (Responses, error) { + checkConnectionCtx, checkConnectionCtxCancel := context.WithTimeout(ctx, 8*time.Second) + if !checkConnection(checkConnectionCtx) { + checkConnectionCtxCancel() + return nil, customerrors.ErrNoInternet } + checkConnectionCtxCancel() + clientFunc := getClientFunc( + ctx, requestConfig.Timeout, requestConfig.Proxies, requestConfig.GetValidDodosCountForRequests(), requestConfig.URL, ) + if clientFunc == nil { + return nil, customerrors.ErrInterrupt + } request := newRequest( requestConfig.URL, @@ -113,14 +125,18 @@ func Run(requestConfig *config.RequestConfig) Responses { ) defer fasthttp.ReleaseRequest(request) responses := releaseDodos( + ctx, request, requestConfig.Timeout, clientFunc, requestConfig.GetValidDodosCountForRequests(), requestConfig.RequestCount, ) + if ctx.Err() != nil && len(responses) == 0 { + return nil, customerrors.ErrInterrupt + } - return responses + return responses, nil } // releaseDodos sends multiple HTTP requests concurrently using multiple "dodos" (goroutines). @@ -128,6 +144,7 @@ func Run(requestConfig *config.RequestConfig) Responses { // dodosCount as the number of goroutines to be used, and requestCount as the total number of requests to be sent. // It returns the responses received from all the requests. func releaseDodos( + ctx context.Context, mainRequest *fasthttp.Request, timeout time.Duration, clientFunc ClientFunc, @@ -136,14 +153,17 @@ func releaseDodos( ) Responses { var ( wg sync.WaitGroup + streamWG sync.WaitGroup requestCountPerDodo int ) - wg.Add(dodosCount + 1) // +1 for progress tracker + wg.Add(dodosCount) + streamWG.Add(1) responses := make([][]Response, dodosCount) countSlice := make([]int, dodosCount) - go streamProgress(&wg, requestCount, "Dodos Working🔥", &countSlice) + streamCtx, streamCtxCancel := context.WithCancel(context.Background()) + go streamProgress(streamCtx, &streamWG, requestCount, "Dodos Working🔥", &countSlice) for i := 0; i < dodosCount; i++ { if i+1 == dodosCount { @@ -157,6 +177,7 @@ func releaseDodos( mainRequest.CopyTo(dodoSpecificRequest) go sendRequest( + ctx, dodoSpecificRequest, timeout, &responses[i], @@ -167,6 +188,8 @@ func releaseDodos( ) } wg.Wait() + streamCtxCancel() + streamWG.Wait() return utils.Flatten(responses) } @@ -175,6 +198,7 @@ func releaseDodos( // For each request, it acquires a response object, gets a client, and measures the time taken to complete the request. // If an error occurs during the request, the error is recorded in the responseData slice. func sendRequest( + ctx context.Context, request *fasthttp.Request, timeout time.Duration, responseData *[]Response, @@ -187,6 +211,10 @@ func sendRequest( defer wg.Done() for range requestCount { + if ctx.Err() != nil { + return + } + func() { defer func() { *counter++ }() @@ -221,6 +249,7 @@ func sendRequest( // If the user chooses to continue, it returns a ClientFunc with a shared client or a randomized client. // If there are no proxies available, it returns a ClientFunc with a shared client. func getClientFunc( + ctx context.Context, timeout time.Duration, proxies []config.Proxy, dodosCount int, @@ -229,8 +258,11 @@ func getClientFunc( isTLS := URL.Scheme == "https" if len(proxies) > 0 { activeProxyClients := getActiveProxyClients( - proxies, timeout, dodosCount, URL, + ctx, proxies, timeout, dodosCount, URL, ) + if ctx.Err() != nil { + return nil + } activeProxyClientsCount := len(activeProxyClients) var yesOrNoMessage string if activeProxyClientsCount == 0 { @@ -287,6 +319,7 @@ func getClientFunc( // Once all goroutines have completed, the function waits for them to finish and // returns a flattened slice of active proxy clients. func getActiveProxyClients( + ctx context.Context, proxies []config.Proxy, timeout time.Duration, dodosCount int, @@ -295,12 +328,17 @@ func getActiveProxyClients( activeProxyClientsArray := make([][]fasthttp.HostClient, dodosCount) proxiesCount := len(proxies) - var wg sync.WaitGroup - wg.Add(dodosCount + 1) // +1 for progress tracker + var ( + wg sync.WaitGroup + streamWG sync.WaitGroup + ) + wg.Add(dodosCount) + streamWG.Add(1) var proxiesSlice []config.Proxy countSlice := make([]int, dodosCount) - go streamProgress(&wg, proxiesCount, "Searching for active proxies🌐", &countSlice) + streamCtx, streamCtxCancel := context.WithCancel(context.Background()) + go streamProgress(streamCtx, &streamWG, proxiesCount, "Searching for active proxies🌐", &countSlice) for i := 0; i < dodosCount; i++ { if i+1 == dodosCount { @@ -309,6 +347,7 @@ func getActiveProxyClients( proxiesSlice = proxies[i*proxiesCount/dodosCount : (i+1)*proxiesCount/dodosCount] } go findActiveProxyClients( + ctx, proxiesSlice, timeout, &activeProxyClientsArray[i], @@ -318,6 +357,8 @@ func getActiveProxyClients( ) } wg.Wait() + streamCtxCancel() + streamWG.Wait() return utils.Flatten(activeProxyClientsArray) } @@ -326,6 +367,7 @@ func getActiveProxyClients( // It also increments the count for each successful request. // The function is designed to be used as a concurrent operation, and it uses the WaitGroup to wait for all goroutines to finish. func findActiveProxyClients( + ctx context.Context, proxies []config.Proxy, timeout time.Duration, activeProxyClients *[]fasthttp.HostClient, @@ -341,6 +383,10 @@ func findActiveProxyClients( request.Header.SetMethod("GET") for _, proxy := range proxies { + if ctx.Err() != nil { + return + } + func() { defer func() { *count++ }() @@ -520,6 +566,7 @@ func setRequestBody(req *fasthttp.Request, body string) { // The function runs in a separate goroutine and updates the progress bar until all items are processed. // Once all items are processed, it marks the progress bar as done and stops rendering. func streamProgress( + ctx context.Context, wg *sync.WaitGroup, total int, message string, @@ -543,11 +590,21 @@ func streamProgress( totalCount += count } dodosTracker.SetValue(int64(totalCount)) + + if ctx.Err() != nil { + fmt.Printf("\r") + dodosTracker.MarkAsErrored() + time.Sleep(time.Millisecond * 300) + pw.Stop() + return + } + if totalCount == total { break } time.Sleep(time.Millisecond * 200) } + fmt.Printf("\r") dodosTracker.MarkAsDone() time.Sleep(time.Millisecond * 300) pw.Stop() @@ -555,14 +612,25 @@ func streamProgress( // checkConnection checks the internet connection by making requests to different websites. // It returns true if the connection is successful, otherwise false. -func checkConnection() bool { - _, _, err := fasthttp.Get(nil, "https://www.google.com") - if err != nil { - _, _, err = fasthttp.Get(nil, "https://www.bing.com") +func checkConnection(ctx context.Context) bool { + ch := make(chan bool) + go func() { + _, _, err := fasthttp.Get(nil, "https://www.google.com") if err != nil { - _, _, err = fasthttp.Get(nil, "https://www.yahoo.com") - return err == nil + _, _, err = fasthttp.Get(nil, "https://www.bing.com") + if err != nil { + _, _, err = fasthttp.Get(nil, "https://www.yahoo.com") + ch <- err == nil + } + ch <- true } + ch <- true + }() + + select { + case <-ctx.Done(): + return false + case res := <-ch: + return res } - return true }