Skip to content

Conversation

glenn-jocher
Copy link
Member

@glenn-jocher glenn-jocher commented Sep 19, 2022

May resolve threaded inference issue in #9425 (comment) by avoiding memory sharing on init.

Signed-off-by: Glenn Jocher [email protected]

🛠️ PR Summary

Made with ❤️ by Ultralytics Actions

🌟 Summary

Improved initialization of grid and anchor grid variables in YOLOv5 model.

📊 Key Changes

  • Changed the way grid variables are initialized from a list containing a single empty tensor to a list comprehension creating separate empty tensors for each detection layer.
  • Applied analogous changes to the anchor grid variable initialization.

🎯 Purpose & Impact

  • These changes prevent potential issues with multiple layers referencing the same tensor, enabling proper individual layer operations.
  • By ensuring each detection layer has its own unique grid and anchor grid, the update enhances the model's reliability and correctness, which might lead to subtle performance improvements for users of the YOLOv5 object detection model. 🚀

May resolve threaded inference issue in #9425 (comment) by avoiding memory sharing on init.


Signed-off-by: Glenn Jocher <[email protected]>
@glenn-jocher
Copy link
Member Author

Problem statement that this PR resolves:

import torch
x = [torch.tensor(1)] * 10
x
Out[8]: 
[tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1),
 tensor(1)]
x[3] *=2
x
Out[10]: 
[tensor(2),
 tensor(2),
 tensor(2),
 tensor(2),
 tensor(2),
 tensor(2),
 tensor(2),
 tensor(2),
 tensor(2),
 tensor(2)]

@glenn-jocher glenn-jocher merged commit 868c0e9 into master Sep 19, 2022
@glenn-jocher glenn-jocher deleted the update/detect_grid_init branch September 19, 2022 11:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant