|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | import os |
18 | | -from typing import Any, Dict, List, Tuple, cast |
| 18 | +from typing import Any, Dict, List, Tuple, Optional, cast |
19 | 19 |
|
20 | 20 | from .utils import split_into_batches, is_subsequence_in_list |
21 | 21 |
|
| 22 | +# Optional GPU (CuPy) support |
| 23 | +_gpu_available = False |
| 24 | +try: # pragma: no cover - optional dependency path |
| 25 | + import cupy as _cp_mod # type: ignore[import-not-found] |
| 26 | + |
| 27 | + cp = cast(Any, _cp_mod) |
| 28 | + |
| 29 | + try: |
| 30 | + _gpu_available = cp.cuda.runtime.getDeviceCount() > 0 # type: ignore[attr-defined] |
| 31 | + except Exception: |
| 32 | + _gpu_available = False |
| 33 | +except Exception: # pragma: no cover - optional dependency path |
| 34 | + cp = None # type: ignore[assignment] |
| 35 | + _gpu_available = False |
| 36 | + |
22 | 37 | # Simple per-process cache for encoded transactions keyed by the list object's id |
23 | 38 | _ENCODED_CACHE: Dict[int, Tuple[List[List[int]], Dict[int, str], Dict[str, int], int]] = {} |
24 | 39 |
|
@@ -89,6 +104,40 @@ def _encode_candidates(candidates: List[Tuple[str, ...]], vocab: Dict[str, int]) |
89 | 104 | return [[vocab[s] for s in cand] for cand in candidates] |
90 | 105 |
|
91 | 106 |
|
| 107 | +def _support_counts_gpu_singletons( |
| 108 | + enc_tx: List[List[int]], |
| 109 | + cand_ids: List[int], |
| 110 | + min_support_abs: int, |
| 111 | + vocab_size: int, |
| 112 | +) -> List[Tuple[List[int], int]]: |
| 113 | + """GPU-accelerated support counts for singleton candidates using CuPy. |
| 114 | +
|
| 115 | + This computes the number of transactions containing each candidate item ID. |
| 116 | + It uniquifies items per transaction on CPU to preserve presence semantics, |
| 117 | + then performs a single bincount on GPU. |
| 118 | + """ |
| 119 | + # Ensure one contribution per transaction |
| 120 | + unique_rows: List[List[int]] = [list(set(row)) for row in enc_tx] |
| 121 | + if not unique_rows: |
| 122 | + return [] |
| 123 | + |
| 124 | + # Flatten to a 1D list of item ids, then move to GPU |
| 125 | + flat: List[int] = [item for row in unique_rows for item in row] |
| 126 | + if not flat: |
| 127 | + return [] |
| 128 | + |
| 129 | + cp_flat = cp.asarray(flat, dtype=cp.int32) # type: ignore[name-defined] |
| 130 | + counts = cp.bincount(cp_flat, minlength=vocab_size) # type: ignore[attr-defined] |
| 131 | + counts_host: Any = counts.get() # back to host as a NumPy array |
| 132 | + |
| 133 | + out: List[Tuple[List[int], int]] = [] |
| 134 | + for cid in cand_ids: |
| 135 | + freq = int(counts_host[cid]) |
| 136 | + if freq >= min_support_abs: |
| 137 | + out.append(([cid], freq)) |
| 138 | + return out |
| 139 | + |
| 140 | + |
92 | 141 | def support_counts_python( |
93 | 142 | transactions: List[Tuple[str, ...]], |
94 | 143 | candidates: List[Tuple[str, ...]], |
@@ -118,30 +167,91 @@ def support_counts( |
118 | 167 | candidates: List[Tuple[str, ...]], |
119 | 168 | min_support_abs: int, |
120 | 169 | batch_size: int = 100, |
| 170 | + backend: Optional[str] = None, |
121 | 171 | ) -> Dict[Tuple[str, ...], int]: |
122 | 172 | """Choose the best available backend for support counting. |
123 | 173 |
|
124 | | - Backend selection is controlled by the env var GSPPY_BACKEND: |
| 174 | + Backend selection is controlled by the `backend` argument when provided, |
| 175 | + otherwise by the env var GSPPY_BACKEND: |
125 | 176 | - "rust": require Rust extension (raise if missing) |
| 177 | + - "gpu": try GPU path when available (currently singletons optimized), |
| 178 | + fall back to CPU for the rest |
126 | 179 | - "python": force pure-Python fallback |
127 | 180 | - otherwise: try Rust first and fall back to Python |
128 | 181 | """ |
129 | | - backend = _env_backend() |
| 182 | + backend_sel = (backend or _env_backend()).lower() |
130 | 183 |
|
131 | | - if backend == "python": |
| 184 | + if backend_sel == "gpu": |
| 185 | + if not _gpu_available: |
| 186 | + raise RuntimeError("GSPPY_BACKEND=gpu but CuPy GPU is not available") |
| 187 | + # Encode once |
| 188 | + enc_tx, inv_vocab, vocab = _get_encoded_transactions(transactions) |
| 189 | + enc_cands = _encode_candidates(candidates, vocab) |
| 190 | + |
| 191 | + # Partition candidates into singletons and non-singletons |
| 192 | + singletons: List[Tuple[int, Tuple[str, ...]]] = [] |
| 193 | + others: List[Tuple[List[int], Tuple[str, ...]]] = [] |
| 194 | + # Pair original and encoded candidates; lengths should match |
| 195 | + assert len(candidates) == len(enc_cands), "Encoded candidates length mismatch" |
| 196 | + for orig, enc in zip(candidates, enc_cands): # noqa: B905 - lengths checked above |
| 197 | + if len(enc) == 1: |
| 198 | + singletons.append((enc[0], orig)) |
| 199 | + else: |
| 200 | + others.append((enc, orig)) |
| 201 | + |
| 202 | + out: Dict[Tuple[str, ...], int] = {} |
| 203 | + |
| 204 | + # GPU path for singletons |
| 205 | + if singletons: |
| 206 | + vocab_size = max(vocab.values()) + 1 if vocab else 0 |
| 207 | + gpu_res = _support_counts_gpu_singletons( |
| 208 | + enc_tx=enc_tx, |
| 209 | + cand_ids=[cid for cid, _ in singletons], |
| 210 | + min_support_abs=min_support_abs, |
| 211 | + vocab_size=vocab_size, |
| 212 | + ) |
| 213 | + # Map back to original strings |
| 214 | + cand_by_id: Dict[int, Tuple[str, ...]] = {cid: orig for cid, orig in singletons} |
| 215 | + for enc_cand, freq in gpu_res: |
| 216 | + cid = enc_cand[0] |
| 217 | + out[cand_by_id[cid]] = int(freq) |
| 218 | + |
| 219 | + # Fallback for others (prefer rust when available) |
| 220 | + if others: |
| 221 | + if _rust_available: |
| 222 | + try: |
| 223 | + other_enc = [enc for enc, _ in others] |
| 224 | + res = cast( |
| 225 | + List[Tuple[List[int], int]], _compute_supports_rust(enc_tx, other_enc, int(min_support_abs)) |
| 226 | + ) |
| 227 | + for enc_cand, freq in res: |
| 228 | + out[tuple(inv_vocab[i] for i in enc_cand)] = int(freq) |
| 229 | + except Exception: |
| 230 | + # fallback to python |
| 231 | + out.update( |
| 232 | + support_counts_python(transactions, [orig for _, orig in others], min_support_abs, batch_size) |
| 233 | + ) |
| 234 | + else: |
| 235 | + out.update( |
| 236 | + support_counts_python(transactions, [orig for _, orig in others], min_support_abs, batch_size) |
| 237 | + ) |
| 238 | + |
| 239 | + return out |
| 240 | + |
| 241 | + if backend_sel == "python": |
132 | 242 | return support_counts_python(transactions, candidates, min_support_abs, batch_size) |
133 | 243 |
|
134 | | - if backend == "rust": |
| 244 | + if backend_sel == "rust": |
135 | 245 | if not _rust_available: |
136 | 246 | raise RuntimeError("GSPPY_BACKEND=rust but Rust extension _gsppy_rust is not available") |
137 | 247 | # use rust |
138 | 248 | enc_tx, inv_vocab, vocab = _get_encoded_transactions(transactions) |
139 | 249 | enc_cands = _encode_candidates(candidates, vocab) |
140 | 250 | result = cast(List[Tuple[List[int], int]], _compute_supports_rust(enc_tx, enc_cands, int(min_support_abs))) |
141 | | - out: Dict[Tuple[str, ...], int] = {} |
| 251 | + out_rust: Dict[Tuple[str, ...], int] = {} |
142 | 252 | for enc_cand, freq in result: |
143 | | - out[tuple(inv_vocab[i] for i in enc_cand)] = int(freq) |
144 | | - return out |
| 253 | + out_rust[tuple(inv_vocab[i] for i in enc_cand)] = int(freq) |
| 254 | + return out_rust |
145 | 255 |
|
146 | 256 | # auto: try rust then fallback |
147 | 257 | if _rust_available: |
|
0 commit comments