From 3b86715e879ff63c07f9dd90c06ecd5e7fb119a4 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Tue, 31 Dec 2024 16:11:11 -0500 Subject: [PATCH 1/2] [SPARKNLP-1106] Introducing DistilBertForMultipleChoice --- .../annotator/classifier_dl/__init__.py | 1 + .../classifier_dl/bert_for_multiple_choice.py | 4 +- .../distilbert_for_multiple_choice.py | 161 +++++++++++ python/sparknlp/internal/__init__.py | 9 + .../distilbert_for_multiple_choice_test.py | 76 +++++ .../ml/ai/DistilBertClassification.scala | 88 ++++++ .../dl/DistilBertForMultipleChoice.scala | 259 ++++++++++++++++++ .../DistilBertForMultipleChoiceTestSpec.scala | 85 ++++++ 8 files changed, 681 insertions(+), 2 deletions(-) create mode 100644 python/sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py create mode 100644 python/test/annotator/classifier_dl/distilbert_for_multiple_choice_test.py create mode 100644 src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForMultipleChoice.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForMultipleChoiceTestSpec.scala diff --git a/python/sparknlp/annotator/classifier_dl/__init__.py b/python/sparknlp/annotator/classifier_dl/__init__.py index 2b5e30fc3ff359..9ae264e7a864f1 100644 --- a/python/sparknlp/annotator/classifier_dl/__init__.py +++ b/python/sparknlp/annotator/classifier_dl/__init__.py @@ -55,3 +55,4 @@ from sparknlp.annotator.classifier_dl.albert_for_zero_shot_classification import * from sparknlp.annotator.classifier_dl.camembert_for_zero_shot_classification import * from sparknlp.annotator.classifier_dl.bert_for_multiple_choice import * +from sparknlp.annotator.classifier_dl.distilbert_for_multiple_choice import * \ No newline at end of file diff --git a/python/sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py b/python/sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py index 2c27f913e56fcc..045e8d64180b53 100644 --- a/python/sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py +++ b/python/sparknlp/annotator/classifier_dl/bert_for_multiple_choice.py @@ -130,7 +130,7 @@ def loadSavedModel(folder, spark_session): Returns ------- - BertForQuestionAnswering + BertForMultipleChoice The restored model """ from sparknlp.internal import _BertMultipleChoiceLoader @@ -154,7 +154,7 @@ def pretrained(name="bert_base_uncased_multiple_choice", lang="en", remote_loc=N Returns ------- - BertForQuestionAnswering + BertForMultipleChoice The restored model """ from sparknlp.pretrained import ResourceDownloader diff --git a/python/sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py b/python/sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py new file mode 100644 index 00000000000000..f76aa3859c307e --- /dev/null +++ b/python/sparknlp/annotator/classifier_dl/distilbert_for_multiple_choice.py @@ -0,0 +1,161 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparknlp.common import * + +class DistilBertForMultipleChoice(AnnotatorModel, + HasCaseSensitiveProperties, + HasBatchedAnnotate, + HasEngine, + HasMaxSentenceLengthLimit): + """DistilBertForMultipleChoice can load DistilBert Models with a multiple choice classification head on top + (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> spanClassifier = DistilBertForMultipleChoice.pretrained() \\ + ... .setInputCols(["document_question", "document_context"]) \\ + ... .setOutputCol("answer") + + The default model is ``"bert_base_uncased_multiple_choice"``, if no name is + provided. + + For available pretrained models please see the `Models Hub + `__. + + To see which models are compatible and how to import them see + `Import Transformers into Spark NLP 🚀 + `_. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT, DOCUMENT`` ``CHUNK`` + ====================== ====================== + + Parameters + ---------- + batchSize + Batch size. Large values allows faster processing but requires more + memory, by default 8 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default + False + maxSentenceLength + Max sentence length to process, by default 512 + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = MultiDocumentAssembler() \\ + ... .setInputCols(["question", "context"]) \\ + ... .setOutputCols(["document_question", "document_context"]) + >>> questionAnswering = DistilBertForMultipleChoice.pretrained() \\ + ... .setInputCols(["document_question", "document_context"]) \\ + ... .setOutputCol("answer") \\ + ... .setCaseSensitive(False) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... questionAnswering + ... ]) + >>> data = spark.createDataFrame([["The Eiffel Tower is located in which country??", "Germany, France, Italy"]]).toDF("question", "context") + >>> result = pipeline.fit(data).transform(data) + >>> result.select("answer.result").show(truncate=False) + +--------------------+ + |result | + +--------------------+ + |[France] | + +--------------------+ + """ + name = "DistilBertForMultipleChoice" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.CHUNK + + choicesDelimiter = Param(Params._dummy(), + "choicesDelimiter", + "Delimiter character use to split the choices", + TypeConverters.toString) + + def setChoicesDelimiter(self, value): + """Sets delimiter character use to split the choices + + Parameters + ---------- + value : string + Delimiter character use to split the choices + """ + return self._set(caseSensitive=value) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.classifier.dl.DistilBertForMultipleChoice", + java_model=None): + super(DistilBertForMultipleChoice, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + batchSize=4, + maxSentenceLength=512, + caseSensitive=False, + choicesDelimiter = "," + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + DistilBertForMultipleChoice + The restored model + """ + from sparknlp.internal import _DistilBertMultipleChoiceLoader + jModel = _DistilBertMultipleChoiceLoader(folder, spark_session._jsparkSession)._java_obj + return DistilBertForMultipleChoice(java_model=jModel) + + @staticmethod + def pretrained(name="distilbert_base_uncased_multiple_choice", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default + "bert_base_uncased_multiple_choice" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + DistilBertForMultipleChoice + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(DistilBertForMultipleChoice, name, lang, remote_loc) \ No newline at end of file diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index 4cb5321e8a8691..f3e0152aab2574 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -211,6 +211,15 @@ def __init__(self, path, jspark): ) +class _DistilBertMultipleChoiceLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_DistilBertMultipleChoiceLoader, self).__init__( + "com.johnsnowlabs.nlp.annotators.classifier.dl.DistilBertForMultipleChoice.loadSavedModel", + path, + jspark, + ) + + class _ElmoLoader(ExtendedJavaWrapper): def __init__(self, path, jspark): super(_ElmoLoader, self).__init__( diff --git a/python/test/annotator/classifier_dl/distilbert_for_multiple_choice_test.py b/python/test/annotator/classifier_dl/distilbert_for_multiple_choice_test.py new file mode 100644 index 00000000000000..15e3885767017b --- /dev/null +++ b/python/test/annotator/classifier_dl/distilbert_for_multiple_choice_test.py @@ -0,0 +1,76 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.util import SparkContextForTest + + +class DistilBertForMultipleChoiceTestSetup(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + self.question = "The Eiffel Tower is located in which country?" + self.choices = "Germany, France, Italy" + + self.spark = SparkContextForTest.spark + empty_df = self.spark.createDataFrame([[""]]).toDF("text") + + document_assembler = MultiDocumentAssembler() \ + .setInputCols(["question", "context"]) \ + .setOutputCols(["document_question", "document_context"]) + + DistilBert_for_multiple_choice = DistilBertForMultipleChoice.pretrained() \ + .setInputCols(["document_question", "document_context"]) \ + .setOutputCol("answer") + + pipeline = Pipeline(stages=[document_assembler, DistilBert_for_multiple_choice]) + + self.pipeline_model = pipeline.fit(empty_df) + + +@pytest.mark.slow +class DistilBertForMultipleChoiceTest(DistilBertForMultipleChoiceTestSetup, unittest.TestCase): + + def setUp(self): + super().setUp() + self.data = self.spark.createDataFrame([[self.question, self.choices]]).toDF("question","context") + self.data.show(truncate=False) + + def test_run(self): + result_df = self.pipeline_model.transform(self.data) + result_df.show(truncate=False) + for row in result_df.collect(): + self.assertTrue(row["answer"][0].result != "") + + +@pytest.mark.slow +class LightDistilBertForMultipleChoiceTest(DistilBertForMultipleChoiceTestSetup, unittest.TestCase): + + def setUp(self): + super().setUp() + + def runTest(self): + light_pipeline = LightPipeline(self.pipeline_model) + annotations_result = light_pipeline.fullAnnotate(self.question,self.choices) + print(annotations_result) + for result in annotations_result: + self.assertTrue(result["answer"][0].result != "") + + result = light_pipeline.annotate(self.question,self.choices) + print(result) + self.assertTrue(result["answer"] != "") diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala index 2ae27f9510eaff..fecab2daf79372 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala @@ -484,6 +484,94 @@ private[johnsnowlabs] class DistilBertClassification( (startScores, endScores) } + override def tagSpanMultipleChoice(batch: Seq[Array[Int]]): Array[Float] = { + val logits = detectedEngine match { + case ONNX.name => computeLogitsMultipleChoiceWithOnnx(batch) + case Openvino.name => computeLogitsMultipleChoiceWithOv(batch) + } + + calculateSoftmax(logits) + } + + private def computeLogitsMultipleChoiceWithOnnx(batch: Seq[Array[Int]]): Array[Float] = { + val sequenceLength = batch.head.length + val inputIds = Array(batch.map(x => x.map(_.toLong)).toArray) + val attentionMask = Array( + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + val tokenTypeIds = Array(batch.map(_ => Array.fill(sequenceLength)(0L)).toArray) + + val (ortSession, ortEnv) = onnxWrapper.get.getSession(onnxSessionOptions) + val tokenTensors = OnnxTensor.createTensor(ortEnv, inputIds) + val maskTensors = OnnxTensor.createTensor(ortEnv, attentionMask) + val segmentTensors = OnnxTensor.createTensor(ortEnv, tokenTypeIds) + + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors).asJava + + try { + val output = ortSession.run(inputs) + try { + + val logits = output + .get("logits") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + + logits + } finally if (output != null) output.close() + } catch { + case e: Exception => + // Log the exception as a warning + println("Exception in computeLogitsMultipleChoiceWithOnnx: ", e) + // Rethrow the exception to propagate it further + throw e + } + } + + private def computeLogitsMultipleChoiceWithOv(batch: Seq[Array[Int]]): Array[Float] = { + val (numChoices, sequenceLength) = (batch.length, batch.head.length) + // batch_size, num_choices, sequence_length + val shape = Some(Array(1, numChoices, sequenceLength)) + val (tokenTensors, maskTensors, segmentTensors) = + PrepareEmbeddings.prepareOvLongBatchTensorsWithSegment( + batch, + sequenceLength, + numChoices, + sentencePadTokenId, + shape) + + val compiledModel = openvinoWrapper.get.getCompiledModel() + val inferRequest = compiledModel.create_infer_request() + inferRequest.set_tensor("input_ids", tokenTensors) + inferRequest.set_tensor("attention_mask", maskTensors) + + inferRequest.infer() + + try { + try { + val logits = inferRequest + .get_output_tensor() + .data() + + logits + } + } catch { + case e: Exception => + // Log the exception as a warning + logger.warn("Exception in computeLogitsMultipleChoiceWithOv", e) + // Rethrow the exception to propagate it further + throw e + } + } + def computeLogitsWithTF( batch: Seq[Array[Int]], maxSentenceLength: Int): (Array[Float], Array[Float]) = { diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForMultipleChoice.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForMultipleChoice.scala new file mode 100644 index 00000000000000..dedf00525a2360 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForMultipleChoice.scala @@ -0,0 +1,259 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.ml.ai.DistilBertClassification +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.openvino.{OpenvinoWrapper, ReadOpenvinoModel, WriteOpenvinoModel} +import com.johnsnowlabs.ml.tensorflow.TensorflowWrapper +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, Openvino} +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.serialization.MapFeature +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param.{IntParam, Param} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession + +class DistilBertForMultipleChoice(override val uid: String) + extends AnnotatorModel[DistilBertForMultipleChoice] + with HasBatchedAnnotate[DistilBertForMultipleChoice] + with WriteOnnxModel + with WriteOpenvinoModel + with HasCaseSensitiveProperties + with HasEngine { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + def this() = this(Identifiable.randomUID("DistilBertForMultipleChoice")) + + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(AnnotatorType.DOCUMENT, AnnotatorType.DOCUMENT) + override val outputAnnotatorType: AnnotatorType = AnnotatorType.CHUNK + + /** Vocabulary used to encode the words to ids with WordPieceEncoder + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + /** @group setParam */ + def sentenceStartTokenId: Int = { + $$(vocabulary)("[CLS]") + } + + /** @group setParam */ + def sentenceEndTokenId: Int = { + $$(vocabulary)("[SEP]") + } + + /** Max sentence length to process (Default: `512`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "DistilBERT models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + val choicesDelimiter = + new Param[String](this, "choicesDelimiter", "Delimiter character use to split the choices") + + def setChoicesDelimiter(value: String): this.type = set(choicesDelimiter, value) + + private var _model: Option[Broadcast[DistilBertClassification]] = None + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], + openvinoWrapper: Option[OpenvinoWrapper], + ): DistilBertForMultipleChoice = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new DistilBertClassification( + tensorflowWrapper, + onnxWrapper, + openvinoWrapper, + sentenceStartTokenId, + sentenceEndTokenId, + tags = Map.empty[String, Int], + vocabulary = $$(vocabulary)))) + } + + this + } + + /** @group getParam */ + def getModelIfNotSet: DistilBertClassification = _model.get.value + + /** Whether to lowercase tokens or not (Default: `true`). + * + * @group setParam + */ + override def setCaseSensitive(value: Boolean): this.type = set(this.caseSensitive, value) + + setDefault( + batchSize -> 4, + maxSentenceLength -> 512, + caseSensitive -> false, + choicesDelimiter -> ",") + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations in batches that correspond to inputAnnotationCols generated by previous + * annotators if any + * @return + * any number of annotations processed for every batch of input annotations. Not necessary + * one to one relationship + * + * IMPORTANT: !MUST! return sequences of equal lengths !! IMPORTANT: !MUST! return sentences + * that belong to the same original row !! (challenging) + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + batchedAnnotations.map(annotations => { + if (annotations.nonEmpty) { + getModelIfNotSet.predictSpanMultipleChoice( + annotations, + $(choicesDelimiter), + $(maxSentenceLength), + $(caseSensitive)) + } else { + Seq.empty[Annotation] + } + }) + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + + getEngine match { + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + "_distilbert_multiple_choice_classification", + DistilBertForMultipleChoice.onnxFile) + case Openvino.name => + writeOpenvinoModel( + path, + spark, + getModelIfNotSet.openvinoWrapper.get, + "openvino_model.xml", + DistilBertForMultipleChoice.openvinoFile) + } + } + +} + +trait ReadablePretrainedDistilBertForMultipleChoiceModel + extends ParamsAndFeaturesReadable[DistilBertForMultipleChoice] + with HasPretrained[DistilBertForMultipleChoice] { + override val defaultModelName: Some[String] = Some("distilbert_base_uncased_multiple_choice") + + /** Java compliant-overrides */ + override def pretrained(): DistilBertForMultipleChoice = super.pretrained() + + override def pretrained(name: String): DistilBertForMultipleChoice = super.pretrained(name) + + override def pretrained(name: String, lang: String): DistilBertForMultipleChoice = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): DistilBertForMultipleChoice = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadDistilBertForMultipleChoiceModel extends ReadOnnxModel with ReadOpenvinoModel { + this: ParamsAndFeaturesReadable[DistilBertForMultipleChoice] => + + override val onnxFile: String = "distilbert_mc_classification_onnx" + override val openvinoFile: String = "distilbert_mc_classification_openvino" + + def readModel(instance: DistilBertForMultipleChoice, path: String, spark: SparkSession): Unit = { + instance.getEngine match { + case ONNX.name => + val onnxWrapper = + readOnnxModel(path, spark, "distilbert_mc_classification_onnx") + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), None) + case Openvino.name => + val openvinoWrapper = readOpenvinoModel(path, spark, "distilbert_mc_classification_ov") + instance.setModelIfNotSet(spark, None, None, Some(openvinoWrapper)) + case _ => + throw new Exception(notSupportedEngineError) + } + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): DistilBertForMultipleChoice = { + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + val annotatorModel = new DistilBertForMultipleChoice().setVocabulary(vocabs) + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case ONNX.name => + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), None) + case Openvino.name => + val ovWrapper: OpenvinoWrapper = + OpenvinoWrapper.read( + spark, + localModelPath, + zipped = false, + useBundle = true, + detectedEngine = detectedEngine) + annotatorModel + .setModelIfNotSet(spark, None, None, Some(ovWrapper)) + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } + +} + +/** This is the companion object of [[DistilBertForMultipleChoice]]. Please refer to that class for the + * documentation. + */ +object DistilBertForMultipleChoice + extends ReadablePretrainedDistilBertForMultipleChoiceModel + with ReadDistilBertForMultipleChoiceModel \ No newline at end of file diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForMultipleChoiceTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForMultipleChoiceTestSpec.scala new file mode 100644 index 00000000000000..247b05d833b8dd --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForMultipleChoiceTestSpec.scala @@ -0,0 +1,85 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package com.johnsnowlabs.nlp.annotators.classifier.dl + +import com.johnsnowlabs.nlp.{Annotation, AssertAnnotations, MultiDocumentAssembler} +import com.johnsnowlabs.nlp.annotators.SparkSessionTest +import com.johnsnowlabs.nlp.base.LightPipeline +import com.johnsnowlabs.tags.SlowTest +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.scalatest.flatspec.AnyFlatSpec + +class DistilBertForMultipleChoiceTestSpec extends AnyFlatSpec with SparkSessionTest { + + import spark.implicits._ + + lazy val pipelineModel = getDistilBertForMultipleChoicePipelineModel + + val testDataframe = + Seq(("The Eiffel Tower is located in which country?", "Germany, France, Italy")) + .toDF("question", "context") + + + "DistilBertForMultipleChoiceTestSpec" should "answer a multiple choice question" taggedAs SlowTest in { + val resultDf = pipelineModel.transform(testDataframe) + resultDf.show(truncate = false) + + val result = AssertAnnotations.getActualResult(resultDf, "answer") + result.foreach { annotation => + annotation.foreach(a => assert(a.result.nonEmpty)) + } + } + + it should "work with light pipeline fullAnnotate" taggedAs SlowTest in { + val lightPipeline = new LightPipeline(pipelineModel) + val resultFullAnnotate = lightPipeline.fullAnnotate( + "The Eiffel Tower is located in which country?", + "Germany, France, Italy") + println(s"resultAnnotate: $resultFullAnnotate") + + val answerAnnotation = resultFullAnnotate("answer").head.asInstanceOf[Annotation] + + assert(answerAnnotation.result.nonEmpty) + } + + it should "work with light pipeline annotate" taggedAs SlowTest in { + val lightPipeline = new LightPipeline(pipelineModel) + val resultAnnotate = lightPipeline.annotate( + "The Eiffel Tower is located in which country?", + "Germany, France, Italy") + println(s"resultAnnotate: $resultAnnotate") + + assert(resultAnnotate("answer").head.nonEmpty) + } + + private def getDistilBertForMultipleChoicePipelineModel: PipelineModel = { + val documentAssembler = new MultiDocumentAssembler() + .setInputCols("question", "context") + .setOutputCols("document_question", "document_context") + + val bertForMultipleChoice = DistilBertForMultipleChoice + .pretrained() + .setInputCols("document_question", "document_context") + .setOutputCol("answer") + + val pipeline = new Pipeline().setStages(Array(documentAssembler, bertForMultipleChoice)) + + pipeline.fit(emptyDataSet) + } + +} From 191c78bbb7ecce7b94c501350f52501cc5e6bcd5 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Thu, 2 Jan 2025 12:14:24 -0500 Subject: [PATCH 2/2] [SPARKNLP-1106] Adding notebook examples for DistilBertForMultipleChoice --- ...park_NLP_DistilBertForMultipleChoice.ipynb | 2751 ++++++++++++++++ ...park_NLP_DistilBertForMultipleChoice.ipynb | 2903 +++++++++++++++++ 2 files changed, 5654 insertions(+) create mode 100644 examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_DistilBertForMultipleChoice.ipynb create mode 100644 examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_DistilBertForMultipleChoice.ipynb diff --git a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_DistilBertForMultipleChoice.ipynb b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_DistilBertForMultipleChoice.ipynb new file mode 100644 index 00000000000000..d3ff42a5bec59c --- /dev/null +++ b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_DistilBertForMultipleChoice.ipynb @@ -0,0 +1,2751 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PAsu8UVGoLVf" + }, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_DistilBertForMultipleChoice.ipynb)\n", + "\n", + "## Import ONNX DistilBertForMultipleChoice models from HuggingFace 🤗 into Spark NLP 🚀\n", + "\n", + "Let's keep in mind a few things before we start 😊\n", + "\n", + "- ONNX support was introduced in `Spark NLP 5.0.0`, enabling high performance inference for models.\n", + "- `DistilBertForMultipleChoice` is only available since in `Spark NLP 5.6.0` and after. So please make sure you have upgraded to the latest Spark NLP release\n", + "- You can import BERT models trained/fine-tuned for question answering via `DistilBertForMultipleChoice` or `TFDistilBertForMultipleChoice`. These models are usually under `Multiple Choice` category and have `bert` in their labels\n", + "- Reference: [DistilBertForMultipleChoice](https://huggingface.co/docs/transformers/main/en/model_doc/distilbert#transformers.DistilBertForMultipleChoice)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OzijcdtQpOx9" + }, + "source": [ + "## Export and Save HuggingFace model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MlgoClMXpSg4" + }, + "source": [ + "- Let's install `transformers` package with the `onnx` extension and it's dependencies. You don't need `onnx` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cJWbob-kHICU", + "outputId": "b9a93019-f7a9-4f13-d727-502e8806d75c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/424.1 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", + "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m419.8/424.1 kB\u001b[0m \u001b[31m13.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.1/424.1 kB\u001b[0m \u001b[31m10.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/13.3 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", + "\u001b[2K \u001b[91m━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.1/13.3 MB\u001b[0m \u001b[31m124.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r", + "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━\u001b[0m \u001b[32m11.4/13.3 MB\u001b[0m \u001b[31m183.8 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r", + "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m13.3/13.3 MB\u001b[0m \u001b[31m195.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.3/13.3 MB\u001b[0m \u001b[31m110.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/212.7 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.7/212.7 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m42.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.5/84.5 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m65.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m455.8/455.8 kB\u001b[0m \u001b[31m38.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m11.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m14.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.0/16.0 MB\u001b[0m \u001b[31m102.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.5/55.5 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m21.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\n", + "grpcio-status 1.62.3 requires protobuf>=4.21.6, but you have protobuf 3.20.2 which is incompatible.\n", + "tensorflow 2.17.1 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 3.20.2 which is incompatible.\n", + "tensorflow-metadata 1.13.1 requires protobuf<5,>=3.20.3, but you have protobuf 3.20.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q --upgrade transformers[onnx] optimum" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XtewR2xdOa5s" + }, + "source": [ + "- HuggingFace has an extension called Optimum which offers specialized model inference, including ONNX. We can use this to import and export ONNX models with `from_pretrained` and `save_pretrained`.\n", + "- We'll use the treained model above as an example and load it as a `ORTModelForMultipleChoice`, representing an ONNX model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 313, + "referenced_widgets": [ + "3fa8c62ddf034288bbe3c585e481582a", + "b942bc8a478b40c18b3afeb2ef549b82", + "74b1760ba60d4217be2b0a0540f9dc22", + "e6bd749383ad44c49d7d2555ed8f773a", + "35e316e7eedc4829a1c421b87e25cc6b", + "d10a939a7c294d51ac0970c1e6e67400", + "c501f85ca9104291b32766594abfb177", + "35c33ddbb5e64f2eaad2cf22bd2b312e", + "8e1edff76eb74812a4a57d8f9a7c66a5", + "f76dd41f539a42d883ea85d6fc249381", + "dc07286d1da74c1b8c80edf7b22ec105", + "5c1e19066ac2428f892903b335136e05", + "41bfa9e47892418981d2105b727c6651", + "10177825cf9746fe8f142462f7723aef", + "be47affd0e774f038ba4b935a7dc0fda", + "115e14b1a23046f6bde91b91368110cd", + "4c1c4ff45bdc4658b11d783c67d08f79", + "367b13b5545a4c65a5d8a827720271f6", + "895587a2608a49e7ad3d6132abb3a833", + "8a3001347f6e47afac38e3fa860a7aee", + "341bc46cd55148759af3deb3ad08a816", + "0ae4ff0010404177ae68147ac49ea7e6", + "2f44f043a48c4b79a08c8b8fab68be5c", + "b01be358bde44be1b233b147dfaaf9ad", + "b4e498dbede94fa49d93372f238b4966", + "c05e109ef7344f92a1fb33cd3d6b45e2", + "5bef47b4061f4d638796c17779e780dc", + "45757698c34b4db9ac25e7257cd15435", + "0c3da599f3304566920f6002e90a6961", + "d36768137ea04899b36b79c738a13fa6", + "2f9901d1aaf14f2c9bbed14e85e58ca1", + "19b1c94ebf85413da435ce89cad2d666", + "b851e1cbb82f446bb6a275afed0cc3b1", + "1e5cbd2c3f46415fbbc0ce27f4c48ce9", + "65331501cffe49428c3ee28e43432b74", + "1a9b4a12cc9443059cb57a411d1c2637", + "455d7eadeb44420a883b2066628640d8", + "b39e0fb3fbf44cba8fab21ddfac7a2cc", + "bdb846fc731e487ca1a72f55d7aa8d6e", + "9dca665841224956be77804fe1f90391", + "f23cd3265e034825b8cd3dd2d393d421", + "4c958a2b95b944c7829a9d2bc0cb77ac", + "9388c8035d2c47e08ab99040dcd00cf9", + "62f926c3d9f9461596521ba34e795b3a", + "d7b0df281fcb4e95abc2ac1f33503fb2", + "11fce16817ab484f953b590045ed3d5c", + "005cdb63a2ef454f8c7f401ec377e2df", + "a5d4aa7482434d43b0d856558d27664c", + "c8d906a71eb24d83bb823bac76af29cc", + "7d6ebb345cba4bcab22f9d1f7261d5c8", + "f4942df43b0d42e8ade2b3a953e76c86", + "1fb6829e34664a0cb82c56616160a4a7", + "4e3c9f8eb4964535a1ae416600f57e93", + "540d3b239d76402d8495be7e23351f07", + "fd869a4e8ec5472aa15916213b60f1cb", + "49307ab4126b42618a7c16fa81c22117", + "9a99cd70f2ff4e41ba7e60ea0142d45b", + "ab3833fcf88747c69da9c98ca60b3443", + "67bff7bbad194bef948cb8dfd985cd1b", + "6999f459998a4032ad29f074d5134847", + "cb1bef115ba54c549acc1056c3149ebd", + "3826c29474704eca93ed9a7d97865909", + "2999fb3ce9b041839ff03250080f0501", + "09369dd140294679ac321835d8a085e2", + "176a27c38a96473c8603b69e4009daf8", + "15a0cd6793ef4c4dbcbada0adada95a9" + ] + }, + "id": "Id33annImYM8", + "outputId": "a195cc94-133d-4705-cb1f-a3cd12fff670" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3fa8c62ddf034288bbe3c585e481582a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/574 [00:00 0, chunk -> 0, score -> 0.6048505}, []}]|\n", + "|[{chunk, 0, 6, Germany, {sentence -> 0, chunk -> 0, score -> 0.39137164}, []}] |\n", + "|[{chunk, 0, 5, Tiger, {sentence -> 0, chunk -> 0, score -> 0.2897997}, []}] |\n", + "|[{chunk, 0, 3, 90°C, {sentence -> 0, chunk -> 0, score -> 0.35916787}, []}] |\n", + "|[{chunk, 0, 6, Jupiter, {sentence -> 0, chunk -> 0, score -> 0.35939977}, []}] |\n", + "|[{chunk, 0, 7, English, {sentence -> 0, chunk -> 0, score -> 0.3640033}, []}] |\n", + "|[{chunk, 0, 11, The Mongols, {sentence -> 0, chunk -> 0, score -> 0.29171145}, []}] |\n", + "|[{chunk, 0, 6, Osmium, {sentence -> 0, chunk -> 0, score -> 0.4062368}, []}] |\n", + "|[{chunk, 0, 13, South America, {sentence -> 0, chunk -> 0, score -> 0.392063}, []}] |\n", + "|[{chunk, 0, 13, Pablo Picasso, {sentence -> 0, chunk -> 0, score -> 0.4129128}, []}] |\n", + "+----------------------------------------------------------------------------------------------------------+\n", + "\n" + ] + } + ], + "source": [ + "document_assembler = MultiDocumentAssembler() \\\n", + " .setInputCols([\"question\", \"choices\"]) \\\n", + " .setOutputCols([\"document_question\", \"document_choices\"])\n", + "\n", + "distilbert_for_multiple_choice = DistilBertForMultipleChoice() \\\n", + " .load(\"./{}_spark_nlp_onnx\".format(MODEL_NAME)) \\\n", + " .setInputCols([\"document_question\", \"document_choices\"])\\\n", + " .setOutputCol(\"answer\") \\\n", + " .setBatchSize(4)\n", + "\n", + "pipeline = Pipeline(stages=[document_assembler, distilbert_for_multiple_choice])\n", + "pipeline_model = pipeline.fit(testing_df)\n", + "\n", + "pipeline_df = pipeline_model.transform(testing_df)\n", + "\n", + "pipeline_df.select(\"answer\").show(truncate=False)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "005cdb63a2ef454f8c7f401ec377e2df": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1fb6829e34664a0cb82c56616160a4a7", + "max": 711396, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4e3c9f8eb4964535a1ae416600f57e93", + "value": 711396 + } + }, + "09369dd140294679ac321835d8a085e2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "0ae4ff0010404177ae68147ac49ea7e6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0c3da599f3304566920f6002e90a6961": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "10177825cf9746fe8f142462f7723aef": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_895587a2608a49e7ad3d6132abb3a833", + "max": 267829484, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8a3001347f6e47afac38e3fa860a7aee", + "value": 267829484 + } + }, + "115e14b1a23046f6bde91b91368110cd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "11fce16817ab484f953b590045ed3d5c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7d6ebb345cba4bcab22f9d1f7261d5c8", + "placeholder": "​", + "style": "IPY_MODEL_f4942df43b0d42e8ade2b3a953e76c86", + "value": "tokenizer.json: 100%" + } + }, + "15a0cd6793ef4c4dbcbada0adada95a9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "176a27c38a96473c8603b69e4009daf8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "19b1c94ebf85413da435ce89cad2d666": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1a9b4a12cc9443059cb57a411d1c2637": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f23cd3265e034825b8cd3dd2d393d421", + "max": 231508, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4c958a2b95b944c7829a9d2bc0cb77ac", + "value": 231508 + } + }, + "1e5cbd2c3f46415fbbc0ce27f4c48ce9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_65331501cffe49428c3ee28e43432b74", + "IPY_MODEL_1a9b4a12cc9443059cb57a411d1c2637", + "IPY_MODEL_455d7eadeb44420a883b2066628640d8" + ], + "layout": "IPY_MODEL_b39e0fb3fbf44cba8fab21ddfac7a2cc" + } + }, + "1fb6829e34664a0cb82c56616160a4a7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2999fb3ce9b041839ff03250080f0501": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2f44f043a48c4b79a08c8b8fab68be5c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b01be358bde44be1b233b147dfaaf9ad", + "IPY_MODEL_b4e498dbede94fa49d93372f238b4966", + "IPY_MODEL_c05e109ef7344f92a1fb33cd3d6b45e2" + ], + "layout": "IPY_MODEL_5bef47b4061f4d638796c17779e780dc" + } + }, + "2f9901d1aaf14f2c9bbed14e85e58ca1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "341bc46cd55148759af3deb3ad08a816": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "35c33ddbb5e64f2eaad2cf22bd2b312e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "35e316e7eedc4829a1c421b87e25cc6b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "367b13b5545a4c65a5d8a827720271f6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3826c29474704eca93ed9a7d97865909": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3fa8c62ddf034288bbe3c585e481582a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b942bc8a478b40c18b3afeb2ef549b82", + "IPY_MODEL_74b1760ba60d4217be2b0a0540f9dc22", + "IPY_MODEL_e6bd749383ad44c49d7d2555ed8f773a" + ], + "layout": "IPY_MODEL_35e316e7eedc4829a1c421b87e25cc6b" + } + }, + "41bfa9e47892418981d2105b727c6651": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4c1c4ff45bdc4658b11d783c67d08f79", + "placeholder": "​", + "style": "IPY_MODEL_367b13b5545a4c65a5d8a827720271f6", + "value": "model.safetensors: 100%" + } + }, + "455d7eadeb44420a883b2066628640d8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9388c8035d2c47e08ab99040dcd00cf9", + "placeholder": "​", + "style": "IPY_MODEL_62f926c3d9f9461596521ba34e795b3a", + "value": " 232k/232k [00:00<00:00, 2.56MB/s]" + } + }, + "45757698c34b4db9ac25e7257cd15435": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "49307ab4126b42618a7c16fa81c22117": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_9a99cd70f2ff4e41ba7e60ea0142d45b", + "IPY_MODEL_ab3833fcf88747c69da9c98ca60b3443", + "IPY_MODEL_67bff7bbad194bef948cb8dfd985cd1b" + ], + "layout": "IPY_MODEL_6999f459998a4032ad29f074d5134847" + } + }, + "4c1c4ff45bdc4658b11d783c67d08f79": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4c958a2b95b944c7829a9d2bc0cb77ac": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "4e3c9f8eb4964535a1ae416600f57e93": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "540d3b239d76402d8495be7e23351f07": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5bef47b4061f4d638796c17779e780dc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5c1e19066ac2428f892903b335136e05": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_41bfa9e47892418981d2105b727c6651", + "IPY_MODEL_10177825cf9746fe8f142462f7723aef", + "IPY_MODEL_be47affd0e774f038ba4b935a7dc0fda" + ], + "layout": "IPY_MODEL_115e14b1a23046f6bde91b91368110cd" + } + }, + "62f926c3d9f9461596521ba34e795b3a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "65331501cffe49428c3ee28e43432b74": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bdb846fc731e487ca1a72f55d7aa8d6e", + "placeholder": "​", + "style": "IPY_MODEL_9dca665841224956be77804fe1f90391", + "value": "vocab.txt: 100%" + } + }, + "67bff7bbad194bef948cb8dfd985cd1b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_176a27c38a96473c8603b69e4009daf8", + "placeholder": "​", + "style": "IPY_MODEL_15a0cd6793ef4c4dbcbada0adada95a9", + "value": " 125/125 [00:00<00:00, 10.1kB/s]" + } + }, + "6999f459998a4032ad29f074d5134847": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "74b1760ba60d4217be2b0a0540f9dc22": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_35c33ddbb5e64f2eaad2cf22bd2b312e", + "max": 574, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_8e1edff76eb74812a4a57d8f9a7c66a5", + "value": 574 + } + }, + "7d6ebb345cba4bcab22f9d1f7261d5c8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "895587a2608a49e7ad3d6132abb3a833": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8a3001347f6e47afac38e3fa860a7aee": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "8e1edff76eb74812a4a57d8f9a7c66a5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "9388c8035d2c47e08ab99040dcd00cf9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9a99cd70f2ff4e41ba7e60ea0142d45b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cb1bef115ba54c549acc1056c3149ebd", + "placeholder": "​", + "style": "IPY_MODEL_3826c29474704eca93ed9a7d97865909", + "value": "special_tokens_map.json: 100%" + } + }, + "9dca665841224956be77804fe1f90391": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a5d4aa7482434d43b0d856558d27664c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_540d3b239d76402d8495be7e23351f07", + "placeholder": "​", + "style": "IPY_MODEL_fd869a4e8ec5472aa15916213b60f1cb", + "value": " 711k/711k [00:00<00:00, 15.3MB/s]" + } + }, + "ab3833fcf88747c69da9c98ca60b3443": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2999fb3ce9b041839ff03250080f0501", + "max": 125, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_09369dd140294679ac321835d8a085e2", + "value": 125 + } + }, + "b01be358bde44be1b233b147dfaaf9ad": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_45757698c34b4db9ac25e7257cd15435", + "placeholder": "​", + "style": "IPY_MODEL_0c3da599f3304566920f6002e90a6961", + "value": "tokenizer_config.json: 100%" + } + }, + "b39e0fb3fbf44cba8fab21ddfac7a2cc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b4e498dbede94fa49d93372f238b4966": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d36768137ea04899b36b79c738a13fa6", + "max": 1224, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_2f9901d1aaf14f2c9bbed14e85e58ca1", + "value": 1224 + } + }, + "b851e1cbb82f446bb6a275afed0cc3b1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b942bc8a478b40c18b3afeb2ef549b82": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d10a939a7c294d51ac0970c1e6e67400", + "placeholder": "​", + "style": "IPY_MODEL_c501f85ca9104291b32766594abfb177", + "value": "config.json: 100%" + } + }, + "bdb846fc731e487ca1a72f55d7aa8d6e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "be47affd0e774f038ba4b935a7dc0fda": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_341bc46cd55148759af3deb3ad08a816", + "placeholder": "​", + "style": "IPY_MODEL_0ae4ff0010404177ae68147ac49ea7e6", + "value": " 268M/268M [00:06<00:00, 42.6MB/s]" + } + }, + "c05e109ef7344f92a1fb33cd3d6b45e2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_19b1c94ebf85413da435ce89cad2d666", + "placeholder": "​", + "style": "IPY_MODEL_b851e1cbb82f446bb6a275afed0cc3b1", + "value": " 1.22k/1.22k [00:00<00:00, 99.0kB/s]" + } + }, + "c501f85ca9104291b32766594abfb177": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c8d906a71eb24d83bb823bac76af29cc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cb1bef115ba54c549acc1056c3149ebd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d10a939a7c294d51ac0970c1e6e67400": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d36768137ea04899b36b79c738a13fa6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d7b0df281fcb4e95abc2ac1f33503fb2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_11fce16817ab484f953b590045ed3d5c", + "IPY_MODEL_005cdb63a2ef454f8c7f401ec377e2df", + "IPY_MODEL_a5d4aa7482434d43b0d856558d27664c" + ], + "layout": "IPY_MODEL_c8d906a71eb24d83bb823bac76af29cc" + } + }, + "dc07286d1da74c1b8c80edf7b22ec105": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e6bd749383ad44c49d7d2555ed8f773a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f76dd41f539a42d883ea85d6fc249381", + "placeholder": "​", + "style": "IPY_MODEL_dc07286d1da74c1b8c80edf7b22ec105", + "value": " 574/574 [00:00<00:00, 48.8kB/s]" + } + }, + "f23cd3265e034825b8cd3dd2d393d421": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f4942df43b0d42e8ade2b3a953e76c86": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f76dd41f539a42d883ea85d6fc249381": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fd869a4e8ec5472aa15916213b60f1cb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_DistilBertForMultipleChoice.ipynb b/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_DistilBertForMultipleChoice.ipynb new file mode 100644 index 00000000000000..018ee5807e529d --- /dev/null +++ b/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_DistilBertForMultipleChoice.ipynb @@ -0,0 +1,2903 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "_V5XcDCnVgSi" + }, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_DistilBertForMultipleChoice.ipynb)\n", + "\n", + "# Import OpenVINO DistilBertForMultipleChoice models from HuggingFace 🤗 into Spark NLP 🚀\n", + "\n", + "This notebook provides a detailed walkthrough on optimizing and exporting DistilBertForMultipleChoice models from HuggingFace for use in Spark NLP, leveraging the various tools provided in the [Intel OpenVINO toolkit](https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html) ecosystem.\n", + "\n", + "Let's keep in mind a few things before we start 😊\n", + "\n", + "- OpenVINO support was introduced in `Spark NLP 5.4.0`, enabling high performance inference for models. Please make sure you have upgraded to the latest Spark NLP release.\n", + "- You can import models for DistilBertForMultipleChoice from DistilBertForMultipleChoice and they have to be in `Multiple Choice` category." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aghasVppVgSk" + }, + "source": [ + "## 1. Export and Save the HuggingFace model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "be4HsTDMVgSk" + }, + "source": [ + "- Let's install `transformers` and `openvino` packages with other dependencies. You don't need `openvino` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n", + "- We lock `transformers` on version `4.41.2`. This doesn't mean it won't work with the future releases, but we wanted you to know which versions have been tested successfully." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-7L-2ZWUVgSl", + "outputId": "5d2d172b-5f02-4639-83fc-82b9e601dfe3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.8/43.8 kB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.1/9.1 MB\u001b[0m \u001b[31m61.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m84.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.7/38.7 MB\u001b[0m \u001b[31m52.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m215.7/215.7 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m21.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.1/424.1 kB\u001b[0m \u001b[31m38.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.0/16.0 MB\u001b[0m \u001b[31m102.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m18.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m13.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m18.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.1/13.1 MB\u001b[0m \u001b[31m101.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m66.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "google-ai-generativelanguage 0.6.10 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-api-core 2.19.2 requires protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0.dev0,>=3.19.5, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-aiplatform 1.74.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-bigquery-connection 1.17.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-bigquery-storage 2.27.0 requires protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-bigtable 2.27.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-datastore 2.20.2 requires protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-firestore 2.19.0 requires protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-functions 1.19.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-iam 2.17.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-language 2.16.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-pubsub 2.27.1 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-resource-manager 1.14.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "google-cloud-translate 3.19.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "googleapis-common-protos 1.66.0 requires protobuf!=3.20.0,!=3.20.1,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0.dev0,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "grpc-google-iam-v1 0.13.1 requires protobuf!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.20.1 which is incompatible.\n", + "grpcio-status 1.62.3 requires protobuf>=4.21.6, but you have protobuf 3.20.1 which is incompatible.\n", + "tensorflow 2.17.1 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 3.20.1 which is incompatible.\n", + "tensorflow-metadata 1.13.1 requires protobuf<5,>=3.20.3, but you have protobuf 3.20.1 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q --upgrade transformers==4.41.2\n", + "!pip install -q --upgrade openvino==2024.1\n", + "!pip install -q --upgrade optimum-intel==1.17.0\n", + "!pip install -q --upgrade onnx==1.12.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vI7uz_6hVgSl" + }, + "source": [ + "[Optimum Intel](https://github.com/huggingface/optimum-intel?tab=readme-ov-file#openvino) is the interface between the Transformers library and the various model optimization and acceleration tools provided by Intel. HuggingFace models loaded with optimum-intel are automatically optimized for OpenVINO, while being compatible with the Transformers API.\n", + "- Normally, to load a HuggingFace model directly for inference/export, just replace the `AutoModelForXxx` class with the corresponding `OVModelForXxx` class. However, ForMultipleChoice is not yet available so we will use `openvino.convert_model()` after exporting ONNX model\n", + "- We'll use [irfanamal/bert_multiple_choice](https://huggingface.co/irfanamal/bert_multiple_choice) model from HuggingFace as an example\n", + "- We also need the `vocab.txt` saved from `AutoTokenizer`. This is the same for every model, these are assets (saved in `/assets`) needed for tokenization inside Spark NLP." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "TDapJ_09nqXQ", + "outputId": "3d8b2ec9-b2c9-4fe8-b3de-0654a43416ba" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: pip in /usr/local/lib/python3.10/dist-packages (24.1.2)\n", + "Collecting pip\n", + " Downloading pip-24.3.1-py3-none-any.whl.metadata (3.7 kB)\n", + "Downloading pip-24.3.1-py3-none-any.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m22.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: pip\n", + " Attempting uninstall: pip\n", + " Found existing installation: pip 24.1.2\n", + " Uninstalling pip-24.1.2:\n", + " Successfully uninstalled pip-24.1.2\n", + "Successfully installed pip-24.3.1\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.3/13.3 MB\u001b[0m \u001b[31m113.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m115.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.1/10.1 MB\u001b[0m \u001b[31m151.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m60.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "grpcio-status 1.62.3 requires protobuf>=4.21.6, but you have protobuf 3.20.2 which is incompatible.\n", + "optimum-intel 1.17.0 requires transformers<4.42.0,>=4.36.0, but you have transformers 4.47.1 which is incompatible.\n", + "tensorflow 2.17.1 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 3.20.2 which is incompatible.\n", + "tensorflow-metadata 1.13.1 requires protobuf<5,>=3.20.3, but you have protobuf 3.20.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install --upgrade pip\n", + "!pip install -q --upgrade transformers[onnx] optimum openvino==2024.1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 313, + "referenced_widgets": [ + "1685190d1c0b40dc9a04bf961ed1fe9a", + "8faf63a59311431e8a4f8149b3d09be3", + "04caa1515af146588adb1552ece36dcd", + "0a41311d87454a49b28792762d1ad390", + "df80408d1d9f4995bce81c9657a16902", + "d09eb9036f7042979767eee8b8adafdd", + "e45f9f7671f84e5783ba53cdafbf0eee", + "3787315381de4f4a8f383dffe4dd1318", + "92efd8cc36b24ba7827eb8ee8d5ddd76", + "586875b1181840658f82ed3791819ea9", + "e5927cae79b2469a8c72dc30591f895c", + "cd2acd5f7e7f4811906e9efdf91cf2f3", + "497ac3bcc25b4710994554005e19ca36", + "7154097716564b4db750626f7efdffc7", + "103d60fc85664ea6852c6cf65a838f08", + "94034333eb9f4399ae56edfa1b9d4b9b", + "153a98ebce99415bbaa8f3bcf0e003a3", + "56b0be1d416b4e968a9b5380367dcd77", + "48d8e523726b426ab553e158e1738d35", + "6758258d40b34089bb9a2ff3d3008c98", + "abde155ca5604caea5996b7eb88e80a3", + "f2e9969ee9364f8b96c0ccbe17b04dd3", + "c77c1345e62441ec9f0cccefd74680ae", + "bda09c689b7e451b9f205bcd89e2892c", + "4d1ce1e3122d4685895e531179f1a6f8", + "ea563bffe9a9402797647d6e57726b07", + "11e9347df720416ab68f7254699e4bfd", + "d74e96dcba28405bb0b90863c372b9a3", + "4143428097184e239be3aa15a709a1bd", + "ebb172d5ac1c4e2381a5e02b3d009d71", + "3935021ef68448f394a93f6a6215724b", + "f73b5c70ce424707951099c81a54ec0f", + "545fcedce22f463e9745cf2785c6260f", + "5acb405e46c7407aa3d38e49d4087a0b", + "e27b262eb28f48a8b39c32a4f65bf74f", + "1da886fbe0a341059120356697806776", + "16da8069c945418f8bb156a7f6d25748", + "795ff4db8fae4852a055586f00364be4", + "1ce12659ed8748239d6755f6da042a98", + "2cba33d72df64451856a650f45f1f765", + "8774027c38a1482895e6c21e6d0e662b", + "c2a9db5e3ef54f7fb704f599e65af401", + "4f3818960d504431a7eb18cd7488e2ba", + "77b7b75457ef4c38ac2634eff41a29d4", + "ea3db61523174ad08842b86d8645fe36", + "7adfc44756d64dbd9552b3e4c7315424", + "9280dc685751491d96554819dce773f2", + "cbc7ecd928d841ee9c448fbb1643612c", + "9a6dd0ec39954a46982e711ece20d8c8", + "83f34c62dcc64d6abb92b880877c3b31", + "466e13e3b2864431962a8a0e29473063", + "dc1dc8a06d7f4b949a50e388e84f690f", + "3e93119eb321493281c6a8cbdf5bb226", + "7061bf9eac4d4f4b8f7250fc8043d912", + "783ee37e722c49d69bc856466340d9eb", + "3a4eaa1f68fd4e9cbd36d29815918c6a", + "4e57df313f4a43fab68b281d93893952", + "a089f60df2b9404290a4634deaf0654a", + "b763f3ad27f044139de5d4eb656153a1", + "ecf9f67c924e4ea69fd2a6b9fb9ef2a6", + "203467747a404ad9b34c642b1a05b711", + "620132d19704448386b7aced66a92ab6", + "c1dde60b58ce4a449f0663cd8a6040b6", + "146d64045c554e66afa4dbd4797971e8", + "26c383992f4742ca92e8132e9aee73b9", + "35c6f2532f6f44a6a90ce6c3a508791c" + ] + }, + "id": "_b89GvQKosA0", + "outputId": "afab9178-e2fd-4979-dbe5-5dcee3c3036f" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1685190d1c0b40dc9a04bf961ed1fe9a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/574 [00:00 0, chunk -> 0, score -> 0.60487646}, []}]|\n", + "|[{chunk, 0, 6, Germany, {sentence -> 0, chunk -> 0, score -> 0.39134768}, []}] |\n", + "|[{chunk, 0, 5, Tiger, {sentence -> 0, chunk -> 0, score -> 0.28978878}, []}] |\n", + "|[{chunk, 0, 3, 90°C, {sentence -> 0, chunk -> 0, score -> 0.35916173}, []}] |\n", + "|[{chunk, 0, 6, Jupiter, {sentence -> 0, chunk -> 0, score -> 0.35947314}, []}] |\n", + "|[{chunk, 0, 7, English, {sentence -> 0, chunk -> 0, score -> 0.36399662}, []}] |\n", + "|[{chunk, 0, 11, The Mongols, {sentence -> 0, chunk -> 0, score -> 0.29171973}, []}] |\n", + "|[{chunk, 0, 6, Osmium, {sentence -> 0, chunk -> 0, score -> 0.40618128}, []}] |\n", + "|[{chunk, 0, 13, South America, {sentence -> 0, chunk -> 0, score -> 0.39206758}, []}] |\n", + "|[{chunk, 0, 13, Pablo Picasso, {sentence -> 0, chunk -> 0, score -> 0.4128621}, []}] |\n", + "+-----------------------------------------------------------------------------------------------------------+\n", + "\n" + ] + } + ], + "source": [ + "from sparknlp.base import *\n", + "from sparknlp.annotator import *\n", + "from pyspark.ml import Pipeline, PipelineModel\n", + "\n", + "document_assembler = MultiDocumentAssembler() \\\n", + " .setInputCols([\"question\", \"choices\"]) \\\n", + " .setOutputCols([\"document_question\", \"document_choices\"])\n", + "\n", + "distilbert_for_multiple_choice = DistilBertForMultipleChoice() \\\n", + " .load(f\"{MODEL_NAME}_spark_nlp_openvino\") \\\n", + " .setInputCols([\"document_question\", \"document_choices\"])\\\n", + " .setOutputCol(\"answer\") \\\n", + " .setBatchSize(4)\n", + "\n", + "pipeline = Pipeline(stages=[document_assembler, distilbert_for_multiple_choice])\n", + "pipeline_model = pipeline.fit(testing_df)\n", + "\n", + "pipeline_df = pipeline_model.transform(testing_df)\n", + "\n", + "pipeline_df.select(\"answer\").show(truncate=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lpxiq1igoj6c" + }, + "source": [ + "That's it! You can now go wild and use hundreds of `DistilBertForMultipleChoice` models from HuggingFace 🤗 in Spark NLP 🚀\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "04caa1515af146588adb1552ece36dcd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3787315381de4f4a8f383dffe4dd1318", + "max": 574, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_92efd8cc36b24ba7827eb8ee8d5ddd76", + "value": 574 + } + }, + "0a41311d87454a49b28792762d1ad390": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_586875b1181840658f82ed3791819ea9", + "placeholder": "​", + "style": "IPY_MODEL_e5927cae79b2469a8c72dc30591f895c", + "value": " 574/574 [00:00<00:00, 44.8kB/s]" + } + }, + "103d60fc85664ea6852c6cf65a838f08": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_abde155ca5604caea5996b7eb88e80a3", + "placeholder": "​", + "style": "IPY_MODEL_f2e9969ee9364f8b96c0ccbe17b04dd3", + "value": " 268M/268M [00:06<00:00, 42.4MB/s]" + } + }, + "11e9347df720416ab68f7254699e4bfd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "146d64045c554e66afa4dbd4797971e8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "153a98ebce99415bbaa8f3bcf0e003a3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1685190d1c0b40dc9a04bf961ed1fe9a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_8faf63a59311431e8a4f8149b3d09be3", + "IPY_MODEL_04caa1515af146588adb1552ece36dcd", + "IPY_MODEL_0a41311d87454a49b28792762d1ad390" + ], + "layout": "IPY_MODEL_df80408d1d9f4995bce81c9657a16902" + } + }, + "16da8069c945418f8bb156a7f6d25748": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4f3818960d504431a7eb18cd7488e2ba", + "placeholder": "​", + "style": "IPY_MODEL_77b7b75457ef4c38ac2634eff41a29d4", + "value": " 232k/232k [00:00<00:00, 2.76MB/s]" + } + }, + "1ce12659ed8748239d6755f6da042a98": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1da886fbe0a341059120356697806776": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8774027c38a1482895e6c21e6d0e662b", + "max": 231508, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c2a9db5e3ef54f7fb704f599e65af401", + "value": 231508 + } + }, + "203467747a404ad9b34c642b1a05b711": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "26c383992f4742ca92e8132e9aee73b9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2cba33d72df64451856a650f45f1f765": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "35c6f2532f6f44a6a90ce6c3a508791c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3787315381de4f4a8f383dffe4dd1318": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3935021ef68448f394a93f6a6215724b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "3a4eaa1f68fd4e9cbd36d29815918c6a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_4e57df313f4a43fab68b281d93893952", + "IPY_MODEL_a089f60df2b9404290a4634deaf0654a", + "IPY_MODEL_b763f3ad27f044139de5d4eb656153a1" + ], + "layout": "IPY_MODEL_ecf9f67c924e4ea69fd2a6b9fb9ef2a6" + } + }, + "3e93119eb321493281c6a8cbdf5bb226": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "4143428097184e239be3aa15a709a1bd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "466e13e3b2864431962a8a0e29473063": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "48d8e523726b426ab553e158e1738d35": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "497ac3bcc25b4710994554005e19ca36": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_153a98ebce99415bbaa8f3bcf0e003a3", + "placeholder": "​", + "style": "IPY_MODEL_56b0be1d416b4e968a9b5380367dcd77", + "value": "model.safetensors: 100%" + } + }, + "4d1ce1e3122d4685895e531179f1a6f8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ebb172d5ac1c4e2381a5e02b3d009d71", + "max": 1224, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_3935021ef68448f394a93f6a6215724b", + "value": 1224 + } + }, + "4e57df313f4a43fab68b281d93893952": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_203467747a404ad9b34c642b1a05b711", + "placeholder": "​", + "style": "IPY_MODEL_620132d19704448386b7aced66a92ab6", + "value": "special_tokens_map.json: 100%" + } + }, + "4f3818960d504431a7eb18cd7488e2ba": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "545fcedce22f463e9745cf2785c6260f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "56b0be1d416b4e968a9b5380367dcd77": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "586875b1181840658f82ed3791819ea9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5acb405e46c7407aa3d38e49d4087a0b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e27b262eb28f48a8b39c32a4f65bf74f", + "IPY_MODEL_1da886fbe0a341059120356697806776", + "IPY_MODEL_16da8069c945418f8bb156a7f6d25748" + ], + "layout": "IPY_MODEL_795ff4db8fae4852a055586f00364be4" + } + }, + "620132d19704448386b7aced66a92ab6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6758258d40b34089bb9a2ff3d3008c98": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "7061bf9eac4d4f4b8f7250fc8043d912": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7154097716564b4db750626f7efdffc7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_48d8e523726b426ab553e158e1738d35", + "max": 267829484, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_6758258d40b34089bb9a2ff3d3008c98", + "value": 267829484 + } + }, + "77b7b75457ef4c38ac2634eff41a29d4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "783ee37e722c49d69bc856466340d9eb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "795ff4db8fae4852a055586f00364be4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7adfc44756d64dbd9552b3e4c7315424": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_83f34c62dcc64d6abb92b880877c3b31", + "placeholder": "​", + "style": "IPY_MODEL_466e13e3b2864431962a8a0e29473063", + "value": "tokenizer.json: 100%" + } + }, + "83f34c62dcc64d6abb92b880877c3b31": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8774027c38a1482895e6c21e6d0e662b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8faf63a59311431e8a4f8149b3d09be3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d09eb9036f7042979767eee8b8adafdd", + "placeholder": "​", + "style": "IPY_MODEL_e45f9f7671f84e5783ba53cdafbf0eee", + "value": "config.json: 100%" + } + }, + "9280dc685751491d96554819dce773f2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_dc1dc8a06d7f4b949a50e388e84f690f", + "max": 711396, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_3e93119eb321493281c6a8cbdf5bb226", + "value": 711396 + } + }, + "92efd8cc36b24ba7827eb8ee8d5ddd76": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "94034333eb9f4399ae56edfa1b9d4b9b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9a6dd0ec39954a46982e711ece20d8c8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a089f60df2b9404290a4634deaf0654a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c1dde60b58ce4a449f0663cd8a6040b6", + "max": 125, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_146d64045c554e66afa4dbd4797971e8", + "value": 125 + } + }, + "abde155ca5604caea5996b7eb88e80a3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b763f3ad27f044139de5d4eb656153a1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_26c383992f4742ca92e8132e9aee73b9", + "placeholder": "​", + "style": "IPY_MODEL_35c6f2532f6f44a6a90ce6c3a508791c", + "value": " 125/125 [00:00<00:00, 8.24kB/s]" + } + }, + "bda09c689b7e451b9f205bcd89e2892c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d74e96dcba28405bb0b90863c372b9a3", + "placeholder": "​", + "style": "IPY_MODEL_4143428097184e239be3aa15a709a1bd", + "value": "tokenizer_config.json: 100%" + } + }, + "c1dde60b58ce4a449f0663cd8a6040b6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c2a9db5e3ef54f7fb704f599e65af401": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c77c1345e62441ec9f0cccefd74680ae": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_bda09c689b7e451b9f205bcd89e2892c", + "IPY_MODEL_4d1ce1e3122d4685895e531179f1a6f8", + "IPY_MODEL_ea563bffe9a9402797647d6e57726b07" + ], + "layout": "IPY_MODEL_11e9347df720416ab68f7254699e4bfd" + } + }, + "cbc7ecd928d841ee9c448fbb1643612c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7061bf9eac4d4f4b8f7250fc8043d912", + "placeholder": "​", + "style": "IPY_MODEL_783ee37e722c49d69bc856466340d9eb", + "value": " 711k/711k [00:00<00:00, 5.21MB/s]" + } + }, + "cd2acd5f7e7f4811906e9efdf91cf2f3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_497ac3bcc25b4710994554005e19ca36", + "IPY_MODEL_7154097716564b4db750626f7efdffc7", + "IPY_MODEL_103d60fc85664ea6852c6cf65a838f08" + ], + "layout": "IPY_MODEL_94034333eb9f4399ae56edfa1b9d4b9b" + } + }, + "d09eb9036f7042979767eee8b8adafdd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d74e96dcba28405bb0b90863c372b9a3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dc1dc8a06d7f4b949a50e388e84f690f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "df80408d1d9f4995bce81c9657a16902": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e27b262eb28f48a8b39c32a4f65bf74f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1ce12659ed8748239d6755f6da042a98", + "placeholder": "​", + "style": "IPY_MODEL_2cba33d72df64451856a650f45f1f765", + "value": "vocab.txt: 100%" + } + }, + "e45f9f7671f84e5783ba53cdafbf0eee": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e5927cae79b2469a8c72dc30591f895c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ea3db61523174ad08842b86d8645fe36": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_7adfc44756d64dbd9552b3e4c7315424", + "IPY_MODEL_9280dc685751491d96554819dce773f2", + "IPY_MODEL_cbc7ecd928d841ee9c448fbb1643612c" + ], + "layout": "IPY_MODEL_9a6dd0ec39954a46982e711ece20d8c8" + } + }, + "ea563bffe9a9402797647d6e57726b07": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f73b5c70ce424707951099c81a54ec0f", + "placeholder": "​", + "style": "IPY_MODEL_545fcedce22f463e9745cf2785c6260f", + "value": " 1.22k/1.22k [00:00<00:00, 107kB/s]" + } + }, + "ebb172d5ac1c4e2381a5e02b3d009d71": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ecf9f67c924e4ea69fd2a6b9fb9ef2a6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f2e9969ee9364f8b96c0ccbe17b04dd3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "f73b5c70ce424707951099c81a54ec0f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}