Skip to content

Commit e46ba99

Browse files
Merge pull request #96 from fredsamhaak/master
Add a function of cell types annotation with local LLMs
2 parents e610cf9 + d1421ba commit e46ba99

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

omicverse/single/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@
2626
from ._mdic3 import pyMDIC3
2727
from ._cnmf import *
2828
from ._gptcelltype import gptcelltype
29-
from ._cytotrace2 import cytotrace2
29+
from ._cytotrace2 import cytotrace2
30+
from ._gptcelltype_local import gptcelltype_local
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
2+
3+
def gptcelltype_local(input, tissuename=None, speciename='human',
4+
model_name='Qwen/Qwen2-7B-Instruct', topgenenumber=10):
5+
"""
6+
Annotation of cell types using a local LLM model.
7+
8+
Arguments:
9+
input: dict, input dictionary with clusters as keys and gene markers as values. \
10+
e.g. {'cluster1': ['gene1', 'gene2'], 'cluster2': ['gene3']}
11+
tissuename: str, tissue name.
12+
speciename: str, species name. Default: 'human'.
13+
model_name: str, the name or path of the local model to be used.
14+
topgenenumber: int, the number of top genes to consider for each cluster. Default: 10.
15+
"""
16+
import re
17+
import numpy as np
18+
import pandas as pd
19+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
20+
21+
# Load the model and tokenizer from Hugging Face
22+
tokenizer = AutoTokenizer.from_pretrained(model_name)
23+
model = AutoModelForCausalLM.from_pretrained(
24+
model_name,
25+
device_map='cuda',
26+
torch_dtype='auto',
27+
trust_remote_code=True
28+
)
29+
pipe = pipeline(
30+
"text-generation",
31+
model=model,
32+
tokenizer=tokenizer
33+
)
34+
35+
if isinstance(input, dict):
36+
input = {k: 'unknown' if not v else ','.join(v[:topgenenumber]) for k, v in input.items()}
37+
elif isinstance(input, pd.DataFrame):
38+
# Filter genes with positive log fold change and group by cluster, selecting top genes
39+
input = input[input['logfoldchanges'] > 0]
40+
input = input.groupby('cluster')['names'].apply(lambda x: ','.join(x.iloc[:topgenenumber]))
41+
else:
42+
raise ValueError("Input must be either a dictionary of lists or a pandas DataFrame.")
43+
44+
message_template = (
45+
f"Identify cell types of {tissuename} cells in {speciename} using the above markers separately for each row.\n"
46+
"Provide the cell type name, followed by the reason, which should be enclosed in square brackets.\n"
47+
"Some can be a mixture of multiple cell types. If so, seperate them with semicolon.\n\n"
48+
"Output format:\n"
49+
"cluster: cell type [marker(s)]\n"
50+
"Output example:\n"
51+
"0: T cells [CD3D, IL7R]\n"
52+
"1: Cytotoxic T cells [CCL5, NKG7, GZMA]; Natural Killer (NK) cells [GNLY, KLRD1]"
53+
)
54+
55+
allres = {}
56+
cutnum = int(np.ceil(len(input) / 30))
57+
if cutnum > 1:
58+
cid = np.digitize(range(1, len(input) + 1), bins=np.linspace(1, len(input), cutnum + 1))
59+
else:
60+
cid = np.ones(len(input), dtype=int)
61+
62+
for i in range(1, cutnum + 1):
63+
id_list = [j for j, x in enumerate(cid) if x == i]
64+
message = '\n'.join(
65+
[f"{k}: {v}" for k, v in input.items() if list(input.keys()).index(k) in id_list]
66+
) + '\n\n' + message_template
67+
68+
messages = [
69+
{"role": "system", "content": "You are an experienced biologist with particular expertise in molecular biology, cell biology, and bioinformatics."},
70+
{"role": "user", "content": message},
71+
]
72+
73+
generation_args = {
74+
"max_new_tokens": 5000,
75+
"return_full_text": False,
76+
"temperature": 0.3,
77+
"do_sample": False,
78+
}
79+
generated = pipe(messages, **generation_args)
80+
print(generated[0]['generated_text'])
81+
82+
pattern = r'\d+:\s+(.+?)\s+\[.*?\]'
83+
res = re.findall(pattern, generated[0]['generated_text'])
84+
for idx, cell_type in zip(id_list, res):
85+
key = list(input.keys())[idx]
86+
allres[key] = 'unknown' if input[key] == 'unknown' else cell_type
87+
88+
print('Note: It is always recommended to check the results returned by the LLM in case of AI hallucination, before going to downstream analysis.')
89+
return allres

0 commit comments

Comments
 (0)