aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexander Kavon <me+git@alexkavon.com>2024-01-22 17:56:22 -0500
committerAlexander Kavon <me+git@alexkavon.com>2024-01-22 17:56:22 -0500
commit66d84b2b49f55e6c652816466a5c3b4202234134 (patch)
tree710c7121d361e3232901d8b9538034e1409b7ec4
parent28e420d7f8ce5dc622c5e0e28a10654f48fcbca4 (diff)
added login authentication route logic, added compareSecretToHash, decodeHash, hashArgon2, renamed HashSecret to hashSecret
-rw-r--r--src/user/hooks.go8
-rw-r--r--src/user/routes.go16
-rw-r--r--src/user/secret.go122
3 files changed, 121 insertions, 25 deletions
diff --git a/src/user/hooks.go b/src/user/hooks.go
index 1552760..61e24bc 100644
--- a/src/user/hooks.go
+++ b/src/user/hooks.go
@@ -4,23 +4,21 @@ import (
"context"
validation "github.com/go-ozzo/ozzo-validation/v4"
- "github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/volatiletech/sqlboiler/v4/boil"
"gitlab.com/alexkavon/newsstand/src/models"
)
func init() {
- models.AddUserHook(boil.BeforeInsertHook, validate)
+ models.AddUserHook(boil.BeforeInsertHook, validateNew)
// should always be last
models.AddUserHook(boil.BeforeInsertHook, hashSecretBeforeInsert)
}
-func validate(ctx context.Context, exec boil.ContextExecutor, u *models.User) error {
+func validateNew(ctx context.Context, exec boil.ContextExecutor, u *models.User) error {
// validate user
err := validation.ValidateStruct(u,
validation.Field(&u.Username, validation.Required, validation.Length(3, 50)),
validation.Field(&u.Secret, validation.Required, validation.Length(8, 128)),
- validation.Field(&u.Email, validation.Required, is.Email),
)
if err != nil {
return err
@@ -30,7 +28,7 @@ func validate(ctx context.Context, exec boil.ContextExecutor, u *models.User) er
}
func hashSecretBeforeInsert(ctx context.Context, exec boil.ContextExecutor, u *models.User) error {
- hashed, err := HashSecret(u.Secret)
+ hashed, err := hashSecret(u.Secret)
if err != nil {
return err
}
diff --git a/src/user/routes.go b/src/user/routes.go
index 7cbd3fb..5399418 100644
--- a/src/user/routes.go
+++ b/src/user/routes.go
@@ -92,10 +92,26 @@ func LoginForm(s *server.Server) http.HandlerFunc {
func Login(s *server.Server) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
+ r.ParseForm()
// look up the user from the db
+ user, err := models.Users(models.UserWhere.Username.EQ(r.PostFormValue("username"))).One(r.Context(), s.Db.ToSqlDb())
+ if err != nil {
+ log.Fatal(err)
+ }
+
// hash the form secret
// compare form hash to db hash
+ valid, err := compareSecretToHash(r.PostFormValue("secret"), user.Secret)
+ if err != nil {
+ log.Fatal(err)
+ }
+ if !valid {
+ log.Fatal("Incorrect login credentials TODO resolve with compareSecretToHash err")
+ }
+
// login or dont
+ sessions.NewSession(w, sessions.SessionValues{"uid": user.ID, "username": user.Username})
+ http.Redirect(w, r, "/u/me", http.StatusSeeOther)
}
}
diff --git a/src/user/secret.go b/src/user/secret.go
index b6382fe..a777072 100644
--- a/src/user/secret.go
+++ b/src/user/secret.go
@@ -2,45 +2,127 @@ package user
import (
"crypto/rand"
+ "crypto/subtle"
"encoding/base64"
+ "errors"
"fmt"
+ "strings"
"golang.org/x/crypto/argon2"
)
-func HashSecret(secret string) (string, error) {
- hashconf := &struct {
- memory uint32
- iterations uint32
- parallelism uint8
- keyLength uint32
- saltLength uint32
- }{64 * 1024, 3, 2, 12, 16}
- salt := make([]byte, hashconf.saltLength)
+var (
+ errInvalidHash = errors.New("provided hash is wrong format")
+ errHashesNotEqual = errors.New("secret is not equal to encoded hash")
+ errIncorrectVersion = errors.New("incorrect version of Argon2")
+)
+
+type hashconf struct {
+ memory uint32
+ iterations uint32
+ parallelism uint8
+ keyLength uint32
+ saltLength uint32
+}
+
+func hashSecret(secret string) (string, error) {
+ hc := &hashconf{
+ memory: 64 * 1024,
+ iterations: 3,
+ parallelism: 2,
+ keyLength: 12,
+ saltLength: 16,
+ }
+ salt := make([]byte, hc.saltLength)
_, err := rand.Read(salt)
if err != nil {
return "", err
}
- hash := argon2.IDKey(
- []byte(secret),
- salt,
- hashconf.iterations,
- hashconf.memory,
- hashconf.parallelism,
- hashconf.keyLength,
- )
+ hash := hashArgon2(secret, salt, hc)
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
encodedHash := fmt.Sprintf(
"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version,
- hashconf.memory,
- hashconf.iterations,
- hashconf.parallelism,
+ hc.memory,
+ hc.iterations,
+ hc.parallelism,
b64Salt,
b64Hash,
)
return encodedHash, nil
}
+
+func compareSecretToHash(secret, encoded string) (bool, error) {
+ // decode the encoded hash
+ hc, salt, comparehash, err := decodeHash(encoded)
+ if err != nil {
+ return false, err
+ }
+
+ // encode the secret
+ verifyhash := hashArgon2(secret, salt, hc)
+
+ // compare the hashes using constant time comparison
+ // to prevent timing attacks. if not equal, then return false
+ if subtle.ConstantTimeCompare(comparehash, verifyhash) != 1 {
+ return false, errHashesNotEqual
+ }
+
+ return true, nil
+}
+
+func hashArgon2(secret string, salt []byte, hc *hashconf) []byte {
+ hash := argon2.IDKey(
+ []byte(secret),
+ salt,
+ hc.iterations,
+ hc.memory,
+ hc.parallelism,
+ hc.keyLength,
+ )
+ return hash
+}
+
+func decodeHash(encoded string) (hc *hashconf, salt, decodedhash []byte, err error) {
+ params := strings.Split(encoded, "$")
+ // check we have enough params
+ if len(params) != 6 {
+ return nil, nil, nil, errInvalidHash
+ }
+
+ // check the argon2 version matches
+ var version int
+ _, err = fmt.Sscanf(params[2], "v=%d", &version)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ if version != argon2.Version {
+ return nil, nil, nil, errIncorrectVersion
+ }
+
+ // parse hashconf params to be returned
+ hc = &hashconf{}
+ _, err = fmt.Sscanf(params[3], "m=%d,t=%d,p=%d", &hc.memory, &hc.iterations, &hc.parallelism)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ // decode the salt
+ salt, err = base64.RawStdEncoding.Strict().DecodeString(params[4])
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ hc.saltLength = uint32(len(salt))
+
+ // decode the hash
+ decodedhash, err = base64.RawStdEncoding.Strict().DecodeString(params[5])
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ hc.keyLength = uint32(len(decodedhash))
+
+ return hc, salt, decodedhash, nil
+}