-
Notifications
You must be signed in to change notification settings - Fork 393
Add beta support for jsd #290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
def forward(self, p, q): | ||
return LigerJSDFunction.apply(p, q) | ||
def forward(self, log_q, log_p): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the correct order of input and target (student and teacher) respectively. would it be too confusing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, the name is a bit confusing, or we can add some descriptions here to clarify
@qingquansong @yundai424 ready for review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in general! In case you're interested, I think one good future work is to make those KL or JSD losses similar to the fuse CE loss: feed teacher and student model last projection layer to the kernel and fuse it with the losses. Here teacher weight does not need grad and student will need grad.
|
||
def forward(self, p, q): | ||
return LigerJSDFunction.apply(p, q) | ||
def forward(self, log_q, log_p): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, the name is a bit confusing, or we can add some descriptions here to clarify
awesome work! waiting for the final nit review |
@qingquansong sure, I'm in. |
Forgot to add jsd in readme and liger_kernel.transformer |
Head branch was pushed to by a user without write access
Summary
Resolve #278 .
Details
Forward:
where$X=logQ$ , $Y=logP$ and $M=\beta P + (1-\beta)Q$ .
Gradients:
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence