diff options
| author | Alexander Kavon <me+git@alexkavon.com> | 2024-01-22 17:56:22 -0500 |
|---|---|---|
| committer | Alexander Kavon <me+git@alexkavon.com> | 2024-01-22 17:56:22 -0500 |
| commit | 66d84b2b49f55e6c652816466a5c3b4202234134 (patch) | |
| tree | 710c7121d361e3232901d8b9538034e1409b7ec4 | |
| parent | 28e420d7f8ce5dc622c5e0e28a10654f48fcbca4 (diff) | |
added login authentication route logic, added compareSecretToHash, decodeHash, hashArgon2, renamed HashSecret to hashSecret
| -rw-r--r-- | src/user/hooks.go | 8 | ||||
| -rw-r--r-- | src/user/routes.go | 16 | ||||
| -rw-r--r-- | src/user/secret.go | 122 |
3 files changed, 121 insertions, 25 deletions
diff --git a/src/user/hooks.go b/src/user/hooks.go index 1552760..61e24bc 100644 --- a/src/user/hooks.go +++ b/src/user/hooks.go @@ -4,23 +4,21 @@ import ( "context" validation "github.com/go-ozzo/ozzo-validation/v4" - "github.com/go-ozzo/ozzo-validation/v4/is" "github.com/volatiletech/sqlboiler/v4/boil" "gitlab.com/alexkavon/newsstand/src/models" ) func init() { - models.AddUserHook(boil.BeforeInsertHook, validate) + models.AddUserHook(boil.BeforeInsertHook, validateNew) // should always be last models.AddUserHook(boil.BeforeInsertHook, hashSecretBeforeInsert) } -func validate(ctx context.Context, exec boil.ContextExecutor, u *models.User) error { +func validateNew(ctx context.Context, exec boil.ContextExecutor, u *models.User) error { // validate user err := validation.ValidateStruct(u, validation.Field(&u.Username, validation.Required, validation.Length(3, 50)), validation.Field(&u.Secret, validation.Required, validation.Length(8, 128)), - validation.Field(&u.Email, validation.Required, is.Email), ) if err != nil { return err @@ -30,7 +28,7 @@ func validate(ctx context.Context, exec boil.ContextExecutor, u *models.User) er } func hashSecretBeforeInsert(ctx context.Context, exec boil.ContextExecutor, u *models.User) error { - hashed, err := HashSecret(u.Secret) + hashed, err := hashSecret(u.Secret) if err != nil { return err } diff --git a/src/user/routes.go b/src/user/routes.go index 7cbd3fb..5399418 100644 --- a/src/user/routes.go +++ b/src/user/routes.go @@ -92,10 +92,26 @@ func LoginForm(s *server.Server) http.HandlerFunc { func Login(s *server.Server) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() // look up the user from the db + user, err := models.Users(models.UserWhere.Username.EQ(r.PostFormValue("username"))).One(r.Context(), s.Db.ToSqlDb()) + if err != nil { + log.Fatal(err) + } + // hash the form secret // compare form hash to db hash + valid, err := compareSecretToHash(r.PostFormValue("secret"), user.Secret) + if err != nil { + log.Fatal(err) + } + if !valid { + log.Fatal("Incorrect login credentials TODO resolve with compareSecretToHash err") + } + // login or dont + sessions.NewSession(w, sessions.SessionValues{"uid": user.ID, "username": user.Username}) + http.Redirect(w, r, "/u/me", http.StatusSeeOther) } } diff --git a/src/user/secret.go b/src/user/secret.go index b6382fe..a777072 100644 --- a/src/user/secret.go +++ b/src/user/secret.go @@ -2,45 +2,127 @@ package user import ( "crypto/rand" + "crypto/subtle" "encoding/base64" + "errors" "fmt" + "strings" "golang.org/x/crypto/argon2" ) -func HashSecret(secret string) (string, error) { - hashconf := &struct { - memory uint32 - iterations uint32 - parallelism uint8 - keyLength uint32 - saltLength uint32 - }{64 * 1024, 3, 2, 12, 16} - salt := make([]byte, hashconf.saltLength) +var ( + errInvalidHash = errors.New("provided hash is wrong format") + errHashesNotEqual = errors.New("secret is not equal to encoded hash") + errIncorrectVersion = errors.New("incorrect version of Argon2") +) + +type hashconf struct { + memory uint32 + iterations uint32 + parallelism uint8 + keyLength uint32 + saltLength uint32 +} + +func hashSecret(secret string) (string, error) { + hc := &hashconf{ + memory: 64 * 1024, + iterations: 3, + parallelism: 2, + keyLength: 12, + saltLength: 16, + } + salt := make([]byte, hc.saltLength) _, err := rand.Read(salt) if err != nil { return "", err } - hash := argon2.IDKey( - []byte(secret), - salt, - hashconf.iterations, - hashconf.memory, - hashconf.parallelism, - hashconf.keyLength, - ) + hash := hashArgon2(secret, salt, hc) b64Salt := base64.RawStdEncoding.EncodeToString(salt) b64Hash := base64.RawStdEncoding.EncodeToString(hash) encodedHash := fmt.Sprintf( "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, - hashconf.memory, - hashconf.iterations, - hashconf.parallelism, + hc.memory, + hc.iterations, + hc.parallelism, b64Salt, b64Hash, ) return encodedHash, nil } + +func compareSecretToHash(secret, encoded string) (bool, error) { + // decode the encoded hash + hc, salt, comparehash, err := decodeHash(encoded) + if err != nil { + return false, err + } + + // encode the secret + verifyhash := hashArgon2(secret, salt, hc) + + // compare the hashes using constant time comparison + // to prevent timing attacks. if not equal, then return false + if subtle.ConstantTimeCompare(comparehash, verifyhash) != 1 { + return false, errHashesNotEqual + } + + return true, nil +} + +func hashArgon2(secret string, salt []byte, hc *hashconf) []byte { + hash := argon2.IDKey( + []byte(secret), + salt, + hc.iterations, + hc.memory, + hc.parallelism, + hc.keyLength, + ) + return hash +} + +func decodeHash(encoded string) (hc *hashconf, salt, decodedhash []byte, err error) { + params := strings.Split(encoded, "$") + // check we have enough params + if len(params) != 6 { + return nil, nil, nil, errInvalidHash + } + + // check the argon2 version matches + var version int + _, err = fmt.Sscanf(params[2], "v=%d", &version) + if err != nil { + return nil, nil, nil, err + } + if version != argon2.Version { + return nil, nil, nil, errIncorrectVersion + } + + // parse hashconf params to be returned + hc = &hashconf{} + _, err = fmt.Sscanf(params[3], "m=%d,t=%d,p=%d", &hc.memory, &hc.iterations, &hc.parallelism) + if err != nil { + return nil, nil, nil, err + } + + // decode the salt + salt, err = base64.RawStdEncoding.Strict().DecodeString(params[4]) + if err != nil { + return nil, nil, nil, err + } + hc.saltLength = uint32(len(salt)) + + // decode the hash + decodedhash, err = base64.RawStdEncoding.Strict().DecodeString(params[5]) + if err != nil { + return nil, nil, nil, err + } + hc.keyLength = uint32(len(decodedhash)) + + return hc, salt, decodedhash, nil +} |
