From d783bf8847b745990d717f98c3fa4846e8f8afe0 Mon Sep 17 00:00:00 2001 From: Mo Kweon Date: Sun, 13 Jun 2021 14:56:01 -0700 Subject: [PATCH] feat: add MethodGet(methodID) - It will be used to query each method --- internal/testutils/testutils.go | 6 +++++ method_get.go | 26 ++++++++++++++++++++ method_get_test.go | 43 +++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 method_get.go create mode 100644 method_get_test.go diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go index a49f0bc..956892e 100644 --- a/internal/testutils/testutils.go +++ b/internal/testutils/testutils.go @@ -12,3 +12,9 @@ func MustExtractAPITokenFromEnv() string { } return apiToken } + + +// ToStringPtr returns a pointer to the given string. +func ToStringPtr(s string) *string { + return &s +} \ No newline at end of file diff --git a/method_get.go b/method_get.go new file mode 100644 index 0000000..3e21b05 --- /dev/null +++ b/method_get.go @@ -0,0 +1,26 @@ +package paperswithcode_go + +import ( + "encoding/json" + "github.com/codingpot/paperswithcode-go/v2/models" +) + +// MethodGet returns a method in a paper. +// See https://paperswithcode-client.readthedocs.io/en/latest/api/client.html#paperswithcode.client.PapersWithCodeClient.method_list +func (c *Client) MethodGet(methodID string) (*models.Method, error) { + url := c.baseURL + "/methods/" + methodID + + response, err := c.httpClient.Get(url) + if err != nil { + return nil, err + } + + var result models.Method + + err = json.NewDecoder(response.Body).Decode(&result) + if err != nil { + return nil, err + } + + return &result, nil +} \ No newline at end of file diff --git a/method_get_test.go b/method_get_test.go new file mode 100644 index 0000000..8fe764e --- /dev/null +++ b/method_get_test.go @@ -0,0 +1,43 @@ +package paperswithcode_go + +import ( + "github.com/codingpot/paperswithcode-go/v2/internal/testutils" + "github.com/codingpot/paperswithcode-go/v2/models" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestClient_MethodGet(t *testing.T) { + tests := []struct { + name string + methodID string + want *models.Method + wantErr bool + }{ + { + name: "With a correct methodID, it returns a method", + methodID: "multi-head-attention", + want: &models.Method{ + ID: "multi-head-attention", + Name: "Multi-Head Attention", + FullName: "Multi-Head Attention", + Description: "**Multi-head Attention** is a module for attention mechanisms which runs through an attention mechanism several times in parallel. The independent attention outputs are then concatenated and linearly transformed into the expected dimension. Intuitively, multiple attention heads allows for attending to parts of the sequence differently (e.g. longer-term dependencies versus shorter-term dependencies). \r\n\r\n$$ \\text{MultiHead}\\left(\\textbf{Q}, \\textbf{K}, \\textbf{V}\\right) = \\left[\\text{head}\\_{1},\\dots,\\text{head}\\_{h}\\right]\\textbf{W}_{0}$$\r\n\r\n$$\\text{where} \\text{ head}\\_{i} = \\text{Attention} \\left(\\textbf{Q}\\textbf{W}\\_{i}^{Q}, \\textbf{K}\\textbf{W}\\_{i}^{K}, \\textbf{V}\\textbf{W}\\_{i}^{V} \\right) $$\r\n\r\nAbove $\\textbf{W}$ are all learnable parameter matrices.\r\n\r\nNote that [scaled dot-product attention](https://paperswithcode.com/method/scaled) is most commonly used in this module, although in principle it can be swapped out for other types of attention mechanism.\r\n\r\nSource: [Lilian Weng](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html#a-family-of-attention-mechanisms)", + Paper: testutils.ToStringPtr("attention-is-all-you-need"), + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewClient() + got, err := c.MethodGet(tt.methodID) + if tt.wantErr { + assert.Error(t, err) + } else { + + assert.NoError(t, err) + } + assert.Equal(t, tt.want, got) + }) + } +} \ No newline at end of file