SimpleTutorialHosting/internal/auth/samlAuth.go
Mitchell Thompson 56635fc145 first commit
2024-10-27 07:12:56 -04:00

126 lines
2.8 KiB
Go

package auth
import (
"SimpleTutorialHosting/internal/models"
"context"
"crypto/rsa"
"errors"
"fmt"
"github.com/crewjam/saml/samlsp"
"net/http"
"net/url"
)
type SAMLAuthenticator struct {
middleware *samlsp.Middleware
adminGroup string
viewerGroup string
}
func NewSAMLAuthenticator(cfg *models.SAMLConfig) (*SAMLAuthenticator, error) {
if cfg == nil {
return nil, errors.New("saml config is nil")
}
rootURL, err := url.Parse(cfg.RootURL)
if err != nil {
return nil, fmt.Errorf("invalid root url: %w", err)
}
idpMetadata, err := samlsp.ParseMetadata(cfg.IDPMetadata)
if err != nil {
return nil, fmt.Errorf("invalid idp metadata: %w", err)
}
opts := samlsp.Options{
URL: *rootURL,
Key: cfg.KeyPair.PrivateKey.(*rsa.PrivateKey),
Certificate: cfg.KeyPair.Leaf,
IDPMetadata: idpMetadata,
SignRequest: true,
AllowIDPInitiated: true,
}
middleware, err := samlsp.New(opts)
if err != nil {
return nil, err
}
return &SAMLAuthenticator{
middleware: middleware,
adminGroup: cfg.AdminGroup,
viewerGroup: cfg.ViewerGroup,
}, nil
}
func (a *SAMLAuthenticator) GetMiddleware() http.Handler {
mux := http.NewServeMux()
// Handle SAML routes
mux.HandleFunc("/saml/metadata", func(w http.ResponseWriter, r *http.Request) {
a.middleware.ServeMetadata(w, r)
})
mux.HandleFunc("/saml/acs", func(w http.ResponseWriter, r *http.Request) {
a.middleware.ServeACS(w, r)
})
mux.HandleFunc("/saml/sso", func(w http.ResponseWriter, r *http.Request) {
a.middleware.ServeHTTP(w, r)
})
return mux
}
func (a *SAMLAuthenticator) Authenticate(next http.Handler) http.Handler {
return a.middleware.RequireAccount(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, err := a.GetUser(r)
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), UserContextKey, user)
next.ServeHTTP(w, r.WithContext(ctx))
}),
)
}
func (a *SAMLAuthenticator) GetUser(r *http.Request) (*models.User, error) {
session := samlsp.SessionFromContext(r.Context())
if session == nil {
return nil, errors.New("no session in context")
}
groups := samlsp.AttributeFromContext(r.Context(), "groups")
var role models.Role
switch groups {
case a.adminGroup:
role = models.RoleAdmin
case a.viewerGroup:
role = models.RoleViewer
}
if role == "" {
return nil, errors.New("user has no valid role")
}
email := samlsp.AttributeFromContext(r.Context(), "email")
if email == "" {
return nil, errors.New("no email attribute")
}
displayName := samlsp.AttributeFromContext(r.Context(), "displayName")
if displayName == "" {
return nil, errors.New("no displayName attribute")
}
return &models.User{
ID: email,
Username: displayName,
Role: role,
}, nil
}