Skip to content

Commit 3f5882f

Browse files
authored
fix/refactor-scan-part-1 (#661)
1 parent 9d3acde commit 3f5882f

File tree

1 file changed

+196
-77
lines changed

1 file changed

+196
-77
lines changed

safety/scan/command.py

Lines changed: 196 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838

3939

4040
cli_apps_opts = {"rich_markup_mode": "rich", "cls": SafetyCLISubGroup}
41-
4241
scan_project_app = typer.Typer(**cli_apps_opts)
4342
scan_system_app = typer.Typer(**cli_apps_opts)
4443

@@ -258,9 +257,6 @@ def generate_cve_details(files: List[FileModel]) -> List[Dict[str, Any]]:
258257
return sort_cve_data(cve_data)
259258

260259

261-
262-
263-
264260
def add_cve_details_to_report(report_to_output: str, files: List[FileModel]) -> str:
265261
"""
266262
Add CVE details to the JSON report output.
@@ -294,6 +290,185 @@ def generate_updates_arguments() -> List:
294290
return fixes
295291

296292

293+
def validate_authentication(ctx: typer.Context) -> None:
294+
"""
295+
Validates that the user is authenticated.
296+
297+
Args:
298+
ctx (typer.Context): The Typer context object.
299+
300+
Raises:
301+
SafetyError: If the user is not authenticated.
302+
"""
303+
if not ctx.obj.metadata.authenticated:
304+
raise SafetyError("Authentication required. Please run 'safety auth login' to authenticate before using this command.")
305+
306+
307+
def generate_fixes_target(apply_updates: bool) -> List:
308+
"""
309+
Generates a list of update targets if `apply_updates` is enabled.
310+
311+
Args:
312+
apply_updates (bool): Whether to generate fixes target.
313+
314+
Returns:
315+
List: A list of update targets if enabled, otherwise an empty list.
316+
"""
317+
return generate_updates_arguments() if apply_updates else []
318+
319+
320+
def validate_save_as(ctx: typer.Context, save_as: Optional[Tuple[ScanExport, Path]]) -> None:
321+
"""
322+
Ensures the `save_as` parameters are valid.
323+
324+
Args:
325+
ctx (typer.Context): The Typer context object.
326+
save_as (Optional[Tuple[ScanExport, Path]]): The save-as parameters.
327+
"""
328+
if not all(save_as):
329+
ctx.params["save_as"] = None
330+
331+
332+
def initialize_file_finder(ctx: typer.Context, target: Path, console: Console, ecosystems: List[Ecosystem]) -> FileFinder:
333+
"""
334+
Initializes the FileFinder for scanning files in the target directory.
335+
336+
Args:
337+
ctx (typer.Context): The Typer context object.
338+
target (Path): The target directory to scan.
339+
console (Console): The console object for logging.
340+
ecosystems (List[Ecosystem]): The list of scannable ecosystems.
341+
342+
Returns:
343+
FileFinder: An initialized FileFinder object.
344+
"""
345+
to_include = {
346+
file_type: paths
347+
for file_type, paths in ctx.obj.config.scan.include_files.items()
348+
if file_type.ecosystem in ecosystems
349+
}
350+
351+
file_finder = FileFinder(
352+
target=target,
353+
ecosystems=ecosystems,
354+
max_level=ctx.obj.config.scan.max_depth,
355+
exclude=ctx.obj.config.scan.ignore,
356+
include_files=to_include,
357+
console=console,
358+
)
359+
360+
# Download necessary assets for each handler
361+
for handler in file_finder.handlers:
362+
if handler.ecosystem:
363+
wait_msg = "Fetching Safety's vulnerability database..."
364+
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
365+
handler.download_required_assets(ctx.obj.auth.client)
366+
367+
return file_finder
368+
369+
370+
def scan_project_directory(file_finder: FileFinder, console: Console) -> Tuple[Path, Dict]:
371+
"""
372+
Scans the project directory and identifies relevant files for analysis.
373+
374+
Args:
375+
file_finder (FileFinder): Initialized file finder object.
376+
console (Console): Console for logging output.
377+
378+
Returns:
379+
Tuple[Path, Dict]: The base path of the project and a dictionary of file paths grouped by type.
380+
"""
381+
wait_msg = "Scanning project directory"
382+
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
383+
path, file_paths = file_finder.search()
384+
print_detected_ecosystems_section(console, file_paths, include_safety_prjs=True)
385+
return path, file_paths
386+
387+
388+
def detect_dependency_vulnerabilities(console, dependency_vuln_detected):
389+
"""
390+
Prints a message indicating that dependency vulnerabilities were detected.
391+
"""
392+
if not dependency_vuln_detected:
393+
console.print()
394+
console.print("Dependency vulnerabilities detected:")
395+
return True
396+
return dependency_vuln_detected
397+
398+
399+
def print_file_info(console, path, target):
400+
"""
401+
Prints the file information for vulnerabilities.
402+
"""
403+
console.print()
404+
msg = f":pencil: [file_title]{path.relative_to(target)}:[/file_title]"
405+
console.print(msg, emoji=True)
406+
407+
408+
def sort_and_filter_vulnerabilities(vulnerabilities, key_func, reverse=True):
409+
"""
410+
Sorts and filters vulnerabilities.
411+
"""
412+
return sorted(
413+
[vuln for vuln in vulnerabilities if not vuln.ignored],
414+
key=key_func,
415+
reverse=reverse
416+
)
417+
418+
419+
def count_critical_vulnerabilities(vulnerabilities: List[Vulnerability]) -> int:
420+
"""
421+
Count the number of critical vulnerabilities in a list of vulnerabilities.
422+
423+
Args:
424+
vulnerabilities (List[Vulnerability]): List of vulnerabilities to evaluate.
425+
426+
Returns:
427+
int: The number of vulnerabilities with a critical severity level.
428+
"""
429+
return sum(
430+
1 for vuln in vulnerabilities
431+
if vuln.severity
432+
and vuln.severity.cvssv3
433+
and vuln.severity.cvssv3.get("base_severity", "none").lower() == VulnerabilitySeverityLabels.CRITICAL.value.lower()
434+
)
435+
436+
437+
def generate_vulnerability_message(spec_name: str, spec_raw: str, vulns_found: int, critical_vulns_count: int, vuln_word: str) -> str:
438+
"""
439+
Generate a formatted message for vulnerabilities in a specific dependency.
440+
441+
Args:
442+
spec_name (str): Name of the dependency.
443+
spec_raw (str): Raw specification string of the dependency.
444+
vulns_found (int): Number of vulnerabilities found.
445+
critical_vulns_count (int): Number of critical vulnerabilities found.
446+
vuln_word (str): Pluralized form of the word "vulnerability."
447+
448+
Returns:
449+
str: Formatted vulnerability message.
450+
"""
451+
msg = f"[dep_name]{spec_name}[/dep_name][specifier]{spec_raw.replace(spec_name, '')}[/specifier] [{vulns_found} {vuln_word} found"
452+
if vulns_found > 3 and critical_vulns_count > 0:
453+
msg += f", [brief_severity]including {critical_vulns_count} critical severity {pluralize('vulnerability', critical_vulns_count)}[/brief_severity]"
454+
return msg
455+
456+
457+
def render_vulnerabilities(vulns_to_report: List[Vulnerability], console: Console, detailed_output: bool) -> None:
458+
"""
459+
Render vulnerabilities to the console.
460+
461+
Args:
462+
vulns_to_report (List[Vulnerability]): List of vulnerabilities to render.
463+
console (Console): Console object for printing.
464+
detailed_output (bool): Whether to display detailed output.
465+
"""
466+
for vuln in vulns_to_report:
467+
render_to_console(
468+
vuln, console, rich_kwargs={"emoji": True, "overflow": "crop"}, detailed_output=detailed_output
469+
)
470+
471+
297472
@scan_project_app.command(
298473
cls=SafetyCLICommand,
299474
help=CLI_SCAN_COMMAND_HELP,
@@ -371,71 +546,35 @@ def scan(ctx: typer.Context,
371546
Scans a project (defaulted to the current directory) for supply-chain security and configuration issues
372547
"""
373548

374-
if not ctx.obj.metadata.authenticated:
375-
raise SafetyError("Authentication required. Please run 'safety auth login' to authenticate before using this command.")
376-
377-
# Generate update arguments if apply updates option is enabled
378-
fixes_target = []
379-
if apply_updates:
380-
fixes_target = generate_updates_arguments()
381-
382-
# Ensure save_as params are correctly set
383-
if not all(save_as):
384-
ctx.params["save_as"] = None
549+
validate_authentication(ctx)
550+
fixes_target = generate_fixes_target(apply_updates)
551+
validate_save_as(ctx, save_as)
385552

386553
console = ctx.obj.console
387554
ecosystems = [Ecosystem(member.value) for member in list(ScannableEcosystems)]
388-
to_include = {file_type: paths for file_type, paths in ctx.obj.config.scan.include_files.items() if file_type.ecosystem in ecosystems}
389-
390-
# Initialize file finder
391-
file_finder = FileFinder(target=target, ecosystems=ecosystems,
392-
max_level=ctx.obj.config.scan.max_depth,
393-
exclude=ctx.obj.config.scan.ignore,
394-
include_files=to_include,
395-
console=console)
396555

397-
# Download necessary assets for each handler
398-
for handler in file_finder.handlers:
399-
if handler.ecosystem:
400-
wait_msg = "Fetching Safety's vulnerability database..."
401-
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
402-
handler.download_required_assets(ctx.obj.auth.client)
403-
404-
# Start scanning the project directory
405-
wait_msg = "Scanning project directory"
406-
407-
path = None
408-
file_paths = {}
409-
410-
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
411-
path, file_paths = file_finder.search()
412-
print_detected_ecosystems_section(console, file_paths,
413-
include_safety_prjs=True)
556+
file_finder = initialize_file_finder(ctx, target, console, ecosystems)
557+
path, file_paths = scan_project_directory(file_finder, console)
414558

415559
target_ecosystems = ", ".join([member.value for member in ecosystems])
416560
wait_msg = f"Analyzing {target_ecosystems} files and environments for security findings"
417561

418562
files: List[FileModel] = []
419-
563+
to_fix_files = []
564+
ignored_vulns_data = iter([])
420565
config = ctx.obj.config
421-
422566
count = 0
423-
424567
affected_count = 0
425-
dependency_vuln_detected = False
426-
427-
ignored_vulns_data = iter([])
428-
429568
exit_code = 0
430569
fixes_count = 0
431570
total_resolved_vulns = 0
432-
to_fix_files = []
433571
fix_file_types = [fix_target[0] if isinstance(fix_target[0], str) else fix_target[0].value for fix_target in fixes_target]
572+
dependency_vuln_detected = False
434573
requirements_txt_found = False
435574
display_apply_fix_suggestion = False
436575

437576
# Process each file for dependencies and vulnerabilities
438-
with console.status(wait_msg, spinner=DEFAULT_SPINNER) as status:
577+
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
439578
for path, analyzed_file in process_files(paths=file_paths,
440579
config=config, use_server_matching=use_server_matching, obj=ctx.obj, target=target):
441580
count += len(analyzed_file.dependency_results.dependencies)
@@ -450,49 +589,30 @@ def scan(ctx: typer.Context,
450589
def sort_vulns_by_score(vuln: Vulnerability) -> int:
451590
if vuln.severity and vuln.severity.cvssv3:
452591
return vuln.severity.cvssv3.get("base_score", 0)
453-
454592
return 0
455593

456594
to_fix_spec = []
457595
file_matched_for_fix = analyzed_file.file_type.value in fix_file_types
458596

459597
if any(affected_specifications):
460-
if not dependency_vuln_detected:
461-
console.print()
462-
console.print("Dependency vulnerabilities detected:")
463-
dependency_vuln_detected = True
598+
dependency_vuln_detected = detect_dependency_vulnerabilities(console, dependency_vuln_detected)
599+
print_file_info(console, path, target)
464600

465-
console.print()
466-
msg = f":pencil: [file_title]{path.relative_to(target)}:[/file_title]"
467-
console.print(msg, emoji=True)
468601
for spec in affected_specifications:
469602
if file_matched_for_fix:
470603
to_fix_spec.append(spec)
471604

472605
console.print()
473-
vulns_to_report = sorted(
474-
[vuln for vuln in spec.vulnerabilities if not vuln.ignored],
475-
key=sort_vulns_by_score,
476-
reverse=True)
477-
critical_vulns_count = sum(1 for vuln in vulns_to_report if vuln.severity and vuln.severity.cvssv3 and vuln.severity.cvssv3.get("base_severity", "none").lower() == VulnerabilitySeverityLabels.CRITICAL.value.lower())
478-
606+
vulns_to_report = sort_and_filter_vulnerabilities(spec.vulnerabilities, key_func=sort_vulns_by_score)
607+
critical_vulns_count = count_critical_vulnerabilities(vulns_to_report)
479608
vulns_found = len(vulns_to_report)
480609
vuln_word = pluralize("vulnerability", vulns_found)
481610

482-
msg = f"[dep_name]{spec.name}[/dep_name][specifier]{spec.raw.replace(spec.name, '')}[/specifier] [{vulns_found} {vuln_word} found"
483-
484-
if vulns_found > 3 and critical_vulns_count > 0:
485-
msg += f", [brief_severity]including {critical_vulns_count} critical severity {pluralize('vulnerability', critical_vulns_count)}[/brief_severity]"
486-
487-
console.print(Padding(f"{msg}]", (0, 0, 0, 1)), emoji=True,
488-
overflow="crop")
611+
msg = generate_vulnerability_message(spec.name, spec.raw, vulns_found, critical_vulns_count, vuln_word)
612+
console.print(Padding(f"{msg}]", (0, 0, 0, 1)), emoji=True, overflow="crop")
489613

490614
if detailed_output or vulns_found < 3:
491-
for vuln in vulns_to_report:
492-
render_to_console(vuln, console,
493-
rich_kwargs={"emoji": True,
494-
"overflow": "crop"},
495-
detailed_output=detailed_output)
615+
render_vulnerabilities(vulns_to_report, console, detailed_output)
496616

497617
lines = []
498618

@@ -591,7 +711,6 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int:
591711
**{k: v for k, v in ctx.params.items() if k not in {"detailed_output", "output", "save_as", "filter_keys"}}
592712
)
593713

594-
595714
project_url = f"{SAFETY_PLATFORM_URL}{ctx.obj.project.url_path}"
596715

597716
if apply_updates:
@@ -630,7 +749,7 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int:
630749

631750
if not no_output:
632751
console.print("-" * console.size.width)
633-
752+
634753
if output is ScanOutput.SCREEN:
635754
run_easter_egg(console, exit_code)
636755

0 commit comments

Comments
 (0)