Skip to content

Commit

Permalink
Merge pull request #38 from ChenghaoMou/near-dedup-improve
Browse files Browse the repository at this point in the history
Near dedup improve
  • Loading branch information
ChenghaoMou authored May 11, 2023
2 parents c306a59 + 9866ca3 commit 966a101
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 12 deletions.
27 changes: 22 additions & 5 deletions near_deduplication/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ And make sure you have git-lfs installed.
### Usage

```bash
# For details on the arguments, see the help message
python minhash_deduplication.py --help
# Quick example
python minhash_deduplication.py --dataset codeparrot/codeparrot-clean-valid \
python minhash_deduplication.py --dataset codeparrot/codeparrot-clean-valid \
--split train \
--column content \
--cache-dir .cache \
--verbose
--min-ngram-size 5
# For details on the arguments, see the help message
python minhash_deduplication.py --help
```
Expand Down Expand Up @@ -58,10 +60,25 @@ gcloud dataproc jobs submit pyspark --cluster ${CLUSTER_NAME} \
near_deduplication/minhash_deduplication_spark.py \
-- \
--table "huggingface-science-codeparrot.the_stack_java.java" \
--output "gs://chenghao-data/dataproc_output/deduplicated"
--output "gs://chenghao-data/dataproc_output/deduplicated" \
--min_ngram_size 5 \
--ngram_size 5 \
--threshold 0.7
```

With above settings, it took about 40 minutes to deduplicate the Java subset (42 million docs, 319GB), 15x faster than the following python implementation in a comparable single-machine environment.
With above settings, it took about 40 minutes to deduplicate the Java subset (42 million docs, 319GB), 15x faster than the following python implementation in a comparable single-machine environment. In terms of scaling, it took about 5 hours to deduplicate the 1.3 TB json subset of the Stack with a 15-machine cluster.

Warning: Big Query might change your list schema in the output! You can use the following code to restore the format (credit to [@RaymondLi0](https://github.com/RaymondLi0)):

```python
LIST_COLUMNS = ['max_stars_repo_licenses', 'max_issues_repo_licenses', 'max_forks_repo_licenses']
def fix_license_cols(example):
for col in LIST_COLUMNS:
example[col] = [x["item"] for x in example[col]["list"]]
return example
...
ds = ds.map(fix_license_cols)
```

#### Python Implementation Analysis

Expand All @@ -79,7 +96,7 @@ To understand the limitation of current deduplication implementation, it is impo

We report here some stats on the experiments we did along the way with a 80-core machine on GCP (M1):

For SantaCoder, our results can be replicated by the following commands:
For SantaCoder, our results can be replicated by the following commands using [older version of the code](https://github.com/bigcode-project/bigcode-analysis/tree/fdd70bffb9cceab031f6d682edf83b5af49e8aaf):

```bash
python minhash_deduplication.py --dataset bigcode/the-stack-dedup-pjj --data-dir data/java --revision v1.1.a1 --cache-dir cache2 --ngram-size 5 --threshold 0.7 --min-token-length 10 --fast
Expand Down
15 changes: 11 additions & 4 deletions near_deduplication/minhash_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
datasets.logging.set_verbosity_error()


def ngrams(sequence: List[str], n: int) -> Iterable:
def ngrams(sequence: List[str], n: int, min_ngram_size: int) -> Iterable:
"""
Directly taken from nltk package to avoid dependency.
Expand All @@ -53,14 +53,16 @@ def ngrams(sequence: List[str], n: int) -> Iterable:
The sequence of items to be n-grammed.
n : int
The order of the n-grams to be extracted.
min_ngram_size : int
The minimum size of n-grams.
Returns
-------
Iterable
The n-grams generated from the sequence.
"""
if len(sequence) < n:
return sequence
if len(sequence) < min_ngram_size:
return []
iterables = tee(sequence, n)
for i, sub_iterable in enumerate(iterables):
for _ in range(i):
Expand Down Expand Up @@ -91,6 +93,7 @@ def embed_func(
ngram_size: int,
hashranges: List[Tuple[int, int]],
permutations: np.ndarray,
min_ngram_size: int = 5,
) -> Dict[str, Any]:
"""
Combined with some datasketch code to better parallelize computation.
Expand All @@ -109,14 +112,16 @@ def embed_func(
The ranges of hash values.
permutations : np.ndarray
The permutations for the minhash.
min_ngram_size : int
The minimum size of n-grams.
Returns
-------
Dict[str, Any]
The hash values in each range and the index.
"""
hashvalues = np.ones(num_perm, dtype=np.uint64) * MAX_HASH
tokens = {" ".join(t) for t in ngrams(NON_ALPHA.split(content), ngram_size)}
tokens = {" ".join(t) for t in ngrams(NON_ALPHA.split(content), ngram_size, min_ngram_size)}
hv = np.array([sha1_hash32(token.encode("utf-8")) for token in tokens], dtype=np.uint64) # noqa: E501
a, b = permutations
phv = np.bitwise_and(((hv * np.tile(a, (len(hv), 1)).T).T + b) % MERSENNE_PRIME, MAX_HASH) # noqa: E501
Expand Down Expand Up @@ -215,6 +220,7 @@ def run(
ngram_size: int = typer.Option(5, help="The ngram size to use for MinHash"),
num_perm: int = typer.Option(256, help="Number of permutations"),
threshold: float = typer.Option(0.7, help="Minhash threshold"),
min_ngram_size: int = typer.Option(5, help="Shorter documents will be removed"),
output: str = typer.Option(None, help="Store the deduplicated dataset"),
):
global uf
Expand Down Expand Up @@ -263,6 +269,7 @@ def run(
"hashranges": HASH_RANGES,
"ngram_size": ngram_size,
"permutations": PERMUTATIONS,
"min_ngram_size": min_ngram_size,
},
input_columns=[column],
remove_columns=ds.column_names,
Expand Down
25 changes: 22 additions & 3 deletions near_deduplication/minhash_deduplication_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def small_star_reduce(group):
return [(n, minimum) for n in nodes if n != minimum]


def ngrams(sequence: List[str], n: int) -> Iterable:
def ngrams(sequence: List[str], n: int, min_ngram_size: int = 5) -> Iterable:
"""
Code taken from NLTK, without padding.
Expand All @@ -58,6 +58,8 @@ def ngrams(sequence: List[str], n: int) -> Iterable:
The sequence of items to be converted into n-grams.
n : int
The order of the n-grams to be extracted.
min_ngram_size : int
The minimum number of items in the sequence to generate n-grams.
Returns
-------
Expand All @@ -71,6 +73,9 @@ def ngrams(sequence: List[str], n: int) -> Iterable:
>>> list(ngrams(['a', 'b', 'c', 'd'], 3))
[('a', 'b', 'c'), ('b', 'c', 'd')]
"""
if len(sequence) < min_ngram_size:
return []

iterables = tee(sequence, n)
for i, sub_iterable in enumerate(iterables):
for _ in range(i):
Expand Down Expand Up @@ -110,6 +115,7 @@ def generate_hash_values(
ngram_size: int,
hashranges: List[Tuple[int, int]],
permutations: np.ndarray,
min_ngram_size: int,
) -> List[Tuple[int, bytes, int]]:
"""
Generate the MinHashLSH values for a given document.
Expand All @@ -128,14 +134,16 @@ def generate_hash_values(
The ranges of offsets for each hash value.
permutations : np.ndarray
The permutations for the hash values.
min_ngram_size : int
The minimum number of items in the sequence to generate n-grams.
Returns
-------
List[Tuple[int, bytes, int]]
The list of (band_idx, hash value, idx) for the document.
"""
hashvalues = np.ones(num_perm, dtype=np.uint64) * MAX_HASH
tokens = {" ".join(t) for t in ngrams(NON_ALPHA.split(content), ngram_size)}
tokens = {" ".join(t) for t in ngrams(NON_ALPHA.split(content), ngram_size, min_ngram_size)}
hv = np.array([sha1_hash32(token.encode("utf-8")) for token in tokens], dtype=np.uint64)
a, b = permutations
phv = np.bitwise_and(((hv * np.tile(a, (len(hv), 1)).T).T + b) % MERSENNE_PRIME, MAX_HASH)
Expand Down Expand Up @@ -238,6 +246,7 @@ def generate_edges(nodes: List[int]) -> List[Tuple[int, int]]:
parser = argparse.ArgumentParser(description="Near-deduplicating BigQuery Table with PySpark")
parser.add_argument("--table", type=str, required=True, help="BigQuery table to deduplicate")
parser.add_argument("--threshold", type=float, default=0.7, help="Similarity threshold")
parser.add_argument("--min_ngram_size", type=int, default=5, help="Shorter docs will be removed")
parser.add_argument("--ngram_size", type=int, default=5, help="N-gram size")
parser.add_argument("--num_perm", type=int, default=256, help="Number of permutations")
parser.add_argument("--b", type=int, default=None, help="Number of bands")
Expand Down Expand Up @@ -284,6 +293,7 @@ def generate_edges(nodes: List[int]) -> List[Tuple[int, int]]:
ngram_size=args.ngram_size,
hashranges=HASH_RANGES,
permutations=PERMUTATIONS,
min_ngram_size=args.min_ngram_size,
)
)
.groupBy(lambda x: (x[0], x[1]))
Expand All @@ -303,10 +313,19 @@ def generate_edges(nodes: List[int]) -> List[Tuple[int, int]]:
results = a.collect()
if len(results) == 0:
log.info("No components found.")
df.write.option(
"maxRecordsPerFile", 300_000
).option(
"intermediateFormat", "orc"
).parquet(args.output, mode="overwrite")
sys.exit(0)

components = spark.createDataFrame(results, schema=["__id__", "component"]).sort(["component", "__id__"])
components.show()
df = df.join(components, on="__id__", how="left")
df = df.filter(F.col("component").isNull()).drop("__id__", "component").cache()
df.write.json(args.output, mode="overwrite")
df.write.option(
"maxRecordsPerFile", 300_000
).option(
"intermediateFormat", "orc"
).parquet(args.output, mode="overwrite")

0 comments on commit 966a101

Please sign in to comment.