Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import static org.apache.commons.lang3.BooleanUtils.isFalse;

import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand All @@ -24,26 +26,35 @@
import org.hibernate.query.criteria.HibernateCriteriaBuilder;
import org.hibernate.relational.SchemaManager;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.AnnotationTarget.Kind;
import org.jboss.jandex.AnnotationTransformation;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.ClassType;
import org.jboss.jandex.CompositeIndex;
import org.jboss.jandex.DotName;
import org.jboss.jandex.FieldInfo;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.ParameterizedType;
import org.jboss.jandex.Type;
import org.objectweb.asm.ClassVisitor;

import io.agroal.api.AgroalDataSource;
import io.quarkus.agroal.spi.JdbcDataSourceBuildItem;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem;
import io.quarkus.arc.deployment.BeanDefiningAnnotationBuildItem;
import io.quarkus.arc.deployment.BeanDiscoveryFinishedBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem.ExtendedBeanConfigurator;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.deployment.ValidationPhaseBuildItem;
import io.quarkus.arc.processor.AnnotationsTransformer;
import io.quarkus.arc.processor.BeanInfo;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.DotNames;
import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.arc.processor.Transformation;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
Expand All @@ -52,7 +63,10 @@
import io.quarkus.deployment.annotations.BuildSteps;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.BytecodeTransformerBuildItem;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.gizmo.ClassTransformer;
import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.hibernate.orm.PersistenceUnit;
import io.quarkus.hibernate.orm.runtime.HibernateOrmRecorder;
import io.quarkus.hibernate.orm.runtime.HibernateOrmRuntimeConfig;
Expand Down Expand Up @@ -276,6 +290,62 @@ void registerBeans(HibernateOrmConfig hibernateOrmConfig,
unremovableBeans.produce(UnremovableBeanBuildItem.beanTypes(jpaModel.getPotentialCdiBeanClassNames()));
}

@BuildStep
void transformBeans(JpaModelBuildItem jpaModel, JpaModelIndexBuildItem indexBuildItem,
BeanDiscoveryFinishedBuildItem beans,
BuildProducer<BytecodeTransformerBuildItem> producer) {
if (!HibernateOrmProcessor.hasEntities(jpaModel)) {
return;
}

// the idea here is to remove the 'private' modifier from all methods that are annotated with JPA Listener methods
// and don't belong to entities
CompositeIndex index = indexBuildItem.getIndex();
for (DotName dotName : jpaModel.getPotentialCdiBeanClassNames()) {
if (jpaModel.getManagedClassNames().contains(dotName.toString())) {
continue;
}
ClassInfo classInfo = index.getClassByName(dotName);
List<BeanInfo> matchingBeans = beans.getBeans().stream().filter(bi -> bi.getBeanClass().equals(dotName)).toList();
if (matchingBeans.size() == 1) {
ScopeInfo beanScope = matchingBeans.get(0).getScope();
for (DotName jpaListenerDotName : ClassNames.JPA_LISTENER_ANNOTATIONS) {
for (AnnotationInstance annotationInstance : classInfo.annotations(jpaListenerDotName)) {
AnnotationTarget target = annotationInstance.target();
if (target.kind() != AnnotationTarget.Kind.METHOD) {
continue;
}
MethodInfo method = target.asMethod();
if (Modifier.isPrivate(method.flags())) {
if (beanScope.getDotName().equals(BuiltinScope.SINGLETON.getName())) {
// we can safely transform in this case
producer.produce(new BytecodeTransformerBuildItem(method.declaringClass().name().toString(),
new BiFunction<>() {
@Override
public ClassVisitor apply(String cls, ClassVisitor clsVisitor) {
var classTransformer = new ClassTransformer(cls);
classTransformer.modifyMethod(MethodDescriptor.of(method))
.removeModifiers(Modifier.PRIVATE);
return classTransformer.applyTo(clsVisitor);
}
}));
} else {
// we can't transform because the client proxy does not know about the transformation and
// will therefore simply copy the private method which will then likely fail because it does
// not contain the injected fields
throw new IllegalArgumentException(
"Methods that are annotated with JPA Listener annotations should not be private. Offending method is '"
+ method.declaringClass().name() + "#" + method.name() + "'");
}
}
}
}
} else {
// we don't really know what to do here, just bail and CDI will figure it out
}
}
}

@BuildStep
void registerAnnotations(BuildProducer<AdditionalBeanBuildItem> additionalBeans,
BuildProducer<BeanDefiningAnnotationBuildItem> beanDefiningAnnotations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,24 +429,13 @@ private void enlistPotentialCdiBeanClasses(Collector collector, DotName dotName)

for (AnnotationInstance annotation : jpaAnnotations) {
AnnotationTarget target = annotation.target();
ClassInfo beanType;
switch (target.kind()) {
case CLASS:
beanType = target.asClass();
break;
case FIELD:
beanType = target.asField().declaringClass();
break;
case METHOD:
beanType = target.asMethod().declaringClass();
break;
case METHOD_PARAMETER:
case TYPE:
case RECORD_COMPONENT:
default:
throw new IllegalArgumentException(
"Annotation " + dotName + " was not expected on a target of kind " + target.kind());
}
ClassInfo beanType = switch (target.kind()) {
case CLASS -> target.asClass();
case FIELD -> target.asField().declaringClass();
case METHOD -> target.asMethod().declaringClass();
default -> throw new IllegalArgumentException(
"Annotation " + dotName + " was not expected on a target of kind " + target.kind());
};
DotName beanTypeDotName = beanType.name();
collector.potentialCdiBeanTypes.add(beanTypeDotName);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package io.quarkus.hibernate.orm;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.persistence.Entity;
import jakarta.persistence.EntityListeners;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.PostPersist;
import jakarta.transaction.Transactional;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;

public class JpaListenerOnPrivateMethodOfApplicationScopedCdiBeanTest {

@RegisterExtension
static QuarkusUnitTest runner = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar
.addAsResource("application.properties"))
.assertException(e -> {
assertThat(e).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("SomeEntityListener#postPersist");
});

@Test
@Transactional
public void test() {
fail("should never be called");
}

@Entity
@EntityListeners(SomeEntityListener.class)
public static class SomeEntity {
private long id;
private String name;

public SomeEntity() {
}

public SomeEntity(String name) {
this.name = name;
}

@Id
@GeneratedValue(strategy = GenerationType.SEQUENCE, generator = "myEntitySeq")
public long getId() {
return id;
}

public void setId(long id) {
this.id = id;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

@Override
public String toString() {
return "SomeEntity:" + name;
}
}

@ApplicationScoped
public static class SomeEntityListener {

@PostPersist
private void postPersist(SomeEntity someEntity) {
fail("should not reach here");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package io.quarkus.hibernate.orm;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.persistence.Entity;
import jakarta.persistence.EntityListeners;
import jakarta.persistence.EntityManager;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.PostPersist;
import jakarta.persistence.PrePersist;
import jakarta.transaction.Transactional;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.Arc;
import io.quarkus.test.QuarkusUnitTest;

public class JpaListenerOnPrivateMethodOfSingletonCdiBeanTest {

@RegisterExtension
static QuarkusUnitTest runner = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar
.addAsResource("application.properties"));

@Inject
EventStore eventStore;

@Inject
EntityManager em;

@Test
@Transactional
public void test() {
Arc.container().requestContext().activate();
try {
SomeEntity entity = new SomeEntity("test");
em.persist(entity);
em.flush();
} finally {
Arc.container().requestContext().terminate();
}

assertThat(eventStore.getEvents()).containsExactly("prePersist", "postPersist");
}

@Entity
@EntityListeners(SomeEntityListener.class)
public static class SomeEntity {
private long id;
private String name;

public SomeEntity() {
}

public SomeEntity(String name) {
this.name = name;
}

@Id
@GeneratedValue(strategy = GenerationType.SEQUENCE, generator = "myEntitySeq")
public long getId() {
return id;
}

public void setId(long id) {
this.id = id;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

@Override
public String toString() {
return "SomeEntity:" + name;
}
}

@Singleton
public static class SomeEntityListener {

private final EventStore eventStore;

public SomeEntityListener(EventStore eventStore) {
this.eventStore = eventStore;
}

@PrePersist
void prePersist(SomeEntity someEntity) {
eventStore.addEvent("prePersist");
}

@PostPersist
private void postPersist(SomeEntity someEntity) {
eventStore.addEvent("postPersist");
}
}

@ApplicationScoped
public static class EventStore {
private final List<String> events = new CopyOnWriteArrayList<>();

public void addEvent(String event) {
events.add(event);
}

public List<String> getEvents() {
return events;
}
}
}
Loading