Browse Source

Add authentication to all required endpoints

pull/2/head
Tovi Jaeschke-Rogers 3 years ago
parent
commit
d584d40a52
14 changed files with 353 additions and 113 deletions
  1. +53
    -0
      Api/Auth/ChangePassword.go
  2. +3
    -2
      Api/Auth/Login.go
  3. +30
    -2
      Api/Auth/Session.go
  4. +33
    -17
      Api/JsonSerialization/VerifyJson.go
  5. +11
    -11
      Api/PostImages.go
  6. +39
    -22
      Api/Posts.go
  7. +34
    -3
      Api/Posts_test.go
  8. +46
    -28
      Api/Users.go
  9. +90
    -14
      Api/Users_test.go
  10. +1
    -1
      Database/Users.go
  11. +4
    -4
      Util/PostHelper.go
  12. +4
    -4
      Util/PostImageHelper.go
  13. +1
    -1
      Util/ReturnJson.go
  14. +4
    -4
      Util/UserHelper.go

+ 53
- 0
Api/Auth/ChangePassword.go View File

@ -0,0 +1,53 @@
package Auth
import (
"encoding/json"
"net/http"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models"
)
type ChangePassword struct {
Password string `json:"password"`
ConfirmPassword string `json:"confirm_password"`
}
func UpdatePassword(w http.ResponseWriter, r *http.Request) {
var (
changePasswd ChangePassword
userData Models.User
err error
)
userData, err = CheckCookieCurrentUser(w, r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
err = json.NewDecoder(r.Body).Decode(&changePasswd)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
if changePasswd.Password != changePasswd.ConfirmPassword {
w.WriteHeader(http.StatusBadRequest)
return
}
userData.Password, err = HashPassword(changePasswd.Password)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
err = Database.UpdateUser(userData.ID.String(), &userData)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}

+ 3
- 2
Api/Auth/Login.go View File

@ -51,8 +51,9 @@ func Login(w http.ResponseWriter, r *http.Request) {
expiresAt = time.Now().Add(1 * time.Hour) expiresAt = time.Now().Add(1 * time.Hour)
Sessions[sessionToken.String()] = Session{ Sessions[sessionToken.String()] = Session{
Username: userData.Email,
Expiry: expiresAt,
UserID: userData.ID.String(),
Email: userData.Email,
Expiry: expiresAt,
} }
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{


+ 30
- 2
Api/Auth/Session.go View File

@ -4,6 +4,9 @@ import (
"errors" "errors"
"net/http" "net/http"
"time" "time"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Util"
) )
var ( var (
@ -11,8 +14,9 @@ var (
) )
type Session struct { type Session struct {
Username string
Expiry time.Time
UserID string
Email string
Expiry time.Time
} }
func (s Session) IsExpired() bool { func (s Session) IsExpired() bool {
@ -49,3 +53,27 @@ func CheckCookie(r *http.Request) (Session, error) {
return userSession, nil return userSession, nil
} }
func CheckCookieCurrentUser(w http.ResponseWriter, r *http.Request) (Models.User, error) {
var (
userSession Session
userData Models.User
err error
)
userSession, err = CheckCookie(r)
if err != nil {
return userData, err
}
userData, err = Util.GetUserById(w, r)
if err != nil {
return userData, err
}
if userData.ID.String() != userSession.UserID {
return userData, errors.New("Is not current user")
}
return userData, nil
}

+ 33
- 17
Api/JsonSerialization/VerifyJson.go View File

@ -7,7 +7,11 @@ import (
// isIntegerType returns whether the type is an integer and if it's unsigned. // isIntegerType returns whether the type is an integer and if it's unsigned.
// See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L328 // See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L328
func isIntegerType(t reflect.Type) (yes bool, unsigned bool) {
func isIntegerType(t reflect.Type) (bool, bool) {
var (
yes bool
unsigned bool
)
switch t.Kind() { switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
yes = true yes = true
@ -16,19 +20,22 @@ func isIntegerType(t reflect.Type) (yes bool, unsigned bool) {
unsigned = true unsigned = true
} }
return
return yes, unsigned
} }
// isFloatType returns true if the type is a floating point. Note that this doesn't // isFloatType returns true if the type is a floating point. Note that this doesn't
// care about the value -- unmarshaling the number "0" gives a float, not an int. // care about the value -- unmarshaling the number "0" gives a float, not an int.
// See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L319 // See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L319
func isFloatType(t reflect.Type) (yes bool) {
func isFloatType(t reflect.Type) bool {
var (
yes bool
)
switch t.Kind() { switch t.Kind() {
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
yes = true yes = true
} }
return
return yes
} }
// CanConvert returns whether value v is convertible to type t. // CanConvert returns whether value v is convertible to type t.
@ -38,10 +45,20 @@ func isFloatType(t reflect.Type) (yes bool) {
// Modified due to not handling slices (DefaultCanConvert fails on PhotoUrls and Tags) // Modified due to not handling slices (DefaultCanConvert fails on PhotoUrls and Tags)
// See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L191 // See: https://github.com/Kangaroux/go-map-schema/blob/master/schema.go#L191
func CanConvert(t reflect.Type, v reflect.Value) bool { func CanConvert(t reflect.Type, v reflect.Value) bool {
isPtr := t.Kind() == reflect.Ptr
isStruct := t.Kind() == reflect.Struct
isArray := t.Kind() == reflect.Array
dstType := t
var (
isPtr bool
isStruct bool
isArray bool
dstType reflect.Type
dstInt bool
unsigned bool
f float64
srcInt bool
)
isPtr = t.Kind() == reflect.Ptr
isStruct = t.Kind() == reflect.Struct
isArray = t.Kind() == reflect.Array
dstType = t
// Check if v is a nil value. // Check if v is a nil value.
if !v.IsValid() || (v.CanAddr() && v.IsNil()) { if !v.IsValid() || (v.CanAddr() && v.IsNil()) {
@ -72,20 +89,19 @@ func CanConvert(t reflect.Type, v reflect.Value) bool {
} }
// Handle converting to an integer type. // Handle converting to an integer type.
if dstInt, unsigned := isIntegerType(dstType); dstInt {
dstInt, unsigned = isIntegerType(dstType)
if dstInt {
if isFloatType(v.Type()) { if isFloatType(v.Type()) {
f := v.Float()
f = v.Float()
if math.Trunc(f) != f {
return false
} else if unsigned && f < 0 {
return false
}
} else if srcInt, _ := isIntegerType(v.Type()); srcInt {
if unsigned && v.Int() < 0 {
if math.Trunc(f) != f || unsigned && f < 0 {
return false return false
} }
} }
srcInt, _ = isIntegerType(v.Type())
if srcInt && unsigned && v.Int() < 0 {
return false
}
} }
return true return true


+ 11
- 11
Api/PostImages.go View File

@ -30,10 +30,10 @@ func createPostImage(w http.ResponseWriter, r *http.Request) {
err error err error
) )
postID, err = getPostId(r)
postID, err = Util.GetPostId(r)
if err != nil { if err != nil {
log.Printf("Error encountered getting id\n") log.Printf("Error encountered getting id\n")
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -42,7 +42,7 @@ func createPostImage(w http.ResponseWriter, r *http.Request) {
err = r.ParseMultipartForm(20 << 20) err = r.ParseMultipartForm(20 << 20)
if err != nil { if err != nil {
log.Printf("Error encountered parsing multipart form: %s\n", err.Error()) log.Printf("Error encountered parsing multipart form: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -54,7 +54,7 @@ func createPostImage(w http.ResponseWriter, r *http.Request) {
file, err = fileHeader.Open() file, err = fileHeader.Open()
if err != nil { if err != nil {
log.Printf("Error encountered while post image upload: %s\n", err.Error()) log.Printf("Error encountered while post image upload: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
defer file.Close() defer file.Close()
@ -62,14 +62,14 @@ func createPostImage(w http.ResponseWriter, r *http.Request) {
fileBytes, err = ioutil.ReadAll(file) fileBytes, err = ioutil.ReadAll(file)
if err != nil { if err != nil {
log.Printf("Error encountered while post image upload: %s\n", err.Error()) log.Printf("Error encountered while post image upload: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
fileObject, err = Util.WriteFile(fileBytes, "image") fileObject, err = Util.WriteFile(fileBytes, "image")
if err != nil { if err != nil {
log.Printf("Error encountered while post image upload: %s\n", err.Error()) log.Printf("Error encountered while post image upload: %s\n", err.Error())
JsonReturn(w, 415, "Invalid filetype")
Util.JsonReturn(w, 415, "Invalid filetype")
return return
} }
@ -83,19 +83,19 @@ func createPostImage(w http.ResponseWriter, r *http.Request) {
err = Database.CreatePostImage(&postImage) err = Database.CreatePostImage(&postImage)
if err != nil { if err != nil {
log.Printf("Error encountered while creating post_image record: %s\n", err.Error()) log.Printf("Error encountered while creating post_image record: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
} }
postData, err = getPostById(w, r)
postData, err = Util.GetPostById(w, r)
if err != nil { if err != nil {
return return
} }
returnJson, err = json.MarshalIndent(postData, "", " ") returnJson, err = json.MarshalIndent(postData, "", " ")
if err != nil { if err != nil {
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -110,7 +110,7 @@ func deletePostImage(w http.ResponseWriter, r *http.Request) {
err error err error
) )
postImageData, err = getPostImageById(w, r)
postImageData, err = Util.GetPostImageById(w, r)
if err != nil { if err != nil {
return return
} }
@ -118,7 +118,7 @@ func deletePostImage(w http.ResponseWriter, r *http.Request) {
err = Database.DeletePostImage(&postImageData) err = Database.DeletePostImage(&postImageData)
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }


+ 39
- 22
Api/Posts.go View File

@ -8,9 +8,11 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/Auth"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Util"
) )
func getPosts(w http.ResponseWriter, r *http.Request) { func getPosts(w http.ResponseWriter, r *http.Request) {
@ -27,27 +29,27 @@ func getPosts(w http.ResponseWriter, r *http.Request) {
page, err = strconv.Atoi(values.Get("page")) page, err = strconv.Atoi(values.Get("page"))
if err != nil { if err != nil {
log.Println("Could not parse page url argument") log.Println("Could not parse page url argument")
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
page, err = strconv.Atoi(values.Get("pageSize")) page, err = strconv.Atoi(values.Get("pageSize"))
if err != nil { if err != nil {
log.Println("Could not parse pageSize url argument") log.Println("Could not parse pageSize url argument")
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
posts, err = Database.GetPosts(page, pageSize) posts, err = Database.GetPosts(page, pageSize)
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
returnJson, err = json.MarshalIndent(posts, "", " ") returnJson, err = json.MarshalIndent(posts, "", " ")
if err != nil { if err != nil {
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -66,13 +68,13 @@ func getFrontPagePosts(w http.ResponseWriter, r *http.Request) {
posts, err = Database.GetFrontPagePosts() posts, err = Database.GetFrontPagePosts()
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
returnJson, err = json.MarshalIndent(posts, "", " ") returnJson, err = json.MarshalIndent(posts, "", " ")
if err != nil { if err != nil {
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -88,14 +90,14 @@ func getPost(w http.ResponseWriter, r *http.Request) {
err error err error
) )
postData, err = getPostById(w, r)
postData, err = Util.GetPostById(w, r)
if err != nil { if err != nil {
return return
} }
returnJson, err = json.MarshalIndent(postData, "", " ") returnJson, err = json.MarshalIndent(postData, "", " ")
if err != nil { if err != nil {
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -112,12 +114,16 @@ func createPost(w http.ResponseWriter, r *http.Request) {
err error err error
) )
// TODO: Add auth
_, err = Auth.CheckCookie(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
requestBody, err = ioutil.ReadAll(r.Body) requestBody, err = ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
log.Printf("Error encountered reading POST body: %s\n", err.Error()) log.Printf("Error encountered reading POST body: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -129,21 +135,20 @@ func createPost(w http.ResponseWriter, r *http.Request) {
"audios", "audios",
}, false) }, false)
if err != nil { if err != nil {
panic(err)
log.Printf("Invalid data provided to posts API: %s\n", err.Error()) log.Printf("Invalid data provided to posts API: %s\n", err.Error())
JsonReturn(w, 405, "Invalid data")
Util.JsonReturn(w, 405, "Invalid data")
return return
} }
err = Database.CreatePost(&postData) err = Database.CreatePost(&postData)
if err != nil { if err != nil {
JsonReturn(w, 405, "Invalid data")
Util.JsonReturn(w, 405, "Invalid data")
} }
returnJson, err = json.MarshalIndent(postData, "", " ") returnJson, err = json.MarshalIndent(postData, "", " ")
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -161,38 +166,44 @@ func updatePost(w http.ResponseWriter, r *http.Request) {
err error err error
) )
id, err = getPostId(r)
_, err = Auth.CheckCookie(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
id, err = Util.GetPostId(r)
if err != nil { if err != nil {
log.Printf("Error encountered getting id\n") log.Printf("Error encountered getting id\n")
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
requestBody, err = ioutil.ReadAll(r.Body) requestBody, err = ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
log.Printf("Error encountered reading POST body: %s\n", err.Error()) log.Printf("Error encountered reading POST body: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
postData, err = JsonSerialization.DeserializePost(requestBody, []string{}, true) postData, err = JsonSerialization.DeserializePost(requestBody, []string{}, true)
if err != nil { if err != nil {
log.Printf("Invalid data provided to posts API: %s\n", err.Error()) log.Printf("Invalid data provided to posts API: %s\n", err.Error())
JsonReturn(w, 405, "Invalid data")
Util.JsonReturn(w, 405, "Invalid data")
return return
} }
postData, err = Database.UpdatePost(id, &postData) postData, err = Database.UpdatePost(id, &postData)
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
returnJson, err = json.MarshalIndent(postData, "", " ") returnJson, err = json.MarshalIndent(postData, "", " ")
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -207,7 +218,13 @@ func deletePost(w http.ResponseWriter, r *http.Request) {
err error err error
) )
postData, err = getPostById(w, r)
_, err = Auth.CheckCookie(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
postData, err = Util.GetPostById(w, r)
if err != nil { if err != nil {
return return
} }
@ -215,7 +232,7 @@ func deletePost(w http.ResponseWriter, r *http.Request) {
err = Database.DeletePost(&postData) err = Database.DeletePost(&postData)
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }


+ 34
- 3
Api/Posts_test.go View File

@ -155,7 +155,11 @@ func Test_createPost(t *testing.T) {
defer ts.Close() defer ts.Close()
userData, err := createTestUser(true)
c, u, err := login()
if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error())
return
}
postJson := ` postJson := `
{ {
@ -171,14 +175,25 @@ func Test_createPost(t *testing.T) {
} }
` `
postJson = fmt.Sprintf(postJson, userData.ID.String())
postJson = fmt.Sprintf(postJson, u.ID.String())
req, err := http.NewRequest("POST", ts.URL+"/post", strings.NewReader(postJson))
if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error())
return
}
req.AddCookie(c)
res, err := http.Post(ts.URL+"/post", "application/json", strings.NewReader(postJson))
res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
return
} }
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode)
return
} }
postData := new(Models.Post) postData := new(Models.Post)
@ -204,6 +219,12 @@ func Test_deletePost(t *testing.T) {
defer ts.Close() defer ts.Close()
c, _, err := login()
if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error())
return
}
postData, err := createTestPost() postData, err := createTestPost()
if err != nil { if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
@ -220,6 +241,8 @@ func Test_deletePost(t *testing.T) {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
} }
req.AddCookie(c)
// Fetch Request // Fetch Request
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
@ -242,6 +265,12 @@ func Test_updatePost(t *testing.T) {
defer ts.Close() defer ts.Close()
c, _, err := login()
if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error())
return
}
postData, err := createTestPost() postData, err := createTestPost()
if err != nil { if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
@ -265,6 +294,8 @@ func Test_updatePost(t *testing.T) {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
} }
req.AddCookie(c)
// Fetch Request // Fetch Request
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {


+ 46
- 28
Api/Users.go View File

@ -12,6 +12,7 @@ import (
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Api/JsonSerialization"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Database"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models" "git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Models"
"git.tovijaeschke.xyz/tovi/SuddenImpactRecords/Util"
) )
func getUsers(w http.ResponseWriter, r *http.Request) { func getUsers(w http.ResponseWriter, r *http.Request) {
@ -23,32 +24,37 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
err error err error
) )
_, err = Auth.CheckCookie(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
values = r.URL.Query() values = r.URL.Query()
page, err = strconv.Atoi(values.Get("page")) page, err = strconv.Atoi(values.Get("page"))
if err != nil { if err != nil {
log.Println("Could not parse page url argument") log.Println("Could not parse page url argument")
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
page, err = strconv.Atoi(values.Get("pageSize")) page, err = strconv.Atoi(values.Get("pageSize"))
if err != nil { if err != nil {
log.Println("Could not parse pageSize url argument") log.Println("Could not parse pageSize url argument")
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
users, err = Database.GetUsers(page, pageSize) users, err = Database.GetUsers(page, pageSize)
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
returnJson, err = json.MarshalIndent(users, "", " ") returnJson, err = json.MarshalIndent(users, "", " ")
if err != nil { if err != nil {
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -64,14 +70,20 @@ func getUser(w http.ResponseWriter, r *http.Request) {
err error err error
) )
userData, err = getUserById(w, r)
_, err = Auth.CheckCookie(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
userData, err = Util.GetUserById(w, r)
if err != nil { if err != nil {
return return
} }
returnJson, err = json.MarshalIndent(userData, "", " ") returnJson, err = json.MarshalIndent(userData, "", " ")
if err != nil { if err != nil {
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -90,7 +102,7 @@ func createUser(w http.ResponseWriter, r *http.Request) {
requestBody, err = ioutil.ReadAll(r.Body) requestBody, err = ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
log.Printf("Error encountered reading POST body: %s\n", err.Error()) log.Printf("Error encountered reading POST body: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -100,30 +112,30 @@ func createUser(w http.ResponseWriter, r *http.Request) {
}, false) }, false)
if err != nil { if err != nil {
log.Printf("Invalid data provided to user API: %s\n", err.Error()) log.Printf("Invalid data provided to user API: %s\n", err.Error())
JsonReturn(w, 405, "Invalid data")
Util.JsonReturn(w, 405, "Invalid data")
return return
} }
err = Database.CheckUniqueEmail(userData.Email) err = Database.CheckUniqueEmail(userData.Email)
if err != nil { if err != nil {
JsonReturn(w, 405, "invalid_email")
Util.JsonReturn(w, 405, "invalid_email")
return return
} }
if userData.Password != userData.ConfirmPassword { if userData.Password != userData.ConfirmPassword {
JsonReturn(w, 405, "invalid_password")
Util.JsonReturn(w, 405, "invalid_password")
return return
} }
userData.Password, err = Auth.HashPassword(userData.Password) userData.Password, err = Auth.HashPassword(userData.Password)
if err != nil { if err != nil {
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
err = Database.CreateUser(&userData) err = Database.CreateUser(&userData)
if err != nil { if err != nil {
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -133,45 +145,44 @@ func createUser(w http.ResponseWriter, r *http.Request) {
func updateUser(w http.ResponseWriter, r *http.Request) { func updateUser(w http.ResponseWriter, r *http.Request) {
var ( var (
userData Models.User
requestBody []byte
returnJson []byte
id string
err error
currentUserData Models.User
userData Models.User
requestBody []byte
returnJson []byte
err error
) )
id, err = getUserId(r)
currentUserData, err = Auth.CheckCookieCurrentUser(w, r)
if err != nil { if err != nil {
log.Printf("Error encountered getting id\n")
JsonReturn(w, 500, "An error occured")
w.WriteHeader(http.StatusUnauthorized)
return return
} }
requestBody, err = ioutil.ReadAll(r.Body) requestBody, err = ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
log.Printf("Error encountered reading POST body: %s\n", err.Error()) log.Printf("Error encountered reading POST body: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
userData, err = JsonSerialization.DeserializeUser(requestBody, []string{}, true) userData, err = JsonSerialization.DeserializeUser(requestBody, []string{}, true)
if err != nil { if err != nil {
log.Printf("Invalid data provided to users API: %s\n", err.Error()) log.Printf("Invalid data provided to users API: %s\n", err.Error())
JsonReturn(w, 405, "Invalid data")
Util.JsonReturn(w, 405, "Invalid data")
return return
} }
err = Database.UpdateUser(id, &userData)
err = Database.UpdateUser(currentUserData.ID.String(), &userData)
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
returnJson, err = json.MarshalIndent(userData, "", " ") returnJson, err = json.MarshalIndent(userData, "", " ")
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }
@ -186,15 +197,22 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
err error err error
) )
userData, err = getUserById(w, r)
_, err = Auth.CheckCookie(r)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
userData, err = Util.GetUserById(w, r)
if err != nil { if err != nil {
w.WriteHeader(http.StatusNotFound)
return return
} }
err = Database.DeleteUser(&userData) err = Database.DeleteUser(&userData)
if err != nil { if err != nil {
log.Printf("An error occured: %s\n", err.Error()) log.Printf("An error occured: %s\n", err.Error())
JsonReturn(w, 500, "An error occured")
Util.JsonReturn(w, 500, "An error occured")
return return
} }


+ 90
- 14
Api/Users_test.go View File

@ -2,6 +2,7 @@ package Api
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net/http" "net/http"
@ -68,6 +69,47 @@ func createTestUser(random bool) (Models.User, error) {
return userData, err return userData, err
} }
func login() (*http.Cookie, Models.User, error) {
var (
c *http.Cookie
u Models.User
)
r.HandleFunc("/admin/login", Auth.Login).Methods("POST")
ts := httptest.NewServer(r)
defer ts.Close()
u, err := createTestUser(true)
if err != nil {
return c, u, err
}
postJson := `
{
"email": "%s",
"password": "password"
}
`
postJson = fmt.Sprintf(postJson, u.Email)
res, err := http.Post(ts.URL+"/admin/login", "application/json", strings.NewReader(postJson))
if err != nil {
return c, u, err
}
if res.StatusCode != http.StatusOK {
return c, u, errors.New("Invalid res.StatusCode")
}
if len(res.Cookies()) != 1 {
return c, u, errors.New("Invalid cookies length")
}
return res.Cookies()[0], u, nil
}
func Test_getUser(t *testing.T) { func Test_getUser(t *testing.T) {
t.Log("Testing getUser...") t.Log("Testing getUser...")
@ -77,22 +119,31 @@ func Test_getUser(t *testing.T) {
defer ts.Close() defer ts.Close()
userData, err := createTestUser(false)
c, u, err := login()
if err != nil { if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
t.FailNow() t.FailNow()
} }
res, err := http.Get(fmt.Sprintf(
req, err := http.NewRequest("GET", fmt.Sprintf(
"%s/user/%s", "%s/user/%s",
ts.URL, ts.URL,
userData.ID,
))
u.ID,
), nil)
if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error())
t.FailNow()
}
req.AddCookie(c)
res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
t.FailNow() t.FailNow()
} }
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode)
t.FailNow() t.FailNow()
@ -105,18 +156,18 @@ func Test_getUser(t *testing.T) {
t.FailNow() t.FailNow()
} }
if getUserData.Email != "email@email.com" {
t.Errorf("Expected email \"email@email.com\", recieved %s", getUserData.Email)
if getUserData.Email != u.Email {
t.Errorf("Expected email \"%s\", recieved %s", u.Email, getUserData.Email)
t.FailNow() t.FailNow()
} }
if getUserData.FirstName != "Hugh" {
t.Errorf("Expected email \"Hugh\", recieved %s", getUserData.FirstName)
if getUserData.FirstName != u.FirstName {
t.Errorf("Expected email \"%s\", recieved %s", u.FirstName, getUserData.FirstName)
t.FailNow() t.FailNow()
} }
if getUserData.LastName != "Mann" {
t.Errorf("Expected email \"Mann\", recieved %s", getUserData.LastName)
if getUserData.LastName != u.LastName {
t.Errorf("Expected email \"%s\", recieved %s", u.LastName, getUserData.LastName)
t.FailNow() t.FailNow()
} }
} }
@ -129,17 +180,31 @@ func Test_getUsers(t *testing.T) {
ts := httptest.NewServer(r) ts := httptest.NewServer(r)
defer ts.Close() defer ts.Close()
var err error
c, _, err := login()
if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error())
t.FailNow()
}
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
createTestUser(true) createTestUser(true)
} }
res, err := http.Get(ts.URL + "/user?page=1&pageSize=10")
req, err := http.NewRequest("GET", ts.URL+"/user?page=1&pageSize=10", nil)
if err != nil { if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
t.FailNow() t.FailNow()
} }
req.AddCookie(c)
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error())
t.FailNow()
}
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode) t.Errorf("Expected %d, recieved %d", http.StatusOK, res.StatusCode)
t.FailNow() t.FailNow()
@ -201,9 +266,10 @@ func Test_updateUser(t *testing.T) {
defer ts.Close() defer ts.Close()
userData, err := createTestUser(true)
c, u, err := login()
if err != nil { if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
t.FailNow()
} }
email := fmt.Sprintf("%s@email.com", randString(16)) email := fmt.Sprintf("%s@email.com", randString(16))
@ -220,13 +286,15 @@ func Test_updateUser(t *testing.T) {
req, err := http.NewRequest("PUT", fmt.Sprintf( req, err := http.NewRequest("PUT", fmt.Sprintf(
"%s/user/%s", "%s/user/%s",
ts.URL, ts.URL,
userData.ID,
u.ID,
), strings.NewReader(postJson)) ), strings.NewReader(postJson))
if err != nil { if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
} }
req.AddCookie(c)
// Fetch Request // Fetch Request
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
@ -266,6 +334,12 @@ func Test_deleteUser(t *testing.T) {
defer ts.Close() defer ts.Close()
c, _, err := login()
if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error())
t.FailNow()
}
userData, err := createTestUser(true) userData, err := createTestUser(true)
if err != nil { if err != nil {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
@ -282,6 +356,8 @@ func Test_deleteUser(t *testing.T) {
t.Errorf("Expected nil, recieved %s", err.Error()) t.Errorf("Expected nil, recieved %s", err.Error())
} }
req.AddCookie(c)
// Fetch Request // Fetch Request
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {


+ 1
- 1
Database/Users.go View File

@ -108,7 +108,7 @@ func UpdateUser(id string, userData *Models.User) error {
var ( var (
err error err error
) )
err = DB.Model(&Models.Post{}).
err = DB.Model(&Models.User{}).
Select("*"). Select("*").
Omit("id", "created_at", "updated_at", "deleted_at"). Omit("id", "created_at", "updated_at", "deleted_at").
Where("id = ?", id). Where("id = ?", id).


Api/PostHelper.go → Util/PostHelper.go View File


Api/PostImageHelper.go → Util/PostImageHelper.go View File


Api/ReturnJson.go → Util/ReturnJson.go View File


Api/UserHelper.go → Util/UserHelper.go View File


Loading…
Cancel
Save