126 lines
2.8 KiB
Go
126 lines
2.8 KiB
Go
package auth
|
|
|
|
import (
|
|
"SimpleTutorialHosting/internal/models"
|
|
"context"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/crewjam/saml/samlsp"
|
|
"net/http"
|
|
"net/url"
|
|
)
|
|
|
|
type SAMLAuthenticator struct {
|
|
middleware *samlsp.Middleware
|
|
adminGroup string
|
|
viewerGroup string
|
|
}
|
|
|
|
func NewSAMLAuthenticator(cfg *models.SAMLConfig) (*SAMLAuthenticator, error) {
|
|
if cfg == nil {
|
|
return nil, errors.New("saml config is nil")
|
|
}
|
|
|
|
rootURL, err := url.Parse(cfg.RootURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid root url: %w", err)
|
|
}
|
|
|
|
idpMetadata, err := samlsp.ParseMetadata(cfg.IDPMetadata)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid idp metadata: %w", err)
|
|
}
|
|
|
|
opts := samlsp.Options{
|
|
URL: *rootURL,
|
|
Key: cfg.KeyPair.PrivateKey.(*rsa.PrivateKey),
|
|
Certificate: cfg.KeyPair.Leaf,
|
|
IDPMetadata: idpMetadata,
|
|
SignRequest: true,
|
|
AllowIDPInitiated: true,
|
|
}
|
|
|
|
middleware, err := samlsp.New(opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &SAMLAuthenticator{
|
|
middleware: middleware,
|
|
adminGroup: cfg.AdminGroup,
|
|
viewerGroup: cfg.ViewerGroup,
|
|
}, nil
|
|
}
|
|
|
|
func (a *SAMLAuthenticator) GetMiddleware() http.Handler {
|
|
mux := http.NewServeMux()
|
|
|
|
// Handle SAML routes
|
|
mux.HandleFunc("/saml/metadata", func(w http.ResponseWriter, r *http.Request) {
|
|
a.middleware.ServeMetadata(w, r)
|
|
})
|
|
|
|
mux.HandleFunc("/saml/acs", func(w http.ResponseWriter, r *http.Request) {
|
|
a.middleware.ServeACS(w, r)
|
|
})
|
|
|
|
mux.HandleFunc("/saml/sso", func(w http.ResponseWriter, r *http.Request) {
|
|
a.middleware.ServeHTTP(w, r)
|
|
})
|
|
|
|
return mux
|
|
}
|
|
|
|
func (a *SAMLAuthenticator) Authenticate(next http.Handler) http.Handler {
|
|
return a.middleware.RequireAccount(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
user, err := a.GetUser(r)
|
|
if err != nil {
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), UserContextKey, user)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}),
|
|
)
|
|
}
|
|
|
|
func (a *SAMLAuthenticator) GetUser(r *http.Request) (*models.User, error) {
|
|
session := samlsp.SessionFromContext(r.Context())
|
|
if session == nil {
|
|
return nil, errors.New("no session in context")
|
|
}
|
|
|
|
groups := samlsp.AttributeFromContext(r.Context(), "groups")
|
|
|
|
var role models.Role
|
|
switch groups {
|
|
case a.adminGroup:
|
|
role = models.RoleAdmin
|
|
case a.viewerGroup:
|
|
role = models.RoleViewer
|
|
}
|
|
|
|
if role == "" {
|
|
return nil, errors.New("user has no valid role")
|
|
}
|
|
|
|
email := samlsp.AttributeFromContext(r.Context(), "email")
|
|
if email == "" {
|
|
return nil, errors.New("no email attribute")
|
|
}
|
|
|
|
displayName := samlsp.AttributeFromContext(r.Context(), "displayName")
|
|
if displayName == "" {
|
|
return nil, errors.New("no displayName attribute")
|
|
}
|
|
|
|
return &models.User{
|
|
ID: email,
|
|
Username: displayName,
|
|
Role: role,
|
|
}, nil
|
|
}
|