Skip to content

Commit

Permalink
cors: ensure submitv2 enables CORS
Browse files Browse the repository at this point in the history
  • Loading branch information
parkr committed Apr 12, 2024
1 parent 7b7797a commit b03804b
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 29 deletions.
2 changes: 1 addition & 1 deletion cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (c *corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

func addCorsHeaders(w http.ResponseWriter, r *http.Request) {
w.Header().Set(CorsAccessControlAllowMethodsHeaderName, "GET")
w.Header().Set(CorsAccessControlAllowMethodsHeaderName, "GET, POST")
if allowCORSOrigin(r.Header.Get("Origin")) {
w.Header().Set(CorsAccessControlAllowOriginHeaderName, r.Header.Get("Origin"))
} else if allowCORSOrigin(r.Referer()) {
Expand Down
8 changes: 4 additions & 4 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestAddCorsHeaders_OriginRequestHeader_Success(t *testing.T) {

addCorsHeaders(recorder, request)

expectedAllowedMethods := "GET"
expectedAllowedMethods := "GET, POST"
actual := recorder.Header().Get(CorsAccessControlAllowMethodsHeaderName)
if actual != expectedAllowedMethods {
t.Errorf("expected %s: %v, got: %v", CorsAccessControlAllowMethodsHeaderName, expectedAllowedMethods, actual)
Expand All @@ -43,7 +43,7 @@ func TestAddCorsHeaders_RefererRequestHeader_Success(t *testing.T) {

addCorsHeaders(recorder, request)

expectedAllowedMethods := "GET"
expectedAllowedMethods := "GET, POST"
actual := recorder.Header().Get(CorsAccessControlAllowMethodsHeaderName)
if actual != expectedAllowedMethods {
t.Errorf("expected %s: %v, got: %v", CorsAccessControlAllowMethodsHeaderName, expectedAllowedMethods, actual)
Expand All @@ -68,7 +68,7 @@ func TestAddCorsHeaders_NeitherRequestHeader_Success(t *testing.T) {

addCorsHeaders(recorder, request)

expectedAllowedMethods := "GET"
expectedAllowedMethods := "GET, POST"
actual := recorder.Header().Get(CorsAccessControlAllowMethodsHeaderName)
if actual != expectedAllowedMethods {
t.Errorf("expected %s: %v, got: %v", CorsAccessControlAllowMethodsHeaderName, expectedAllowedMethods, actual)
Expand All @@ -93,7 +93,7 @@ func TestAddCorsHeaders_UnparseableRequestHeader_Success(t *testing.T) {

addCorsHeaders(recorder, request)

expectedAllowedMethods := "GET"
expectedAllowedMethods := "GET, POST"
actual := recorder.Header().Get(CorsAccessControlAllowMethodsHeaderName)
if actual != expectedAllowedMethods {
t.Errorf("expected %s: %v, got: %v", CorsAccessControlAllowMethodsHeaderName, expectedAllowedMethods, actual)
Expand Down
8 changes: 6 additions & 2 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,23 @@ func submitv2(w http.ResponseWriter, r *http.Request) {
path := r.FormValue("path")
if host == "" || path == "" {
jsv1.Error(w, http.StatusBadRequest, "missing param")
addCorsHeaders(w, r)
return
}
referer := url.URL{Host: host, Path: path}

req, err := http.NewRequest(http.MethodGet, "/ping.js", nil)
if err != nil {
jsv1.Error(w, http.StatusInternalServerError, "unable to rewrite")
addCorsHeaders(w, r)
return
}
req.Header.Set("Referer", referer.String())
req.Header.Set("User-Agent", r.Header.Get("User-Agent"))
req.Header.Set("X-Forwarded-For", r.RemoteAddr)

addCorsHeaders(w, r)

log.Printf("forwarding v2 to v1")

pingv1(w, req)
Expand Down Expand Up @@ -245,8 +249,8 @@ func buildHandler() *http.ServeMux {
mux.HandleFunc("/_health", health)
mux.HandleFunc("/ping", ping)
mux.HandleFunc("/ping.js", ping)
mux.HandleFunc("/submit", submitv2)
mux.HandleFunc("/submit.js", submitv2)
mux.Handle("/submit", &corsHandler{submitv2})
mux.Handle("/submit.js", &corsHandler{submitv2})
mux.Handle("/counts", &corsHandler{counts})
mux.Handle("/all", &corsHandler{all})
return mux
Expand Down
38 changes: 16 additions & 22 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,7 @@ func TestCountsOptionsPreflight(t *testing.T) {
status, http.StatusNoContent)
}

expectedAllowedHosts := "https://example.org"
actual := recorder.Header().Get("Access-Control-Allow-Origin")
if actual != expectedAllowedHosts {
t.Errorf("expected Access-Control-Allow-Origin: %v, got: %v", expectedAllowedHosts, actual)
}

expectedAllowedMethods := "GET"
actual = recorder.Header().Get("Access-Control-Allow-Methods")
if actual != expectedAllowedMethods {
t.Errorf("expected Access-Control-Allow-Methods: %v, got: %v", expectedAllowedMethods, actual)
}
verifyCorsHeaders(t, recorder, "example.org")
}

func TestCountsMissingParam(t *testing.T) {
Expand Down Expand Up @@ -311,17 +301,7 @@ func TestAllOptionsPreflight(t *testing.T) {
status, http.StatusNoContent)
}

expectedAllowedHosts := "https://example.org"
actual := recorder.Header().Get("Access-Control-Allow-Origin")
if actual != expectedAllowedHosts {
t.Errorf("expected Access-Control-Allow-Origin: %v, got: %v", expectedAllowedHosts, actual)
}

expectedAllowedMethods := "GET"
actual = recorder.Header().Get("Access-Control-Allow-Methods")
if actual != expectedAllowedMethods {
t.Errorf("expected Access-Control-Allow-Methods: %v, got: %v", expectedAllowedMethods, actual)
}
verifyCorsHeaders(t, recorder, "example.org")
}

func TestAllHost(t *testing.T) {
Expand Down Expand Up @@ -391,3 +371,17 @@ func TestAllPath(t *testing.T) {
firstElement, expected)
}
}

func verifyCorsHeaders(t *testing.T, recorder *httptest.ResponseRecorder, origin string) {
expectedAllowedHosts := "https://" + origin
actual := recorder.Header().Get(CorsAccessControlAllowOriginHeaderName)
if actual != expectedAllowedHosts {
t.Errorf("expected %s: %v, got: %v", CorsAccessControlAllowOriginHeaderName, expectedAllowedHosts, actual)
}

expectedAllowedMethods := "GET, POST"
actual = recorder.Header().Get(CorsAccessControlAllowMethodsHeaderName)
if actual != expectedAllowedMethods {
t.Errorf("expected %s: %v, got: %v", CorsAccessControlAllowMethodsHeaderName, expectedAllowedMethods, actual)
}
}
18 changes: 18 additions & 0 deletions pingv2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ func TestSubmitV2_MissingHost(t *testing.T) {
}
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.Header.Set("User-Agent", "go test client")
request.Header.Set("Referer", "https://example.org")

recorder := httptest.NewRecorder()
handler := buildHandler()
Expand All @@ -162,6 +163,8 @@ func TestSubmitV2_MissingHost(t *testing.T) {
t.Errorf("submitv2 body is not expected string %q, got: %v",
expected, recorder.Body.String())
}

verifyCorsHeaders(t, recorder, "example.org")
}

func TestSubmitV2_MissingPath(t *testing.T) {
Expand All @@ -173,6 +176,7 @@ func TestSubmitV2_MissingPath(t *testing.T) {
}
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.Header.Set("User-Agent", "go test client")
request.Header.Set("Referer", "https://example.org")

recorder := httptest.NewRecorder()
handler := buildHandler()
Expand All @@ -188,6 +192,8 @@ func TestSubmitV2_MissingPath(t *testing.T) {
t.Errorf("submitv2 body is not expected string %q, got: %v",
expected, recorder.Body.String())
}

verifyCorsHeaders(t, recorder, "example.org")
}

func TestSubmitV2_InvalidHost(t *testing.T) {
Expand All @@ -199,6 +205,7 @@ func TestSubmitV2_InvalidHost(t *testing.T) {
}
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.Header.Set("User-Agent", "go test client")
request.Header.Set("Referer", "https://example.org")

recorder := httptest.NewRecorder()
handler := buildHandler()
Expand All @@ -214,6 +221,8 @@ func TestSubmitV2_InvalidHost(t *testing.T) {
t.Errorf("submitv2 body is not expected string %q, got: %v",
expected, recorder.Body.String())
}

verifyCorsHeaders(t, recorder, "example.org")
}

func TestSubmitV2_UnauthorizedHost(t *testing.T) {
Expand All @@ -225,6 +234,7 @@ func TestSubmitV2_UnauthorizedHost(t *testing.T) {
}
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.Header.Set("User-Agent", "go test client")
request.Header.Set("Referer", "https://example.org")

recorder := httptest.NewRecorder()
handler := buildHandler()
Expand All @@ -240,6 +250,8 @@ func TestSubmitV2_UnauthorizedHost(t *testing.T) {
t.Errorf("submitv2 body is not expected string %q, got: %v",
expected, recorder.Body.String())
}

verifyCorsHeaders(t, recorder, "example.org")
}

func TestSubmitV2_MissingUserAgent(t *testing.T) {
Expand All @@ -251,6 +263,7 @@ func TestSubmitV2_MissingUserAgent(t *testing.T) {
}
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.Header.Set("User-Agent", "") // missing
request.Header.Set("Referer", "https://example.org")

recorder := httptest.NewRecorder()
handler := buildHandler()
Expand All @@ -266,6 +279,8 @@ func TestSubmitV2_MissingUserAgent(t *testing.T) {
t.Errorf("submitv2 body is not expected string %q, got: %v",
expected, recorder.Body.String())
}

verifyCorsHeaders(t, recorder, "example.org")
}

func TestSubmitV2_Success(t *testing.T) {
Expand All @@ -287,6 +302,7 @@ func TestSubmitV2_Success(t *testing.T) {
}
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.Header.Set("User-Agent", "go test client")
request.Header.Set("Referer", "https://example.org")

recorder := httptest.NewRecorder()
handler := buildHandler()
Expand All @@ -303,6 +319,8 @@ func TestSubmitV2_Success(t *testing.T) {
expected, recorder.Body.String())
}

verifyCorsHeaders(t, recorder, "example.org")

visitCountEnd, _ := analytics.ViewsForHostPath(db, "example.org", "/TestSubmitV2_Success")

if visitCountEnd != visitCountStart+1 {
Expand Down

0 comments on commit b03804b

Please sign in to comment.