plugins: basic auth & cors
Showing
3 changed files
with
525 additions
and
24 deletions
| 1 | // Beego (http://beego.me/) | 1 | // Copyright 2014 beego Author. All Rights Reserved. |
| 2 | // | 2 | // |
| 3 | // @description beego is an open-source, high-performance web framework for the Go programming language. | 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. | ||
| 5 | // You may obtain a copy of the License at | ||
| 4 | // | 6 | // |
| 5 | // @link http://github.com/astaxie/beego for the canonical source repository | 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | // | 8 | // |
| 7 | // @license http://github.com/astaxie/beego/blob/master/LICENSE | 9 | // Unless required by applicable law or agreed to in writing, software |
| 8 | // | 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | // @authors astaxie | 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | package auth | 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. | ||
| 11 | 14 | ||
| 12 | // Example: | 15 | // Package auth provides handlers to enable basic auth support. |
| 13 | // func SecretAuth(username, password string) bool { | 16 | // Simple Usage: |
| 14 | // if username == "astaxie" && password == "helloBeego" { | 17 | // import( |
| 15 | // return true | 18 | // "github.com/astaxie/beego" |
| 19 | // "github.com/astaxie/beego/plugins/auth" | ||
| 20 | // ) | ||
| 21 | // | ||
| 22 | // func main(){ | ||
| 23 | // // authenticate every request | ||
| 24 | // beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword")) | ||
| 25 | // beego.Run() | ||
| 16 | // } | 26 | // } |
| 17 | // return false | 27 | // |
| 28 | // | ||
| 29 | // Advanced Usage: | ||
| 30 | // func SecretAuth(username, password string) bool { | ||
| 31 | // return username == "astaxie" && password == "helloBeego" | ||
| 18 | // } | 32 | // } |
| 19 | // authPlugin := auth.NewBasicAuthenticator(SecretAuth, "My Realm") | 33 | // authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required") |
| 20 | // beego.InsertFilter("*", beego.BeforeRouter,authPlugin) | 34 | // beego.InsertFilter("*", beego.BeforeRouter,authPlugin) |
| 35 | package auth | ||
| 21 | 36 | ||
| 22 | import ( | 37 | import ( |
| 23 | "encoding/base64" | 38 | "encoding/base64" |
| ... | @@ -28,6 +43,15 @@ import ( | ... | @@ -28,6 +43,15 @@ import ( |
| 28 | "github.com/astaxie/beego/context" | 43 | "github.com/astaxie/beego/context" |
| 29 | ) | 44 | ) |
| 30 | 45 | ||
| 46 | var defaultRealm = "Authorization Required" | ||
| 47 | |||
| 48 | func Basic(username string, password string) beego.FilterFunc { | ||
| 49 | secrets := func(user, pass string) bool { | ||
| 50 | return user == username && pass == password | ||
| 51 | } | ||
| 52 | return NewBasicAuthenticator(secrets, defaultRealm) | ||
| 53 | } | ||
| 54 | |||
| 31 | func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { | 55 | func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { |
| 32 | return func(ctx *context.Context) { | 56 | return func(ctx *context.Context) { |
| 33 | a := &BasicAuth{Secrets: secrets, Realm: Realm} | 57 | a := &BasicAuth{Secrets: secrets, Realm: Realm} |
| ... | @@ -44,13 +68,10 @@ type BasicAuth struct { | ... | @@ -44,13 +68,10 @@ type BasicAuth struct { |
| 44 | Realm string | 68 | Realm string |
| 45 | } | 69 | } |
| 46 | 70 | ||
| 47 | /* | 71 | //Checks the username/password combination from the request. Returns |
| 48 | Checks the username/password combination from the request. Returns | 72 | //either an empty string (authentication failed) or the name of the |
| 49 | either an empty string (authentication failed) or the name of the | 73 | //authenticated user. |
| 50 | authenticated user. | 74 | //Supports MD5 and SHA1 password entries |
| 51 | |||
| 52 | Supports MD5 and SHA1 password entries | ||
| 53 | */ | ||
| 54 | func (a *BasicAuth) CheckAuth(r *http.Request) string { | 75 | func (a *BasicAuth) CheckAuth(r *http.Request) string { |
| 55 | s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) | 76 | s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) |
| 56 | if len(s) != 2 || s[0] != "Basic" { | 77 | if len(s) != 2 || s[0] != "Basic" { |
| ... | @@ -72,10 +93,8 @@ func (a *BasicAuth) CheckAuth(r *http.Request) string { | ... | @@ -72,10 +93,8 @@ func (a *BasicAuth) CheckAuth(r *http.Request) string { |
| 72 | return "" | 93 | return "" |
| 73 | } | 94 | } |
| 74 | 95 | ||
| 75 | /* | 96 | //http.Handler for BasicAuth which initiates the authentication process |
| 76 | http.Handler for BasicAuth which initiates the authentication process | 97 | //(or requires reauthentication). |
| 77 | (or requires reauthentication). | ||
| 78 | */ | ||
| 79 | func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { | 98 | func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { |
| 80 | w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`) | 99 | w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`) |
| 81 | w.WriteHeader(401) | 100 | w.WriteHeader(401) | ... | ... |
plugins/cors/cors.go
0 → 100644
| 1 | // Copyright 2014 beego Author. All Rights Reserved. | ||
| 2 | // | ||
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 4 | // you may not use this file except in compliance with the License. | ||
| 5 | // You may obtain a copy of the License at | ||
| 6 | // | ||
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 | ||
| 8 | // | ||
| 9 | // Unless required by applicable law or agreed to in writing, software | ||
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, | ||
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 12 | // See the License for the specific language governing permissions and | ||
| 13 | // limitations under the License. | ||
| 14 | |||
| 15 | // Package cors provides handlers to enable CORS support. | ||
| 16 | // Usage | ||
| 17 | // | ||
| 18 | // import ( | ||
| 19 | // "github.com/astaxie/beego" | ||
| 20 | // "github.com/astaxie/beego/plugins/cors" | ||
| 21 | // ) | ||
| 22 | |||
| 23 | //func main() { | ||
| 24 | // // CORS for https://foo.* origins, allowing: | ||
| 25 | // // - PUT and PATCH methods | ||
| 26 | // // - Origin header | ||
| 27 | // // - Credentials share | ||
| 28 | // beego.InsertFilter("*", beego.BeforeRouter,cors.Allow(&cors.Options{ | ||
| 29 | // AllowOrigins: []string{"https://*.foo.com"}, | ||
| 30 | // AllowMethods: []string{"PUT", "PATCH"}, | ||
| 31 | // AllowHeaders: []string{"Origin"}, | ||
| 32 | // ExposeHeaders: []string{"Content-Length"}, | ||
| 33 | // AllowCredentials: true, | ||
| 34 | // })) | ||
| 35 | // beego.Run() | ||
| 36 | //} | ||
| 37 | package cors | ||
| 38 | |||
| 39 | import ( | ||
| 40 | "net/http" | ||
| 41 | "regexp" | ||
| 42 | "strconv" | ||
| 43 | "strings" | ||
| 44 | "time" | ||
| 45 | |||
| 46 | "github.com/astaxie/beego" | ||
| 47 | "github.com/astaxie/beego/context" | ||
| 48 | ) | ||
| 49 | |||
| 50 | const ( | ||
| 51 | headerAllowOrigin = "Access-Control-Allow-Origin" | ||
| 52 | headerAllowCredentials = "Access-Control-Allow-Credentials" | ||
| 53 | headerAllowHeaders = "Access-Control-Allow-Headers" | ||
| 54 | headerAllowMethods = "Access-Control-Allow-Methods" | ||
| 55 | headerExposeHeaders = "Access-Control-Expose-Headers" | ||
| 56 | headerMaxAge = "Access-Control-Max-Age" | ||
| 57 | |||
| 58 | headerOrigin = "Origin" | ||
| 59 | headerRequestMethod = "Access-Control-Request-Method" | ||
| 60 | headerRequestHeaders = "Access-Control-Request-Headers" | ||
| 61 | ) | ||
| 62 | |||
| 63 | var ( | ||
| 64 | defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"} | ||
| 65 | // Regex patterns are generated from AllowOrigins. These are used and generated internally. | ||
| 66 | allowOriginPatterns = []string{} | ||
| 67 | ) | ||
| 68 | |||
| 69 | // Options represents Access Control options. | ||
| 70 | type Options struct { | ||
| 71 | // If set, all origins are allowed. | ||
| 72 | AllowAllOrigins bool | ||
| 73 | // A list of allowed origins. Wild cards and FQDNs are supported. | ||
| 74 | AllowOrigins []string | ||
| 75 | // If set, allows to share auth credentials such as cookies. | ||
| 76 | AllowCredentials bool | ||
| 77 | // A list of allowed HTTP methods. | ||
| 78 | AllowMethods []string | ||
| 79 | // A list of allowed HTTP headers. | ||
| 80 | AllowHeaders []string | ||
| 81 | // A list of exposed HTTP headers. | ||
| 82 | ExposeHeaders []string | ||
| 83 | // Max age of the CORS headers. | ||
| 84 | MaxAge time.Duration | ||
| 85 | } | ||
| 86 | |||
| 87 | // Header converts options into CORS headers. | ||
| 88 | func (o *Options) Header(origin string) (headers map[string]string) { | ||
| 89 | headers = make(map[string]string) | ||
| 90 | // if origin is not allowed, don't extend the headers | ||
| 91 | // with CORS headers. | ||
| 92 | if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { | ||
| 93 | return | ||
| 94 | } | ||
| 95 | |||
| 96 | // add allow origin | ||
| 97 | if o.AllowAllOrigins { | ||
| 98 | headers[headerAllowOrigin] = "*" | ||
| 99 | } else { | ||
| 100 | headers[headerAllowOrigin] = origin | ||
| 101 | } | ||
| 102 | |||
| 103 | // add allow credentials | ||
| 104 | headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) | ||
| 105 | |||
| 106 | // add allow methods | ||
| 107 | if len(o.AllowMethods) > 0 { | ||
| 108 | headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") | ||
| 109 | } | ||
| 110 | |||
| 111 | // add allow headers | ||
| 112 | if len(o.AllowHeaders) > 0 { | ||
| 113 | headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",") | ||
| 114 | } | ||
| 115 | |||
| 116 | // add exposed header | ||
| 117 | if len(o.ExposeHeaders) > 0 { | ||
| 118 | headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") | ||
| 119 | } | ||
| 120 | // add a max age header | ||
| 121 | if o.MaxAge > time.Duration(0) { | ||
| 122 | headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) | ||
| 123 | } | ||
| 124 | return | ||
| 125 | } | ||
| 126 | |||
| 127 | // PreflightHeader converts options into CORS headers for a preflight response. | ||
| 128 | func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { | ||
| 129 | headers = make(map[string]string) | ||
| 130 | if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { | ||
| 131 | return | ||
| 132 | } | ||
| 133 | // verify if requested method is allowed | ||
| 134 | for _, method := range o.AllowMethods { | ||
| 135 | if method == rMethod { | ||
| 136 | headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") | ||
| 137 | break | ||
| 138 | } | ||
| 139 | } | ||
| 140 | |||
| 141 | // verify if requested headers are allowed | ||
| 142 | var allowed []string | ||
| 143 | for _, rHeader := range strings.Split(rHeaders, ",") { | ||
| 144 | rHeader = strings.TrimSpace(rHeader) | ||
| 145 | lookupLoop: | ||
| 146 | for _, allowedHeader := range o.AllowHeaders { | ||
| 147 | if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { | ||
| 148 | allowed = append(allowed, rHeader) | ||
| 149 | break lookupLoop | ||
| 150 | } | ||
| 151 | } | ||
| 152 | } | ||
| 153 | |||
| 154 | headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) | ||
| 155 | // add allow origin | ||
| 156 | if o.AllowAllOrigins { | ||
| 157 | headers[headerAllowOrigin] = "*" | ||
| 158 | } else { | ||
| 159 | headers[headerAllowOrigin] = origin | ||
| 160 | } | ||
| 161 | |||
| 162 | // add allowed headers | ||
| 163 | if len(allowed) > 0 { | ||
| 164 | headers[headerAllowHeaders] = strings.Join(allowed, ",") | ||
| 165 | } | ||
| 166 | |||
| 167 | // add exposed headers | ||
| 168 | if len(o.ExposeHeaders) > 0 { | ||
| 169 | headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") | ||
| 170 | } | ||
| 171 | // add a max age header | ||
| 172 | if o.MaxAge > time.Duration(0) { | ||
| 173 | headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) | ||
| 174 | } | ||
| 175 | return | ||
| 176 | } | ||
| 177 | |||
| 178 | // IsOriginAllowed looks up if the origin matches one of the patterns | ||
| 179 | // generated from Options.AllowOrigins patterns. | ||
| 180 | func (o *Options) IsOriginAllowed(origin string) (allowed bool) { | ||
| 181 | for _, pattern := range allowOriginPatterns { | ||
| 182 | allowed, _ = regexp.MatchString(pattern, origin) | ||
| 183 | if allowed { | ||
| 184 | return | ||
| 185 | } | ||
| 186 | } | ||
| 187 | return | ||
| 188 | } | ||
| 189 | |||
| 190 | // Allow enables CORS for requests those match the provided options. | ||
| 191 | func Allow(opts *Options) beego.FilterFunc { | ||
| 192 | // Allow default headers if nothing is specified. | ||
| 193 | if len(opts.AllowHeaders) == 0 { | ||
| 194 | opts.AllowHeaders = defaultAllowHeaders | ||
| 195 | } | ||
| 196 | |||
| 197 | for _, origin := range opts.AllowOrigins { | ||
| 198 | pattern := regexp.QuoteMeta(origin) | ||
| 199 | pattern = strings.Replace(pattern, "\\*", ".*", -1) | ||
| 200 | pattern = strings.Replace(pattern, "\\?", ".", -1) | ||
| 201 | allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$") | ||
| 202 | } | ||
| 203 | |||
| 204 | return func(ctx *context.Context) { | ||
| 205 | var ( | ||
| 206 | origin = ctx.Input.Header(headerOrigin) | ||
| 207 | requestedMethod = ctx.Input.Header(headerRequestMethod) | ||
| 208 | requestedHeaders = ctx.Input.Header(headerRequestHeaders) | ||
| 209 | // additional headers to be added | ||
| 210 | // to the response. | ||
| 211 | headers map[string]string | ||
| 212 | ) | ||
| 213 | |||
| 214 | if ctx.Input.Method() == "OPTIONS" && | ||
| 215 | (requestedMethod != "" || requestedHeaders != "") { | ||
| 216 | headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders) | ||
| 217 | for key, value := range headers { | ||
| 218 | ctx.Output.Header(key, value) | ||
| 219 | } | ||
| 220 | ctx.Output.SetStatus(http.StatusOK) | ||
| 221 | return | ||
| 222 | } | ||
| 223 | headers = opts.Header(origin) | ||
| 224 | |||
| 225 | for key, value := range headers { | ||
| 226 | ctx.Output.Header(key, value) | ||
| 227 | } | ||
| 228 | } | ||
| 229 | } |
plugins/cors/cors_test.go
0 → 100644
| 1 | // Copyright 2014 beego Author. All Rights Reserved. | ||
| 2 | // | ||
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 4 | // you may not use this file except in compliance with the License. | ||
| 5 | // You may obtain a copy of the License at | ||
| 6 | // | ||
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 | ||
| 8 | // | ||
| 9 | // Unless required by applicable law or agreed to in writing, software | ||
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, | ||
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 12 | // See the License for the specific language governing permissions and | ||
| 13 | // limitations under the License. | ||
| 14 | |||
| 15 | // Package cors provides handlers to enable CORS support. | ||
| 16 | package cors | ||
| 17 | |||
| 18 | import ( | ||
| 19 | "net/http" | ||
| 20 | "net/http/httptest" | ||
| 21 | "strings" | ||
| 22 | "testing" | ||
| 23 | "time" | ||
| 24 | |||
| 25 | "github.com/astaxie/beego" | ||
| 26 | "github.com/astaxie/beego/context" | ||
| 27 | ) | ||
| 28 | |||
| 29 | type HttpHeaderGuardRecorder struct { | ||
| 30 | *httptest.ResponseRecorder | ||
| 31 | savedHeaderMap http.Header | ||
| 32 | } | ||
| 33 | |||
| 34 | func NewRecorder() *HttpHeaderGuardRecorder { | ||
| 35 | return &HttpHeaderGuardRecorder{httptest.NewRecorder(), nil} | ||
| 36 | } | ||
| 37 | |||
| 38 | func (gr *HttpHeaderGuardRecorder) WriteHeader(code int) { | ||
| 39 | gr.ResponseRecorder.WriteHeader(code) | ||
| 40 | gr.savedHeaderMap = gr.ResponseRecorder.Header() | ||
| 41 | } | ||
| 42 | |||
| 43 | func (gr *HttpHeaderGuardRecorder) Header() http.Header { | ||
| 44 | if gr.savedHeaderMap != nil { | ||
| 45 | // headers were written. clone so we don't get updates | ||
| 46 | clone := make(http.Header) | ||
| 47 | for k, v := range gr.savedHeaderMap { | ||
| 48 | clone[k] = v | ||
| 49 | } | ||
| 50 | return clone | ||
| 51 | } else { | ||
| 52 | return gr.ResponseRecorder.Header() | ||
| 53 | } | ||
| 54 | } | ||
| 55 | |||
| 56 | func Test_AllowAll(t *testing.T) { | ||
| 57 | recorder := httptest.NewRecorder() | ||
| 58 | handler := beego.NewControllerRegister() | ||
| 59 | handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ | ||
| 60 | AllowAllOrigins: true, | ||
| 61 | })) | ||
| 62 | handler.Any("/foo", func(ctx *context.Context) { | ||
| 63 | ctx.Output.SetStatus(500) | ||
| 64 | }) | ||
| 65 | r, _ := http.NewRequest("PUT", "/foo", nil) | ||
| 66 | handler.ServeHTTP(recorder, r) | ||
| 67 | |||
| 68 | if recorder.HeaderMap.Get(headerAllowOrigin) != "*" { | ||
| 69 | t.Errorf("Allow-Origin header should be *") | ||
| 70 | } | ||
| 71 | } | ||
| 72 | |||
| 73 | func Test_AllowRegexMatch(t *testing.T) { | ||
| 74 | recorder := httptest.NewRecorder() | ||
| 75 | handler := beego.NewControllerRegister() | ||
| 76 | handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ | ||
| 77 | AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"}, | ||
| 78 | })) | ||
| 79 | handler.Any("/foo", func(ctx *context.Context) { | ||
| 80 | ctx.Output.SetStatus(500) | ||
| 81 | }) | ||
| 82 | origin := "https://bar.foo.com" | ||
| 83 | r, _ := http.NewRequest("PUT", "/foo", nil) | ||
| 84 | r.Header.Add("Origin", origin) | ||
| 85 | handler.ServeHTTP(recorder, r) | ||
| 86 | |||
| 87 | headerValue := recorder.HeaderMap.Get(headerAllowOrigin) | ||
| 88 | if headerValue != origin { | ||
| 89 | t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue) | ||
| 90 | } | ||
| 91 | } | ||
| 92 | |||
| 93 | func Test_AllowRegexNoMatch(t *testing.T) { | ||
| 94 | recorder := httptest.NewRecorder() | ||
| 95 | handler := beego.NewControllerRegister() | ||
| 96 | handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ | ||
| 97 | AllowOrigins: []string{"https://*.foo.com"}, | ||
| 98 | })) | ||
| 99 | handler.Any("/foo", func(ctx *context.Context) { | ||
| 100 | ctx.Output.SetStatus(500) | ||
| 101 | }) | ||
| 102 | origin := "https://ww.foo.com.evil.com" | ||
| 103 | r, _ := http.NewRequest("PUT", "/foo", nil) | ||
| 104 | r.Header.Add("Origin", origin) | ||
| 105 | handler.ServeHTTP(recorder, r) | ||
| 106 | |||
| 107 | headerValue := recorder.HeaderMap.Get(headerAllowOrigin) | ||
| 108 | if headerValue != "" { | ||
| 109 | t.Errorf("Allow-Origin header should not exist, found %v", headerValue) | ||
| 110 | } | ||
| 111 | } | ||
| 112 | |||
| 113 | func Test_OtherHeaders(t *testing.T) { | ||
| 114 | recorder := httptest.NewRecorder() | ||
| 115 | handler := beego.NewControllerRegister() | ||
| 116 | handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ | ||
| 117 | AllowAllOrigins: true, | ||
| 118 | AllowCredentials: true, | ||
| 119 | AllowMethods: []string{"PATCH", "GET"}, | ||
| 120 | AllowHeaders: []string{"Origin", "X-whatever"}, | ||
| 121 | ExposeHeaders: []string{"Content-Length", "Hello"}, | ||
| 122 | MaxAge: 5 * time.Minute, | ||
| 123 | })) | ||
| 124 | handler.Any("/foo", func(ctx *context.Context) { | ||
| 125 | ctx.Output.SetStatus(500) | ||
| 126 | }) | ||
| 127 | r, _ := http.NewRequest("PUT", "/foo", nil) | ||
| 128 | handler.ServeHTTP(recorder, r) | ||
| 129 | |||
| 130 | credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials) | ||
| 131 | methodsVal := recorder.HeaderMap.Get(headerAllowMethods) | ||
| 132 | headersVal := recorder.HeaderMap.Get(headerAllowHeaders) | ||
| 133 | exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders) | ||
| 134 | maxAgeVal := recorder.HeaderMap.Get(headerMaxAge) | ||
| 135 | |||
| 136 | if credentialsVal != "true" { | ||
| 137 | t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal) | ||
| 138 | } | ||
| 139 | |||
| 140 | if methodsVal != "PATCH,GET" { | ||
| 141 | t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal) | ||
| 142 | } | ||
| 143 | |||
| 144 | if headersVal != "Origin,X-whatever" { | ||
| 145 | t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal) | ||
| 146 | } | ||
| 147 | |||
| 148 | if exposedHeadersVal != "Content-Length,Hello" { | ||
| 149 | t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal) | ||
| 150 | } | ||
| 151 | |||
| 152 | if maxAgeVal != "300" { | ||
| 153 | t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal) | ||
| 154 | } | ||
| 155 | } | ||
| 156 | |||
| 157 | func Test_DefaultAllowHeaders(t *testing.T) { | ||
| 158 | recorder := httptest.NewRecorder() | ||
| 159 | handler := beego.NewControllerRegister() | ||
| 160 | handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ | ||
| 161 | AllowAllOrigins: true, | ||
| 162 | })) | ||
| 163 | handler.Any("/foo", func(ctx *context.Context) { | ||
| 164 | ctx.Output.SetStatus(500) | ||
| 165 | }) | ||
| 166 | |||
| 167 | r, _ := http.NewRequest("PUT", "/foo", nil) | ||
| 168 | handler.ServeHTTP(recorder, r) | ||
| 169 | |||
| 170 | headersVal := recorder.HeaderMap.Get(headerAllowHeaders) | ||
| 171 | if headersVal != "Origin,Accept,Content-Type,Authorization" { | ||
| 172 | t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal) | ||
| 173 | } | ||
| 174 | } | ||
| 175 | |||
| 176 | func Test_Preflight(t *testing.T) { | ||
| 177 | recorder := NewRecorder() | ||
| 178 | handler := beego.NewControllerRegister() | ||
| 179 | handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ | ||
| 180 | AllowAllOrigins: true, | ||
| 181 | AllowMethods: []string{"PUT", "PATCH"}, | ||
| 182 | AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"}, | ||
| 183 | })) | ||
| 184 | |||
| 185 | handler.Any("/foo", func(ctx *context.Context) { | ||
| 186 | ctx.Output.SetStatus(200) | ||
| 187 | }) | ||
| 188 | |||
| 189 | r, _ := http.NewRequest("OPTIONS", "/foo", nil) | ||
| 190 | r.Header.Add(headerRequestMethod, "PUT") | ||
| 191 | r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive") | ||
| 192 | handler.ServeHTTP(recorder, r) | ||
| 193 | |||
| 194 | headers := recorder.Header() | ||
| 195 | methodsVal := headers.Get(headerAllowMethods) | ||
| 196 | headersVal := headers.Get(headerAllowHeaders) | ||
| 197 | originVal := headers.Get(headerAllowOrigin) | ||
| 198 | |||
| 199 | if methodsVal != "PUT,PATCH" { | ||
| 200 | t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal) | ||
| 201 | } | ||
| 202 | |||
| 203 | if !strings.Contains(headersVal, "X-whatever") { | ||
| 204 | t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal) | ||
| 205 | } | ||
| 206 | |||
| 207 | if !strings.Contains(headersVal, "x-casesensitive") { | ||
| 208 | t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal) | ||
| 209 | } | ||
| 210 | |||
| 211 | if originVal != "*" { | ||
| 212 | t.Errorf("Allow-Origin is expected to be *, found %v", originVal) | ||
| 213 | } | ||
| 214 | |||
| 215 | if recorder.Code != http.StatusOK { | ||
| 216 | t.Errorf("Status code is expected to be 200, found %d", recorder.Code) | ||
| 217 | } | ||
| 218 | } | ||
| 219 | |||
| 220 | func Benchmark_WithoutCORS(b *testing.B) { | ||
| 221 | recorder := httptest.NewRecorder() | ||
| 222 | handler := beego.NewControllerRegister() | ||
| 223 | beego.RunMode = "prod" | ||
| 224 | handler.Any("/foo", func(ctx *context.Context) { | ||
| 225 | ctx.Output.SetStatus(500) | ||
| 226 | }) | ||
| 227 | b.ResetTimer() | ||
| 228 | for i := 0; i < 100; i++ { | ||
| 229 | r, _ := http.NewRequest("PUT", "/foo", nil) | ||
| 230 | handler.ServeHTTP(recorder, r) | ||
| 231 | } | ||
| 232 | } | ||
| 233 | |||
| 234 | func Benchmark_WithCORS(b *testing.B) { | ||
| 235 | recorder := httptest.NewRecorder() | ||
| 236 | handler := beego.NewControllerRegister() | ||
| 237 | beego.RunMode = "prod" | ||
| 238 | handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{ | ||
| 239 | AllowAllOrigins: true, | ||
| 240 | AllowCredentials: true, | ||
| 241 | AllowMethods: []string{"PATCH", "GET"}, | ||
| 242 | AllowHeaders: []string{"Origin", "X-whatever"}, | ||
| 243 | MaxAge: 5 * time.Minute, | ||
| 244 | })) | ||
| 245 | handler.Any("/foo", func(ctx *context.Context) { | ||
| 246 | ctx.Output.SetStatus(500) | ||
| 247 | }) | ||
| 248 | b.ResetTimer() | ||
| 249 | for i := 0; i < 100; i++ { | ||
| 250 | r, _ := http.NewRequest("PUT", "/foo", nil) | ||
| 251 | handler.ServeHTTP(recorder, r) | ||
| 252 | } | ||
| 253 | } |
-
Please register or sign in to post a comment