diff --git a/asreviewcontrib/semantic_clustering/semantic_clustering.py b/asreviewcontrib/semantic_clustering/semantic_clustering.py index 18bf9ae..f9913aa 100644 --- a/asreviewcontrib/semantic_clustering/semantic_clustering.py +++ b/asreviewcontrib/semantic_clustering/semantic_clustering.py @@ -12,22 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. -# import +# import ASReview +from tqdm import tqdm from asreview.data import ASReviewData +# import numpy +import numpy as np -class SemanticClustering(): - def __init__(self, data: ASReviewData): - self.data = data +# import transformer autotokenizer and automodel +from transformers import AutoTokenizer, AutoModel - # create ASReview data object +# disable transformer warning +from transformers import logging +logging.set_verbosity_error() +#import tqdm -def load_data(ASReviewDataObject): - data = ASReviewDataObject.df[['title', 'abstract']].copy() +def SemanticClustering(asreview_data_object): + + # load data + print("Loading data...") + data = load_data(asreview_data_object) + + # cut data for testing + data = data.iloc[:10, :] + + # load scibert transformer + print("Loading scibert transformer...") + transformer = 'allenai/scibert_scivocab_uncased' + + # load transformer and tokenizer + print("Loading tokenizer and model...") + tokenizer = AutoTokenizer.from_pretrained(transformer) + model = AutoModel.from_pretrained(transformer) + + # tokenize abstracts and add to data + print("Tokenizing abstracts...") + data['tokenized'] = data['abstract'].apply(lambda x: tokenizer.encode( + x, + padding='longest', + add_special_tokens=True, + return_tensors="pt")) + + print(data) + + +def load_data(asreview_data_object): + + # extract title and abstract, drop empty abstracts and reset index + data = asreview_data_object.df[['title', 'abstract']].copy() data['abstract'] = data['abstract'].replace('', np.nan, inplace=False) data.dropna(subset=['abstract'], inplace=True) data = data.reset_index(drop=True) return data + + +if __name__ == "__main__": + filepath = "https://raw.githubusercontent.com/asreview/systematic-review-datasets/master/datasets/van_de_Schoot_2017/output/van_de_Schoot_2017.csv" + SemanticClustering(ASReviewData.from_file(filepath))