Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One to one clustering #2578

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

aymonwuolanne
Copy link
Contributor

Type of PR

  • BUG
  • FEAT
  • MAINT
  • DOC

Is your Pull Request linked to an existing Issue or Pull Request?

Related to #2562 and #251.

Give a brief description for the solution you have provided

There is some discussion of the method in #2562. This PR implements an alternative to cluster_pairwise_predictions_at_threshold which is called cluster_using_single_best_links (more exciting name suggestions welcome).

The goal of this clustering method is to produce clusters where for each cluster and for each dataset in the source_datasets list at most one record from that dataset can be in the cluster. To do this, at each iteration it only accepts a link if it is the single best link for the left id and the right id, and if accepting that link will not create any duplicates.

To deal with ties (e.g. where A1 links to B1 and to B2 with the same match probability) this implementation uses row_number rather than rank which arbitrarily (note: not randomly) picks one of the edges tied for first. A good extension on this would be to implement some other options for dealing with ties, but I think in most Splink applications there are very few ties, especially when term frequency adjustments are being used.

PR Checklist

  • Added documentation for changes
  • Added feature to example notebooks or tutorial (if appropriate)
  • Added tests (if appropriate)
  • Updated CHANGELOG.md (if appropriate)
  • Made changes based off the latest version of Splink
  • Run the linter
  • Run the spellchecker (if appropriate)

@RobinL
Copy link
Member

RobinL commented Jan 15, 2025

Just to say sorry for the delay on this. I'm starting looking at this now!

I've got as far as reading the messages and looking at the tests. I'm mainly uploading the following to help me remember if I look back at this in future:

So in your example we have:
image

And the solution is:
image

I was first confused that in the solution we didn't have 6a -> 7b, but then I read your message:

I like the simplicity of only accepting the best links and not moving on to the next best ones if the best option is already "taken". I think this corresponds to the Unique Mapping Clustering in the paper I mentioned.

So, am I right in saying the logic goes something like:

The best link from 6a is to 5c at 0.8, hence we ignore 6a to 7b. But 6a to 5c eliminated by the stronger 0.85 link from 3a to 5c. And hence we end up with 6a and 7b separate.

Or, as Andy on the team suggested:

First the picking of the best links, and then forcing the dataset uniqueness.

I like how you've named the function (cluster_using_single_best_links) which I think reflects this well

@aymonwuolanne
Copy link
Contributor Author

Thanks @RobinL, hope you had a well-deserved break!

And yes, that is the expected behaviour. The best link for 6a is to 5c, but it cannot link to that cluster because of the duplicate criteria, so it will remain unlinked. Andy's comment sounds exactly right.

In my opinion this approach made the most sense, but a very easy change to the code would let 6a link to the next best record after all the edges violating the duplicate criteria are removed (in other words, first forcing dataset uniqueness, then picking best links).

@RobinL
Copy link
Member

RobinL commented Jan 22, 2025

I've stepped through the code this morning and it looks good - particularly nice to see how you've kept it stylistically consistent with the other clustering code, and remembered to handle edge cases like absence of match_probability - thanks.

I think we'll be able to merge a version of this, but before we do I want to dig a bit further into the 6a -> 7b issue.

Having spoken with the team, I think we have a reservation around the 6a -> 7b link, which comes down to the initial surprise that it's not made, but then trying to explain convincingly at a 'high level' why, in a linkage context, we would drop this link. (i.e. what is the argument why we should conclude it's invalid).

The best I can come up with is:

  1. We have prior knowledge of the 'one to one' condition
  2. From this knowledge, we conclude 6a to 5c is a false positive
  3. If 6a to 5c is a false positive, and 6a to 7b has a lower score, it must be a false positive too (despite it scoring above the threshold)

But I think we're not convinced by (3). We're not convinced that the information that 6a to 5c is a FP somehow 'correlates' or otherwise tells us something is wrong with 6a to 7b. Are we saying that we think the 'problem' or 'edge case' that causes the model to incorrectly score 6a to 5c also is likely to apply to 6a to 7b? i.e. there's some sort of shared error mechanism.

Do you think we're missing something here? Or am I thinking along the wrong lines entirely? If the code change is relatively straightforward, one option could be to allow the user to choose which version they want.

@RobinL
Copy link
Member

RobinL commented Jan 22, 2025

I've also been experimenting with some additional examples. To this end I've build this an interactive example builder. Drag between nodes to create links, hover and drag up and down on links to change probability. Click nodes to change source dataset (represented by colour). Click to create new nodes.

You can then paste the results into this to run the clustering and see the results (open vega_output.html after running)

Click to expand
import json

import duckdb
import pandas as pd

from splink import DuckDBAPI, Linker, SettingsCreator

# --------------------------------------------------------------------------------------
# PASTE DATA HERE
graph_data = {
    "nodes": [
        {"unique_id": "1", "source_dataset": "a"},
        {"unique_id": "2", "source_dataset": "b"},
        {"unique_id": "3", "source_dataset": "a"},
        {"unique_id": "4", "source_dataset": "b"},
        {"unique_id": "5", "source_dataset": "a"},
        {"unique_id": "6", "source_dataset": "b"},
        {"unique_id": "7", "source_dataset": "d"},
    ],
    "links": [
        {
            "unique_id_l": "1",
            "source_dataset_l": "a",
            "unique_id_r": "2",
            "source_dataset_r": "b",
            "match_probability": 0.92,
        },
        {
            "unique_id_l": "2",
            "source_dataset_l": "b",
            "unique_id_r": "3",
            "source_dataset_r": "a",
            "match_probability": 0.9,
        },
        {
            "unique_id_l": "3",
            "source_dataset_l": "a",
            "unique_id_r": "4",
            "source_dataset_r": "b",
            "match_probability": 0.99,
        },
        {
            "unique_id_l": "4",
            "source_dataset_l": "b",
            "unique_id_r": "5",
            "source_dataset_r": "a",
            "match_probability": 0.9,
        },
        {
            "unique_id_l": "5",
            "source_dataset_l": "a",
            "unique_id_r": "6",
            "source_dataset_r": "b",
            "match_probability": 0.9,
        },
        {
            "unique_id_l": "6",
            "source_dataset_l": "b",
            "unique_id_r": "1",
            "source_dataset_r": "a",
            "match_probability": 0.96,
        },
        {
            "unique_id_l": "4",
            "source_dataset_l": "b",
            "unique_id_r": "7",
            "source_dataset_r": "d",
            "match_probability": 0.9,
        },
    ],
}

nodes_df = pd.DataFrame(graph_data["nodes"])
links_df = pd.DataFrame(graph_data["links"])


# --------------------------------------------------------------------------------------


df = pd.DataFrame(graph_data["nodes"])
predictions = pd.DataFrame(graph_data["links"])


# Define settings
settings = SettingsCreator(
    link_type="link_only",
    comparisons=[],
    blocking_rules_to_generate_predictions=[],
)

# Initialize linker
con = duckdb.connect(":memory:")
db_api = DuckDBAPI(con)
linker = Linker(df, settings, db_api=db_api)
# linker._debug_mode = True
# Register predictions
df_predict = linker.table_management.register_table_predict(predictions, overwrite=True)

# linker._debug_mode = True
df_clusters = linker.clustering.cluster_using_single_best_links(
    df_predict, source_datasets=["a", "b", "d"], threshold_match_probability=0.5
)
df_clusters.as_duckdbpyrelation().show(max_width=1000)

# After getting df_clusters, convert to pandas and prepare data for visualization
df_clusters_pd = df_clusters.as_pandas_dataframe()

# Create nodes with cluster information
nodes_with_clusters = df_clusters_pd.merge(nodes_df, on=["unique_id", "source_dataset"])
nodes_for_vega = [
    {
        "id": row["unique_id"],
        "name": f"Node {row['unique_id']}",
        "group": row["cluster_id"],
        "source_dataset": row["source_dataset"],
    }
    for _, row in nodes_with_clusters.iterrows()
]

# Create a mapping of IDs to indices
id_to_index = {node["id"]: i for i, node in enumerate(nodes_for_vega)}

# Create original probability-based links
original_links = []
for _, row in links_df.iterrows():
    source_id = row["unique_id_l"]
    target_id = row["unique_id_r"]
    original_links.append(
        {
            "source": source_id,
            "target": target_id,
            "value": row["match_probability"],
        }
    )

# Create cluster-based links (renamed from links_for_vega)
cluster_links = []
clusters = df_clusters_pd.groupby("cluster_id")["unique_id"].apply(list)
for cluster_nodes in clusters:
    if len(cluster_nodes) > 1:
        for i in range(len(cluster_nodes)):
            for j in range(i + 1, len(cluster_nodes)):
                cluster_links.append(
                    {
                        "source": cluster_nodes[i],
                        "target": cluster_nodes[j],
                        "value": 1.0,
                    }
                )


# Create base visualization spec as a function
def create_vega_spec(nodes, links, title):
    return {
        "$schema": "https://vega.github.io/schema/vega/v5.json",
        "description": title,
        "width": 700,
        "height": 500,
        "padding": 0,
        "autosize": "none",
        "signals": [
            {"name": "static", "value": False, "bind": {"input": "checkbox"}},
            {
                "name": "fix",
                "value": False,
                "on": [
                    {
                        "events": "symbol:mouseout[!event.buttons], window:mouseup",
                        "update": "false",
                    },
                    {"events": "symbol:mouseover", "update": "fix || true"},
                    {
                        "events": "[symbol:mousedown, window:mouseup] > window:mousemove!",
                        "update": "xy()",
                        "force": True,
                    },
                ],
            },
            {
                "name": "node",
                "value": None,
                "on": [
                    {
                        "events": "symbol:mouseover",
                        "update": "fix === true ? item() : node",
                    }
                ],
            },
            {
                "name": "restart",
                "value": False,
                "on": [{"events": {"signal": "fix"}, "update": "fix && fix.length"}],
            },
            {"name": "cx", "update": "width / 2"},
            {"name": "cy", "update": "height / 2"},
            {
                "name": "nodeRadius",
                "value": 15,
                "bind": {"input": "range", "min": 1, "max": 50, "step": 1},
            },
            {
                "name": "nodeCharge",
                "value": -30,
                "bind": {"input": "range", "min": -2000, "max": 500, "step": 1},
            },
            {
                "name": "linkDistance",
                "value": 100,
                "bind": {"input": "range", "min": 5, "max": 200, "step": 1},
            },
            {
                "name": "linkStrength",
                "value": 0.5,
                "bind": {"input": "range", "min": 0.0, "max": 2.0, "step": 0.01},
            },
        ],
        "data": [
            {"name": "node-data", "values": nodes},
            {"name": "link-data", "values": links},
        ],
        "scales": [
            {
                "name": "color",
                "type": "ordinal",
                "domain": {"data": "node-data", "field": "source_dataset"},
                "range": {"scheme": "category10"},
            },
            {
                "name": "linkStrength",
                "type": "linear",
                "domain": [0, 1],
                "range": [0.5, 3],
            },
        ],
        "marks": [
            {
                "name": "links",
                "type": "path",
                "from": {"data": "link-data"},
                "interactive": False,
                "encode": {
                    "update": {
                        "stroke": {"value": "#ccc"},
                        "strokeWidth": {"scale": "linkStrength", "field": "value"},
                    }
                },
                "transform": [
                    {
                        "type": "linkpath",
                        "require": {"signal": "force"},
                        "shape": "line",
                        "sourceX": "datum.source.x",
                        "sourceY": "datum.source.y",
                        "targetX": "datum.target.x",
                        "targetY": "datum.target.y",
                    }
                ],
            },
            {
                "name": "nodes",
                "type": "symbol",
                "zindex": 1,
                "from": {"data": "node-data"},
                "on": [
                    {
                        "trigger": "fix",
                        "modify": "node",
                        "values": "fix === true ? {fx: node.x, fy: node.y} : {fx: fix[0], fy: fix[1]}",
                    }
                ],
                "encode": {
                    "enter": {
                        "fill": {"scale": "color", "field": "source_dataset"},
                        "stroke": {"value": "white"},
                    },
                    "update": {
                        "size": {"signal": "2 * nodeRadius * nodeRadius"},
                        "cursor": {"value": "pointer"},
                    },
                },
                "transform": [
                    {
                        "type": "force",
                        "iterations": 300,
                        "restart": {"signal": "restart"},
                        "static": {"signal": "static"},
                        "signal": "force",
                        "forces": [
                            {
                                "force": "center",
                                "x": {"signal": "cx"},
                                "y": {"signal": "cy"},
                            },
                            {
                                "force": "collide",
                                "radius": {"signal": "nodeRadius"},
                                "strength": 1,
                            },
                            {
                                "force": "nbody",
                                "strength": {"signal": "nodeCharge"},
                            },
                            {
                                "force": "link",
                                "links": "link-data",
                                "id": "datum.id",
                                "distance": {"signal": "linkDistance"},
                                "strength": {"signal": "linkStrength"},
                            },
                        ],
                    }
                ],
            },
            {
                "type": "text",
                "from": {"data": "link-data"},
                "interactive": False,
                "encode": {
                    "enter": {
                        "align": {"value": "center"},
                        "baseline": {"value": "middle"},
                        "fontSize": {"value": 10},
                        "fill": {"value": "#666"},
                    },
                    "update": {
                        "x": {"signal": "(datum.source.x + datum.target.x) / 2"},
                        "y": {"signal": "(datum.source.y + datum.target.y) / 2"},
                        "text": {"signal": "format(datum.value, '.2f')"},
                    },
                },
                "transform": [
                    {
                        "type": "linkpath",
                        "require": {"signal": "force"},
                        "shape": "line",
                        "sourceX": "datum.source.x",
                        "sourceY": "datum.source.y",
                        "targetX": "datum.target.x",
                        "targetY": "datum.target.y",
                    }
                ],
            },
            {
                "type": "text",
                "from": {"data": "nodes"},
                "interactive": False,
                "zindex": 2,
                "encode": {
                    "enter": {
                        "align": {"value": "center"},
                        "baseline": {"value": "middle"},
                        "fontSize": {"value": 15},
                        "fontWeight": {"value": "bold"},
                        "text": {
                            "signal": "datum.datum.source_dataset + '-' + datum.datum.id"
                        },
                    },
                    "update": {
                        "x": {"field": "x"},
                        "y": {"field": "y"},
                    },
                },
            },
        ],
    }


# Write both visualizations to HTML
with open("vega_output.html", "w") as f:
    f.write(
        """
<!DOCTYPE html>
<html>
<head>
    <script src="https://cdn.jsdelivr.net/npm/vega@5"></script>
    <script src="https://cdn.jsdelivr.net/npm/vega-lite@5"></script>
    <script src="https://cdn.jsdelivr.net/npm/vega-embed@6"></script>
    <style>
        .vis-container {
            margin: 20px;
            padding: 20px;
            border: 1px solid #ccc;
        }
        h2 {
            text-align: center;
        }
    </style>
</head>
<body>
    <div class="vis-container">
        <h2>Original Links with Probabilities</h2>
        <div id="view1"></div>
    </div>
    <div class="vis-container">
        <h2>Clustered Results</h2>
        <div id="view2"></div>
    </div>
    <script>
        const spec1 = """
        + json.dumps(create_vega_spec(nodes_for_vega, original_links, "Original Links"))
        + """;
        const spec2 = """
        + json.dumps(
            create_vega_spec(nodes_for_vega, cluster_links, "Clustered Results")
        )
        + """;
        vegaEmbed('#view1', spec1).catch(console.error);
        vegaEmbed('#view2', spec2).catch(console.error);
    </script>
</body>
</html>
"""
    )

So this input

image

Turns into (note the colours are not consistent between the tools)
Before:
image

After:

image

@RobinL
Copy link
Member

RobinL commented Jan 22, 2025

I've shot myself in the foot a couple of times experimenting with ☝️ and failing to update the source_datasets=["a", "b", "d"] bit of this:

df_clusters = linker.clustering.cluster_using_single_best_links(
    df_predict, source_datasets=["a", "b",  "d"], threshold_match_probability=0.5
)

Is there a reason we can't do that automatically?

@aymonwuolanne
Copy link
Contributor Author

But I think we're not convinced by (3). We're not convinced that the information that 6a to 5c is a FP somehow 'correlates' or otherwise tells us something is wrong with 6a to 7b. Are we saying that we think the 'problem' or 'edge case' that causes the model to incorrectly score 6a to 5c also is likely to apply to 6a to 7b? i.e. there's some sort of shared error mechanism.

This is a good point. I've asked around and most people arrive at the conclusion that 6a and 7b should be linked, and I think I agree that's more sensible. I think I had in mind some scenarios where the assumption of no duplicates is not 100% guaranteed, in which case it's possible that 6a is more likely to belong to the other cluster and 7b would be left by itself, but this is a stretch and if we're assuming there are no duplicates then we might as well believe the assumption fully.

I'm happy to change the method around to that approach (first forcing dataset uniqueness, then ranking links), I'll leave it up to you whether it should be added as an option or not. In general for users I feel like this is quite nuanced to understand so perhaps it's simpler with no option.

Is there a reason we can't do that automatically?

Yes, actually! If you have three datasets a, b, c, and you want to assume that a and b have no duplicates, then you can use source_datasets = ["a", "b"]. Maybe that needs to be more clear from the variable name?

@RobinL
Copy link
Member

RobinL commented Jan 23, 2025

OK, sounds good. Yes please, that'd be great if you could change the method accordingly.

Also - agreed, it's simpler with no option, so if you're happy with the new approach, let's just do that.

Yes, actually! If you have three datasets a, b, c, and you want to assume that a and b have no duplicates, then you can use source_datasets = ["a", "b"]. Maybe that needs to be more clear from the variable name?

Aha - I understand now. Yeah, I think the arg name could be clearer, though I can't think of a good name off the top of my head.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants