cd9e614a by astaxie

plugins: basic auth & cors

1 parent c1234e7c
1 // Beego (http://beego.me/) 1 // Copyright 2014 beego Author. All Rights Reserved.
2 // 2 //
3 // @description beego is an open-source, high-performance web framework for the Go programming language. 3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
4 // 6 //
5 // @link http://github.com/astaxie/beego for the canonical source repository 7 // http://www.apache.org/licenses/LICENSE-2.0
6 // 8 //
7 // @license http://github.com/astaxie/beego/blob/master/LICENSE 9 // Unless required by applicable law or agreed to in writing, software
8 // 10 // distributed under the License is distributed on an "AS IS" BASIS,
9 // @authors astaxie 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 package auth 12 // See the License for the specific language governing permissions and
13 // limitations under the License.
11 14
12 // Example: 15 // Package auth provides handlers to enable basic auth support.
13 // func SecretAuth(username, password string) bool { 16 // Simple Usage:
14 // if username == "astaxie" && password == "helloBeego" { 17 // import(
15 // return true 18 // "github.com/astaxie/beego"
19 // "github.com/astaxie/beego/plugins/auth"
20 // )
21 //
22 // func main(){
23 // // authenticate every request
24 // beego.InsertFilter("*", beego.BeforeRouter,auth.Basic("username","secretpassword"))
25 // beego.Run()
16 // } 26 // }
17 // return false 27 //
28 //
29 // Advanced Usage:
30 // func SecretAuth(username, password string) bool {
31 // return username == "astaxie" && password == "helloBeego"
18 // } 32 // }
19 // authPlugin := auth.NewBasicAuthenticator(SecretAuth, "My Realm") 33 // authPlugin := auth.NewBasicAuthenticator(SecretAuth, "Authorization Required")
20 // beego.InsertFilter("*", beego.BeforeRouter,authPlugin) 34 // beego.InsertFilter("*", beego.BeforeRouter,authPlugin)
35 package auth
21 36
22 import ( 37 import (
23 "encoding/base64" 38 "encoding/base64"
...@@ -28,6 +43,15 @@ import ( ...@@ -28,6 +43,15 @@ import (
28 "github.com/astaxie/beego/context" 43 "github.com/astaxie/beego/context"
29 ) 44 )
30 45
46 var defaultRealm = "Authorization Required"
47
48 func Basic(username string, password string) beego.FilterFunc {
49 secrets := func(user, pass string) bool {
50 return user == username && pass == password
51 }
52 return NewBasicAuthenticator(secrets, defaultRealm)
53 }
54
31 func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc { 55 func NewBasicAuthenticator(secrets SecretProvider, Realm string) beego.FilterFunc {
32 return func(ctx *context.Context) { 56 return func(ctx *context.Context) {
33 a := &BasicAuth{Secrets: secrets, Realm: Realm} 57 a := &BasicAuth{Secrets: secrets, Realm: Realm}
...@@ -44,13 +68,10 @@ type BasicAuth struct { ...@@ -44,13 +68,10 @@ type BasicAuth struct {
44 Realm string 68 Realm string
45 } 69 }
46 70
47 /* 71 //Checks the username/password combination from the request. Returns
48 Checks the username/password combination from the request. Returns 72 //either an empty string (authentication failed) or the name of the
49 either an empty string (authentication failed) or the name of the 73 //authenticated user.
50 authenticated user. 74 //Supports MD5 and SHA1 password entries
51
52 Supports MD5 and SHA1 password entries
53 */
54 func (a *BasicAuth) CheckAuth(r *http.Request) string { 75 func (a *BasicAuth) CheckAuth(r *http.Request) string {
55 s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) 76 s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
56 if len(s) != 2 || s[0] != "Basic" { 77 if len(s) != 2 || s[0] != "Basic" {
...@@ -72,10 +93,8 @@ func (a *BasicAuth) CheckAuth(r *http.Request) string { ...@@ -72,10 +93,8 @@ func (a *BasicAuth) CheckAuth(r *http.Request) string {
72 return "" 93 return ""
73 } 94 }
74 95
75 /* 96 //http.Handler for BasicAuth which initiates the authentication process
76 http.Handler for BasicAuth which initiates the authentication process 97 //(or requires reauthentication).
77 (or requires reauthentication).
78 */
79 func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) { 98 func (a *BasicAuth) RequireAuth(w http.ResponseWriter, r *http.Request) {
80 w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`) 99 w.Header().Set("WWW-Authenticate", `Basic realm="`+a.Realm+`"`)
81 w.WriteHeader(401) 100 w.WriteHeader(401)
......
1 // Copyright 2014 beego Author. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Package cors provides handlers to enable CORS support.
16 // Usage
17 //
18 // import (
19 // "github.com/astaxie/beego"
20 // "github.com/astaxie/beego/plugins/cors"
21 // )
22
23 //func main() {
24 // // CORS for https://foo.* origins, allowing:
25 // // - PUT and PATCH methods
26 // // - Origin header
27 // // - Credentials share
28 // beego.InsertFilter("*", beego.BeforeRouter,cors.Allow(&cors.Options{
29 // AllowOrigins: []string{"https://*.foo.com"},
30 // AllowMethods: []string{"PUT", "PATCH"},
31 // AllowHeaders: []string{"Origin"},
32 // ExposeHeaders: []string{"Content-Length"},
33 // AllowCredentials: true,
34 // }))
35 // beego.Run()
36 //}
37 package cors
38
39 import (
40 "net/http"
41 "regexp"
42 "strconv"
43 "strings"
44 "time"
45
46 "github.com/astaxie/beego"
47 "github.com/astaxie/beego/context"
48 )
49
50 const (
51 headerAllowOrigin = "Access-Control-Allow-Origin"
52 headerAllowCredentials = "Access-Control-Allow-Credentials"
53 headerAllowHeaders = "Access-Control-Allow-Headers"
54 headerAllowMethods = "Access-Control-Allow-Methods"
55 headerExposeHeaders = "Access-Control-Expose-Headers"
56 headerMaxAge = "Access-Control-Max-Age"
57
58 headerOrigin = "Origin"
59 headerRequestMethod = "Access-Control-Request-Method"
60 headerRequestHeaders = "Access-Control-Request-Headers"
61 )
62
63 var (
64 defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"}
65 // Regex patterns are generated from AllowOrigins. These are used and generated internally.
66 allowOriginPatterns = []string{}
67 )
68
69 // Options represents Access Control options.
70 type Options struct {
71 // If set, all origins are allowed.
72 AllowAllOrigins bool
73 // A list of allowed origins. Wild cards and FQDNs are supported.
74 AllowOrigins []string
75 // If set, allows to share auth credentials such as cookies.
76 AllowCredentials bool
77 // A list of allowed HTTP methods.
78 AllowMethods []string
79 // A list of allowed HTTP headers.
80 AllowHeaders []string
81 // A list of exposed HTTP headers.
82 ExposeHeaders []string
83 // Max age of the CORS headers.
84 MaxAge time.Duration
85 }
86
87 // Header converts options into CORS headers.
88 func (o *Options) Header(origin string) (headers map[string]string) {
89 headers = make(map[string]string)
90 // if origin is not allowed, don't extend the headers
91 // with CORS headers.
92 if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
93 return
94 }
95
96 // add allow origin
97 if o.AllowAllOrigins {
98 headers[headerAllowOrigin] = "*"
99 } else {
100 headers[headerAllowOrigin] = origin
101 }
102
103 // add allow credentials
104 headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
105
106 // add allow methods
107 if len(o.AllowMethods) > 0 {
108 headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
109 }
110
111 // add allow headers
112 if len(o.AllowHeaders) > 0 {
113 headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",")
114 }
115
116 // add exposed header
117 if len(o.ExposeHeaders) > 0 {
118 headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
119 }
120 // add a max age header
121 if o.MaxAge > time.Duration(0) {
122 headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
123 }
124 return
125 }
126
127 // PreflightHeader converts options into CORS headers for a preflight response.
128 func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) {
129 headers = make(map[string]string)
130 if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
131 return
132 }
133 // verify if requested method is allowed
134 for _, method := range o.AllowMethods {
135 if method == rMethod {
136 headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
137 break
138 }
139 }
140
141 // verify if requested headers are allowed
142 var allowed []string
143 for _, rHeader := range strings.Split(rHeaders, ",") {
144 rHeader = strings.TrimSpace(rHeader)
145 lookupLoop:
146 for _, allowedHeader := range o.AllowHeaders {
147 if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) {
148 allowed = append(allowed, rHeader)
149 break lookupLoop
150 }
151 }
152 }
153
154 headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
155 // add allow origin
156 if o.AllowAllOrigins {
157 headers[headerAllowOrigin] = "*"
158 } else {
159 headers[headerAllowOrigin] = origin
160 }
161
162 // add allowed headers
163 if len(allowed) > 0 {
164 headers[headerAllowHeaders] = strings.Join(allowed, ",")
165 }
166
167 // add exposed headers
168 if len(o.ExposeHeaders) > 0 {
169 headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
170 }
171 // add a max age header
172 if o.MaxAge > time.Duration(0) {
173 headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
174 }
175 return
176 }
177
178 // IsOriginAllowed looks up if the origin matches one of the patterns
179 // generated from Options.AllowOrigins patterns.
180 func (o *Options) IsOriginAllowed(origin string) (allowed bool) {
181 for _, pattern := range allowOriginPatterns {
182 allowed, _ = regexp.MatchString(pattern, origin)
183 if allowed {
184 return
185 }
186 }
187 return
188 }
189
190 // Allow enables CORS for requests those match the provided options.
191 func Allow(opts *Options) beego.FilterFunc {
192 // Allow default headers if nothing is specified.
193 if len(opts.AllowHeaders) == 0 {
194 opts.AllowHeaders = defaultAllowHeaders
195 }
196
197 for _, origin := range opts.AllowOrigins {
198 pattern := regexp.QuoteMeta(origin)
199 pattern = strings.Replace(pattern, "\\*", ".*", -1)
200 pattern = strings.Replace(pattern, "\\?", ".", -1)
201 allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$")
202 }
203
204 return func(ctx *context.Context) {
205 var (
206 origin = ctx.Input.Header(headerOrigin)
207 requestedMethod = ctx.Input.Header(headerRequestMethod)
208 requestedHeaders = ctx.Input.Header(headerRequestHeaders)
209 // additional headers to be added
210 // to the response.
211 headers map[string]string
212 )
213
214 if ctx.Input.Method() == "OPTIONS" &&
215 (requestedMethod != "" || requestedHeaders != "") {
216 headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders)
217 for key, value := range headers {
218 ctx.Output.Header(key, value)
219 }
220 ctx.Output.SetStatus(http.StatusOK)
221 return
222 }
223 headers = opts.Header(origin)
224
225 for key, value := range headers {
226 ctx.Output.Header(key, value)
227 }
228 }
229 }
1 // Copyright 2014 beego Author. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Package cors provides handlers to enable CORS support.
16 package cors
17
18 import (
19 "net/http"
20 "net/http/httptest"
21 "strings"
22 "testing"
23 "time"
24
25 "github.com/astaxie/beego"
26 "github.com/astaxie/beego/context"
27 )
28
29 type HttpHeaderGuardRecorder struct {
30 *httptest.ResponseRecorder
31 savedHeaderMap http.Header
32 }
33
34 func NewRecorder() *HttpHeaderGuardRecorder {
35 return &HttpHeaderGuardRecorder{httptest.NewRecorder(), nil}
36 }
37
38 func (gr *HttpHeaderGuardRecorder) WriteHeader(code int) {
39 gr.ResponseRecorder.WriteHeader(code)
40 gr.savedHeaderMap = gr.ResponseRecorder.Header()
41 }
42
43 func (gr *HttpHeaderGuardRecorder) Header() http.Header {
44 if gr.savedHeaderMap != nil {
45 // headers were written. clone so we don't get updates
46 clone := make(http.Header)
47 for k, v := range gr.savedHeaderMap {
48 clone[k] = v
49 }
50 return clone
51 } else {
52 return gr.ResponseRecorder.Header()
53 }
54 }
55
56 func Test_AllowAll(t *testing.T) {
57 recorder := httptest.NewRecorder()
58 handler := beego.NewControllerRegister()
59 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
60 AllowAllOrigins: true,
61 }))
62 handler.Any("/foo", func(ctx *context.Context) {
63 ctx.Output.SetStatus(500)
64 })
65 r, _ := http.NewRequest("PUT", "/foo", nil)
66 handler.ServeHTTP(recorder, r)
67
68 if recorder.HeaderMap.Get(headerAllowOrigin) != "*" {
69 t.Errorf("Allow-Origin header should be *")
70 }
71 }
72
73 func Test_AllowRegexMatch(t *testing.T) {
74 recorder := httptest.NewRecorder()
75 handler := beego.NewControllerRegister()
76 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
77 AllowOrigins: []string{"https://aaa.com", "https://*.foo.com"},
78 }))
79 handler.Any("/foo", func(ctx *context.Context) {
80 ctx.Output.SetStatus(500)
81 })
82 origin := "https://bar.foo.com"
83 r, _ := http.NewRequest("PUT", "/foo", nil)
84 r.Header.Add("Origin", origin)
85 handler.ServeHTTP(recorder, r)
86
87 headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
88 if headerValue != origin {
89 t.Errorf("Allow-Origin header should be %v, found %v", origin, headerValue)
90 }
91 }
92
93 func Test_AllowRegexNoMatch(t *testing.T) {
94 recorder := httptest.NewRecorder()
95 handler := beego.NewControllerRegister()
96 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
97 AllowOrigins: []string{"https://*.foo.com"},
98 }))
99 handler.Any("/foo", func(ctx *context.Context) {
100 ctx.Output.SetStatus(500)
101 })
102 origin := "https://ww.foo.com.evil.com"
103 r, _ := http.NewRequest("PUT", "/foo", nil)
104 r.Header.Add("Origin", origin)
105 handler.ServeHTTP(recorder, r)
106
107 headerValue := recorder.HeaderMap.Get(headerAllowOrigin)
108 if headerValue != "" {
109 t.Errorf("Allow-Origin header should not exist, found %v", headerValue)
110 }
111 }
112
113 func Test_OtherHeaders(t *testing.T) {
114 recorder := httptest.NewRecorder()
115 handler := beego.NewControllerRegister()
116 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
117 AllowAllOrigins: true,
118 AllowCredentials: true,
119 AllowMethods: []string{"PATCH", "GET"},
120 AllowHeaders: []string{"Origin", "X-whatever"},
121 ExposeHeaders: []string{"Content-Length", "Hello"},
122 MaxAge: 5 * time.Minute,
123 }))
124 handler.Any("/foo", func(ctx *context.Context) {
125 ctx.Output.SetStatus(500)
126 })
127 r, _ := http.NewRequest("PUT", "/foo", nil)
128 handler.ServeHTTP(recorder, r)
129
130 credentialsVal := recorder.HeaderMap.Get(headerAllowCredentials)
131 methodsVal := recorder.HeaderMap.Get(headerAllowMethods)
132 headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
133 exposedHeadersVal := recorder.HeaderMap.Get(headerExposeHeaders)
134 maxAgeVal := recorder.HeaderMap.Get(headerMaxAge)
135
136 if credentialsVal != "true" {
137 t.Errorf("Allow-Credentials is expected to be true, found %v", credentialsVal)
138 }
139
140 if methodsVal != "PATCH,GET" {
141 t.Errorf("Allow-Methods is expected to be PATCH,GET; found %v", methodsVal)
142 }
143
144 if headersVal != "Origin,X-whatever" {
145 t.Errorf("Allow-Headers is expected to be Origin,X-whatever; found %v", headersVal)
146 }
147
148 if exposedHeadersVal != "Content-Length,Hello" {
149 t.Errorf("Expose-Headers are expected to be Content-Length,Hello. Found %v", exposedHeadersVal)
150 }
151
152 if maxAgeVal != "300" {
153 t.Errorf("Max-Age is expected to be 300, found %v", maxAgeVal)
154 }
155 }
156
157 func Test_DefaultAllowHeaders(t *testing.T) {
158 recorder := httptest.NewRecorder()
159 handler := beego.NewControllerRegister()
160 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
161 AllowAllOrigins: true,
162 }))
163 handler.Any("/foo", func(ctx *context.Context) {
164 ctx.Output.SetStatus(500)
165 })
166
167 r, _ := http.NewRequest("PUT", "/foo", nil)
168 handler.ServeHTTP(recorder, r)
169
170 headersVal := recorder.HeaderMap.Get(headerAllowHeaders)
171 if headersVal != "Origin,Accept,Content-Type,Authorization" {
172 t.Errorf("Allow-Headers is expected to be Origin,Accept,Content-Type,Authorization; found %v", headersVal)
173 }
174 }
175
176 func Test_Preflight(t *testing.T) {
177 recorder := NewRecorder()
178 handler := beego.NewControllerRegister()
179 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
180 AllowAllOrigins: true,
181 AllowMethods: []string{"PUT", "PATCH"},
182 AllowHeaders: []string{"Origin", "X-whatever", "X-CaseSensitive"},
183 }))
184
185 handler.Any("/foo", func(ctx *context.Context) {
186 ctx.Output.SetStatus(200)
187 })
188
189 r, _ := http.NewRequest("OPTIONS", "/foo", nil)
190 r.Header.Add(headerRequestMethod, "PUT")
191 r.Header.Add(headerRequestHeaders, "X-whatever, x-casesensitive")
192 handler.ServeHTTP(recorder, r)
193
194 headers := recorder.Header()
195 methodsVal := headers.Get(headerAllowMethods)
196 headersVal := headers.Get(headerAllowHeaders)
197 originVal := headers.Get(headerAllowOrigin)
198
199 if methodsVal != "PUT,PATCH" {
200 t.Errorf("Allow-Methods is expected to be PUT,PATCH, found %v", methodsVal)
201 }
202
203 if !strings.Contains(headersVal, "X-whatever") {
204 t.Errorf("Allow-Headers is expected to contain X-whatever, found %v", headersVal)
205 }
206
207 if !strings.Contains(headersVal, "x-casesensitive") {
208 t.Errorf("Allow-Headers is expected to contain x-casesensitive, found %v", headersVal)
209 }
210
211 if originVal != "*" {
212 t.Errorf("Allow-Origin is expected to be *, found %v", originVal)
213 }
214
215 if recorder.Code != http.StatusOK {
216 t.Errorf("Status code is expected to be 200, found %d", recorder.Code)
217 }
218 }
219
220 func Benchmark_WithoutCORS(b *testing.B) {
221 recorder := httptest.NewRecorder()
222 handler := beego.NewControllerRegister()
223 beego.RunMode = "prod"
224 handler.Any("/foo", func(ctx *context.Context) {
225 ctx.Output.SetStatus(500)
226 })
227 b.ResetTimer()
228 for i := 0; i < 100; i++ {
229 r, _ := http.NewRequest("PUT", "/foo", nil)
230 handler.ServeHTTP(recorder, r)
231 }
232 }
233
234 func Benchmark_WithCORS(b *testing.B) {
235 recorder := httptest.NewRecorder()
236 handler := beego.NewControllerRegister()
237 beego.RunMode = "prod"
238 handler.InsertFilter("*", beego.BeforeRouter, Allow(&Options{
239 AllowAllOrigins: true,
240 AllowCredentials: true,
241 AllowMethods: []string{"PATCH", "GET"},
242 AllowHeaders: []string{"Origin", "X-whatever"},
243 MaxAge: 5 * time.Minute,
244 }))
245 handler.Any("/foo", func(ctx *context.Context) {
246 ctx.Output.SetStatus(500)
247 })
248 b.ResetTimer()
249 for i := 0; i < 100; i++ {
250 r, _ := http.NewRequest("PUT", "/foo", nil)
251 handler.ServeHTTP(recorder, r)
252 }
253 }
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!