Skip to content

Commit

Permalink
Merge pull request #14 from codingpot/feat-add-method-get
Browse files Browse the repository at this point in the history
feat: add MethodGet(methodID)
  • Loading branch information
kkweon authored Jul 1, 2021
2 parents 725a9d7 + d783bf8 commit 9e58638
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
6 changes: 6 additions & 0 deletions internal/testutils/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ func MustExtractAPITokenFromEnv() string {
}
return apiToken
}


// ToStringPtr returns a pointer to the given string.
func ToStringPtr(s string) *string {
return &s
}
26 changes: 26 additions & 0 deletions method_get.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package paperswithcode_go

import (
"encoding/json"
"github.com/codingpot/paperswithcode-go/v2/models"
)

// MethodGet returns a method in a paper.
// See https://paperswithcode-client.readthedocs.io/en/latest/api/client.html#paperswithcode.client.PapersWithCodeClient.method_list
func (c *Client) MethodGet(methodID string) (*models.Method, error) {
url := c.baseURL + "/methods/" + methodID

response, err := c.httpClient.Get(url)
if err != nil {
return nil, err
}

var result models.Method

err = json.NewDecoder(response.Body).Decode(&result)
if err != nil {
return nil, err
}

return &result, nil
}
43 changes: 43 additions & 0 deletions method_get_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package paperswithcode_go

import (
"github.com/codingpot/paperswithcode-go/v2/internal/testutils"
"github.com/codingpot/paperswithcode-go/v2/models"
"github.com/stretchr/testify/assert"
"testing"
)

func TestClient_MethodGet(t *testing.T) {
tests := []struct {
name string
methodID string
want *models.Method
wantErr bool
}{
{
name: "With a correct methodID, it returns a method",
methodID: "multi-head-attention",
want: &models.Method{
ID: "multi-head-attention",
Name: "Multi-Head Attention",
FullName: "Multi-Head Attention",
Description: "**Multi-head Attention** is a module for attention mechanisms which runs through an attention mechanism several times in parallel. The independent attention outputs are then concatenated and linearly transformed into the expected dimension. Intuitively, multiple attention heads allows for attending to parts of the sequence differently (e.g. longer-term dependencies versus shorter-term dependencies). \r\n\r\n$$ \\text{MultiHead}\\left(\\textbf{Q}, \\textbf{K}, \\textbf{V}\\right) = \\left[\\text{head}\\_{1},\\dots,\\text{head}\\_{h}\\right]\\textbf{W}_{0}$$\r\n\r\n$$\\text{where} \\text{ head}\\_{i} = \\text{Attention} \\left(\\textbf{Q}\\textbf{W}\\_{i}^{Q}, \\textbf{K}\\textbf{W}\\_{i}^{K}, \\textbf{V}\\textbf{W}\\_{i}^{V} \\right) $$\r\n\r\nAbove $\\textbf{W}$ are all learnable parameter matrices.\r\n\r\nNote that [scaled dot-product attention](https://paperswithcode.com/method/scaled) is most commonly used in this module, although in principle it can be swapped out for other types of attention mechanism.\r\n\r\nSource: [Lilian Weng](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html#a-family-of-attention-mechanisms)",
Paper: testutils.ToStringPtr("attention-is-all-you-need"),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewClient()
got, err := c.MethodGet(tt.methodID)
if tt.wantErr {
assert.Error(t, err)
} else {

assert.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
}
}

0 comments on commit 9e58638

Please sign in to comment.