Skip to content

Commit

Permalink
feat: add MethodGet(methodID)
Browse files Browse the repository at this point in the history
- It will be used to query each method
  • Loading branch information
kkweon committed Jun 13, 2021
1 parent 87730c8 commit d783bf8
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
6 changes: 6 additions & 0 deletions internal/testutils/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ func MustExtractAPITokenFromEnv() string {
}
return apiToken
}


// ToStringPtr returns a pointer to the given string.
func ToStringPtr(s string) *string {
return &s
}
26 changes: 26 additions & 0 deletions method_get.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package paperswithcode_go

import (
"encoding/json"
"github.com/codingpot/paperswithcode-go/v2/models"
)

// MethodGet returns a method in a paper.
// See https://paperswithcode-client.readthedocs.io/en/latest/api/client.html#paperswithcode.client.PapersWithCodeClient.method_list
func (c *Client) MethodGet(methodID string) (*models.Method, error) {
url := c.baseURL + "/methods/" + methodID

response, err := c.httpClient.Get(url)
if err != nil {
return nil, err
}

var result models.Method

err = json.NewDecoder(response.Body).Decode(&result)
if err != nil {
return nil, err
}

return &result, nil
}
43 changes: 43 additions & 0 deletions method_get_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package paperswithcode_go

import (
"github.com/codingpot/paperswithcode-go/v2/internal/testutils"
"github.com/codingpot/paperswithcode-go/v2/models"
"github.com/stretchr/testify/assert"
"testing"
)

func TestClient_MethodGet(t *testing.T) {
tests := []struct {
name string
methodID string
want *models.Method
wantErr bool
}{
{
name: "With a correct methodID, it returns a method",
methodID: "multi-head-attention",
want: &models.Method{
ID: "multi-head-attention",
Name: "Multi-Head Attention",
FullName: "Multi-Head Attention",
Description: "**Multi-head Attention** is a module for attention mechanisms which runs through an attention mechanism several times in parallel. The independent attention outputs are then concatenated and linearly transformed into the expected dimension. Intuitively, multiple attention heads allows for attending to parts of the sequence differently (e.g. longer-term dependencies versus shorter-term dependencies). \r\n\r\n$$ \\text{MultiHead}\\left(\\textbf{Q}, \\textbf{K}, \\textbf{V}\\right) = \\left[\\text{head}\\_{1},\\dots,\\text{head}\\_{h}\\right]\\textbf{W}_{0}$$\r\n\r\n$$\\text{where} \\text{ head}\\_{i} = \\text{Attention} \\left(\\textbf{Q}\\textbf{W}\\_{i}^{Q}, \\textbf{K}\\textbf{W}\\_{i}^{K}, \\textbf{V}\\textbf{W}\\_{i}^{V} \\right) $$\r\n\r\nAbove $\\textbf{W}$ are all learnable parameter matrices.\r\n\r\nNote that [scaled dot-product attention](https://paperswithcode.com/method/scaled) is most commonly used in this module, although in principle it can be swapped out for other types of attention mechanism.\r\n\r\nSource: [Lilian Weng](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html#a-family-of-attention-mechanisms)",
Paper: testutils.ToStringPtr("attention-is-all-you-need"),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := NewClient()
got, err := c.MethodGet(tt.methodID)
if tt.wantErr {
assert.Error(t, err)
} else {

assert.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
}
}

0 comments on commit d783bf8

Please sign in to comment.