Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Sep 27, 2024
1 parent 92a86c3 commit ae16e9d
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 66 deletions.
11 changes: 3 additions & 8 deletions monai_vila2d/data_prepare/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
- PathVQA: Pathology-based VQA dataset with ~4,000 images and ~32,000 QA pairs, focusing on microscopic views of human tissue.
- RadVQA: Radiology VQA dataset containing ~7,000 images and ~25,000 QA pairs, covering various imaging modalities like X-rays and CT scans.
- SLAKE: Specialized medical VQA dataset with ~14,000 images and ~45,000 QA pairs, emphasizing anatomy, modality, and abnormality questions.
- MIMIC-VQA: Large-scale medical VQA dataset derived from MIMIC-CXR, featuring ~220,000 chest X-ray images and ~900,000 QA pairs.
- Medical-Diff-VQA: Medical-Diff-VQA dataset, a derivative of the MIMIC-CXR dataset, consists of questions categorized into seven categories: abnormality, location, type, level, view, presence, and difference. We currently exclude the difference category in our training preparation.

### Report Generation Datasets

Expand All @@ -14,16 +14,11 @@
experts
- CT
- CXR
- etc.

pathvqa
radvqa
slake
mimic_vqa
- MRI

| Dataset | QA Pairs | Images | Link |
|-----------|-----------|-----------|------|
| PathVQA | ~32,000 | ~4,000 | [PathVQA](https://github.com/UCSD-AI4H/PathVQA) |
| RadVQA | ~25,000 | ~7,000 | [RadVQA](https://github.com/abachaa/VQA-Med-2019) |
| SLAKE | ~45,000 | ~14,000 | [SLAKE](https://github.com/SLAKE-SLAKE/SLAKE) |
| MIMIC-VQA | ~900,000 | ~220,000 | [MIMIC-VQA](https://physionet.org/content/mimic-cxr-vqa/1.0.0/) |
| Medical-Diff-VQA | ~429,000 | 129,232 | [MIMIC-VQA](https://physionet.org/content/medical-diff-vqa/1.0.0) |
13 changes: 12 additions & 1 deletion monai_vila2d/data_prepare/experts/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
# VLM
# Expert data preparation

# 1. Prepare expert training data for VISTA3D

We can take existing CT datasets, run VISTA3D inference on it and use the results to generate training conversation for M3.

```commandline
export PYTHONPATH=${PWD}/..
ROOT_DIR=../../data/experts/vista3d/inference_results
OUT_FILEPREFIX="../../data/experts/vista3d/llama_gen_expert_data_vista3d_what"
python expert_train_data_cxr.py --in_datapath ${IN_DATAPATH} --root_dir ${ROOT_DIR} --out_fileprefix ${OUT_FILEPREFIX}
```
5 changes: 4 additions & 1 deletion monai_vila2d/data_prepare/experts/expert_train_data_cxr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import random

from data_utils import read_json, write_json
from expert_utils import add_expert_conversation, assert_image_placeholder, get_predictions, model_list
from expert_utils import (add_expert_conversation,
assert_image_placeholder,
get_predictions,
model_list)
from tqdm import tqdm

random.seed(0)
Expand Down
128 changes: 72 additions & 56 deletions monai_vila2d/data_prepare/experts/expert_train_data_vista3d.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import argparse
import os
from data_utils import read_txt, read_json, write_json
import random
import uuid

from data_utils import read_json, read_txt, write_json
from tqdm import tqdm
import random

random.seed(0)

from expert_utils import model_list
assert isinstance(model_list, str)

root_dir = "./"
assert isinstance(model_list, str)

n = 100_000
test_frac = 0.5
# TODO: upsample tumors
# TODO: option to upsample tumors


def get_qa_pairs(qas):
Expand All @@ -21,12 +21,12 @@ def get_qa_pairs(qas):
_start = qas.find("Q:")
qas = qas[_start::]
_end = qas.find("\n")
question = qas[qas.find("Q:") + 3:_end]
qas = qas[_end + 1::]
question = qas[qas.find("Q:") + 3 : _end]
qas = qas[_end + 1 : :]
_end = qas.find("\n")
answer = qas[qas.find("A:") + 3:_end]
answer = qas[qas.find("A:") + 3 : _end]
qas = qas[_end::]
#print(question, answer)
# print(question, answer)
assert "Q:" not in answer
assert "A:" not in question
if len(question) == 0 or len(answer) == 0:
Expand All @@ -40,10 +40,10 @@ def get_qa_pairs(qas):
def get_questions(reply):
questions = []
lines = reply.split("\n")
for l in lines:
if len(l) > 4:
if "." == l[1]: # "e.g., '1.'"
question = l[3::]
for line in lines:
if len(line) > 4:
if "." == line[1]: # "e.g., '1.'"
question = line[3::]
questions.append(question)

assert len(questions) > 0
Expand All @@ -59,31 +59,19 @@ def parse_qas(seg_qas_raw_file, lesions_q_raw_file, how_many_q_raw_file):
seg_qas = {}
for v in seg_qas_raw:
qas = get_qa_pairs(v["reply"])
seg_qas[v["object_type"]] = {
"reply": v["reply"],
"exp_model": v["exp_model"],
"qas": qas
}
seg_qas[v["object_type"]] = {"reply": v["reply"], "exp_model": v["exp_model"], "qas": qas}
print(f"Added {len(qas)} QA pairs for {v['object_type']}")

lesions_qs = {}
for v in lesions_q_raw:
qs = get_questions(v["reply"])
lesions_qs[v["tumor"]] = {
"reply": v["reply"],
"tumor": v["tumor"],
"questions": qs
}
lesions_qs[v["tumor"]] = {"reply": v["reply"], "tumor": v["tumor"], "questions": qs}
print(f"Added {len(qs)} lesion questions for {v['tumor']}")

how_many_qs = {}
for v in how_many_q_raw:
qs = get_questions(v["reply"])
how_many_qs[v["tumor"]] = {
"reply": v["reply"],
"tumor": v["tumor"],
"questions": qs
}
how_many_qs[v["tumor"]] = {"reply": v["reply"], "tumor": v["tumor"], "questions": qs}
print(f"Added {len(qs)} how many questions for {v['tumor']}")

return seg_qas, lesions_qs, how_many_qs
Expand All @@ -92,7 +80,7 @@ def parse_qas(seg_qas_raw_file, lesions_q_raw_file, how_many_q_raw_file):
def read_meta_files(root, datasets):
assert isinstance(root, str)
assert isinstance(datasets, list)
meta_files = [os.path.join(root, ds, "extracted_slices_meta.json")for ds in datasets]
meta_files = [os.path.join(root, ds, "extracted_slices_meta.json") for ds in datasets]

meta = []
out_datasets = []
Expand All @@ -119,57 +107,59 @@ def find_image(images, image, dataset):
raise ValueError(f"Did not find a matching image for {image} and dataset {dataset}")


def main():

def main(args):
images = read_txt("./vista3d/ct2D_vista3d_images.txt")
assert n < len(images)
assert args.n_samples < len(images)

incl_ds = ["Task03", "Task07", "Task09", "TotalSegmentatorV2"]
#incl_ds = ["Task03", "Task09"]
# incl_ds = ["Task03", "Task09"]
meta, datasets = read_meta_files(root="../experts/vista3d", datasets=incl_ds)

out_fileprefix = "../../data/experts/vista3d/llama_gen_expert_data_vista3d_what"

# TODO: add tumor questions

what_questions = read_txt("./llama_output/llama_gen_expert_what.txt")

# convert raw to dict
seg_qas, lesions_qs, how_many_qs = parse_qas("./llama_output/llama_gen_expert_qa_vista3d.json",
"./llama_output/llama_gen_expert_qa_lesions.json",
"./llama_output/llama_gen_expert_qa_how_many.json")
seg_qas, lesions_qs, how_many_qs = parse_qas(
"./llama_output/llama_gen_expert_qa_vista3d.json",
"./llama_output/llama_gen_expert_qa_lesions.json",
"./llama_output/llama_gen_expert_qa_how_many.json",
)

meta_ds = [(m, d) for m, d in zip(meta, datasets)]
meta_ds = random.sample(meta_ds, k=n)
meta_ds = random.sample(meta_ds, k=args.n_samplesn)

all_conversations = []
n_neg_tumors, n_pos_tumors, n_seg, n_what = 0, 0, 0, 0
for md in tqdm(meta_ds, desc="creating train data..."):
m, ds = md[0], md[1]
image = find_image(images, m["image"], ds).replace(root_dir, "").replace("\n", "")
image = find_image(images, m["image"], ds).replace(args.root_dir, "").replace("\n", "")
label = image.replace("_img.png", "_label.png")
group_name = m["group_name"]

id = str(uuid.uuid4())

entry = {
"image": image,
"id": id
}
entry = {"image": image, "id": id}

if "tumor" in group_name or "lesion" in group_name:
# tumor task
if group_name in lesions_qs:
les_qs = lesions_qs[group_name]
else:
les_qs = lesions_qs[group_name+"s"]
les_qs = lesions_qs[group_name + "s"]

question = random.choice(les_qs["questions"])

conv = list()
conv.append({"from": "human", "value": model_list + f" <image>This is a CT image.\n" + question})
conv.append({"from": "gpt", "value": f"This looks like a CT image. Let me trigger <VISTA3D({group_name})>. "})
conv.append({"from": "human", "value": f"The results are <segmentation>. The colors in this image describe {m['label_colors']}. Use this result to respond to this prompt:\n{question}."})
conv.append(
{
"from": "human",
"value": f"The results are <segmentation>. The colors in this image describe {m['label_colors']}. "
f"Use this result to respond to this prompt:\n{question}.",
}
)
if len(m["num_tumors"]) > 0:
n_pos_tumors += 1
answer = "yes"
Expand Down Expand Up @@ -217,20 +207,36 @@ def main():
question = random.choice(what_questions)
conv = list()
conv.append({"from": "human", "value": model_list + f" <image>This is a CT image.\n" + question})
conv.append({"from": "gpt", "value": f"This looks like a CT image. Let me trigger <VISTA3D({group_name})>. "})
conv.append({"from": "human", "value": f"The results are <segmentation>. The colors in this image describe {m['label_colors']}. Use this result to respond to this prompt:\n{question}."})
conv.append({"from": "gpt", "value": f"This a CT image. It contains several anatomical structures such as identified by VISTA3D: {m['label_colors']}."})
conv.append(
{"from": "gpt", "value": f"This looks like a CT image. Let me trigger <VISTA3D({group_name})>. "}
)
conv.append(
{
"from": "human",
"value": f"The results are <segmentation>. "
f"The colors in this image describe {m['label_colors']}. "
f"Use this result to respond to this prompt:\n{question}.",
}
)
conv.append(
{
"from": "gpt",
"value": f"This a CT image. "
f"It contains several anatomical structures such as identified by VISTA3D: "
f"{m['label_colors']}.",
}
)

entry["conversations"] = conv

all_conversations.append(entry)

print(f"Converted {len(all_conversations)} conversations")

out_train_file = out_fileprefix + "_train.json"
out_test_file = out_fileprefix + "_test.json"
out_train_file = args.out_fileprefix + "_train.json"
out_test_file = args.out_fileprefix + "_test.json"

split_idx = int(test_frac*len(all_conversations))
split_idx = int(args.test_frac * len(all_conversations))

random.shuffle(all_conversations)
test_conversations = all_conversations[0:split_idx]
Expand All @@ -239,8 +245,18 @@ def main():
write_json(train_conversations, out_train_file)
write_json(test_conversations, out_test_file)

print(f"Saved neg tumors: {n_neg_tumors}, pos tumors: {n_pos_tumors}, seg: {n_seg}, what: {n_what}, total: {len(all_conversations)}")
print(
f"Saved neg tumors: {n_neg_tumors}, pos tumors: {n_pos_tumors}, "
f"seg: {n_seg}, what: {n_what}, total: {len(all_conversations)}"
)


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument("--root_dir", type=str, required=True)
parser.add_argument("--out_fileprefix", type=str, required=True)
parser.add_argument("--n_samples", type=int, default=100_000)
parser.add_argument("--test_frac", type=float, default=0.5)
args = parser.parse_args()

main(args)

0 comments on commit ae16e9d

Please sign in to comment.