rpcserver_test.go raw

   1  package wallet
   2  
   3  import (
   4  	"net/http"
   5  	"net/http/httptest"
   6  	"reflect"
   7  	"testing"
   8  	
   9  	"github.com/p9c/p9/pkg/qu"
  10  )
  11  
  12  func TestThrottle(t *testing.T) {
  13  	const threshold = 1
  14  	busy := qu.T()
  15  	srv := httptest.NewServer(
  16  		ThrottledFn(threshold,
  17  			func(w http.ResponseWriter, r *http.Request) {
  18  				<-busy
  19  			},
  20  		),
  21  	)
  22  	codes := make(chan int, 2)
  23  	for i := 0; i < cap(codes); i++ {
  24  		go func() {
  25  			res, e := http.Get(srv.URL)
  26  			if e != nil {
  27  				t.Fatal(e)
  28  			}
  29  			codes <- res.StatusCode
  30  		}()
  31  	}
  32  	got := make(map[int]int, cap(codes))
  33  	for i := 0; i < cap(codes); i++ {
  34  		got[<-codes]++
  35  		if i == 0 {
  36  			busy.Q()
  37  		}
  38  	}
  39  	want := map[int]int{200: 1, 429: 1}
  40  	if !reflect.DeepEqual(want, got) {
  41  		t.Fatalf("status codes: want: %v, got: %v", want, got)
  42  	}
  43  }
  44