Skip to content

Commit b383d83

Browse files
authored
Merge pull request eugenp#2972 from denis-zhdanov/BAEL-1290-javac-plugin
BAEL-1290 Creating a Java Compiler Plugin
2 parents 996bee2 + 8e5e0e3 commit b383d83

10 files changed

Lines changed: 342 additions & 0 deletions

File tree

core-java/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,13 @@
216216
<artifactId>spring-web</artifactId>
217217
<version>4.3.4.RELEASE</version>
218218
</dependency>
219+
<dependency>
220+
<groupId>com.sun</groupId>
221+
<artifactId>tools</artifactId>
222+
<version>1.8.0</version>
223+
<scope>system</scope>
224+
<systemPath>${java.home}/../lib/tools.jar</systemPath>
225+
</dependency>
219226
</dependencies>
220227

221228
<build>
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package com.baeldung.javac;
2+
3+
import java.lang.annotation.*;
4+
5+
@Documented
6+
@Retention(RetentionPolicy.CLASS)
7+
@Target({ElementType.PARAMETER})
8+
public @interface Positive {
9+
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package com.baeldung.javac;
2+
3+
import com.sun.source.tree.MethodTree;
4+
import com.sun.source.tree.VariableTree;
5+
import com.sun.source.util.*;
6+
import com.sun.tools.javac.api.BasicJavacTask;
7+
import com.sun.tools.javac.code.TypeTag;
8+
import com.sun.tools.javac.tree.JCTree;
9+
import com.sun.tools.javac.tree.TreeMaker;
10+
import com.sun.tools.javac.util.Context;
11+
import com.sun.tools.javac.util.Name;
12+
import com.sun.tools.javac.util.Names;
13+
14+
import javax.tools.JavaCompiler;
15+
import java.util.*;
16+
import java.util.stream.Collectors;
17+
18+
import static com.sun.tools.javac.util.List.nil;
19+
20+
/**
21+
* A {@link JavaCompiler javac} plugin which inserts {@code >= 0} checks into resulting {@code *.class} files
22+
* for numeric method parameters marked by {@link Positive}
23+
*/
24+
public class SampleJavacPlugin implements Plugin {
25+
26+
public static final String NAME = "MyPlugin";
27+
28+
private static Set<String> TARGET_TYPES = new HashSet<>(Arrays.asList(
29+
// Use only primitive types for simplicity
30+
byte.class.getName(), short.class.getName(), char.class.getName(), int.class.getName(),
31+
long.class.getName(), float.class.getName(), double.class.getName()
32+
));
33+
34+
@Override
35+
public String getName() {
36+
return NAME;
37+
}
38+
39+
@Override
40+
public void init(JavacTask task, String... args) {
41+
Context context = ((BasicJavacTask) task).getContext();
42+
task.addTaskListener(new TaskListener() {
43+
@Override
44+
public void started(TaskEvent e) {
45+
}
46+
47+
@Override
48+
public void finished(TaskEvent e) {
49+
if (e.getKind() != TaskEvent.Kind.PARSE) {
50+
return;
51+
}
52+
e.getCompilationUnit().accept(new TreeScanner<Void, Void>() {
53+
@Override
54+
public Void visitMethod(MethodTree method, Void v) {
55+
List<VariableTree> parametersToInstrument = method.getParameters()
56+
.stream()
57+
.filter(SampleJavacPlugin.this::shouldInstrument)
58+
.collect(Collectors.toList());
59+
if (!parametersToInstrument.isEmpty()) {
60+
// There is a possible case that more than one argument is marked by @Positive,
61+
// as the checks are added to the method's body beginning, we process parameters RTL
62+
// to ensure correct order.
63+
Collections.reverse(parametersToInstrument);
64+
parametersToInstrument.forEach(p -> addCheck(method, p, context));
65+
}
66+
// There is a possible case that there is a nested class declared in a method's body,
67+
// hence, we want to proceed with method body AST as well.
68+
return super.visitMethod(method, v);
69+
}
70+
}, null);
71+
}
72+
});
73+
}
74+
75+
private boolean shouldInstrument(VariableTree parameter) {
76+
return TARGET_TYPES.contains(parameter.getType().toString())
77+
&& parameter.getModifiers().getAnnotations()
78+
.stream()
79+
.anyMatch(a -> Positive.class.getSimpleName().equals(a.getAnnotationType().toString()));
80+
}
81+
82+
private void addCheck(MethodTree method, VariableTree parameter, Context context) {
83+
JCTree.JCIf check = createCheck(parameter, context);
84+
JCTree.JCBlock body = (JCTree.JCBlock) method.getBody();
85+
body.stats = body.stats.prepend(check);
86+
}
87+
88+
private static JCTree.JCIf createCheck(VariableTree parameter, Context context) {
89+
TreeMaker factory = TreeMaker.instance(context);
90+
Names symbolsTable = Names.instance(context);
91+
String parameterName = parameter.getName().toString();
92+
String errorMessagePrefix = String.format("Argument '%s' of type %s is marked by @%s but got '",
93+
parameterName, parameter.getType(), Positive.class.getSimpleName());
94+
String errorMessageSuffix = "' for it";
95+
Name parameterId = symbolsTable.fromString(parameterName);
96+
return factory.at(((JCTree) parameter).pos).If(
97+
factory.Parens(
98+
factory.Binary(
99+
JCTree.Tag.LE,
100+
factory.Ident(parameterId),
101+
factory.Literal(TypeTag.INT, 0))
102+
),
103+
factory.Block(0, com.sun.tools.javac.util.List.of(
104+
factory.Throw(
105+
factory.NewClass(
106+
null,
107+
nil(),
108+
factory.Ident(
109+
symbolsTable.fromString(IllegalArgumentException.class.getSimpleName())
110+
),
111+
com.sun.tools.javac.util.List.of(
112+
factory.Binary(JCTree.Tag.PLUS,
113+
factory.Binary(JCTree.Tag.PLUS,
114+
factory.Literal(TypeTag.CLASS, errorMessagePrefix),
115+
factory.Ident(parameterId)),
116+
factory.Literal(TypeTag.CLASS, errorMessageSuffix))),
117+
null
118+
)
119+
)
120+
)),
121+
null
122+
);
123+
}
124+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
com.baeldung.javac.SampleJavacPlugin
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package com.baeldung.javac;
2+
3+
import org.junit.Test;
4+
5+
import static org.junit.Assert.assertEquals;
6+
7+
public class SampleJavacPluginIntegrationTest {
8+
9+
private static final String CLASS_TEMPLATE =
10+
"package com.baeldung.javac;\n" +
11+
"\n" +
12+
"public class Test {\n" +
13+
" public static %1$s service(@Positive %1$s i) {\n" +
14+
" return i;\n" +
15+
" }\n" +
16+
"}\n" +
17+
"";
18+
19+
private TestCompiler compiler = new TestCompiler();
20+
private TestRunner runner = new TestRunner();
21+
22+
@Test(expected = IllegalArgumentException.class)
23+
public void givenInt_whenNegative_thenThrowsException() throws Throwable {
24+
compileAndRun(double.class,-1);
25+
}
26+
27+
@Test(expected = IllegalArgumentException.class)
28+
public void givenInt_whenZero_thenThrowsException() throws Throwable {
29+
compileAndRun(int.class,0);
30+
}
31+
32+
@Test
33+
public void givenInt_whenPositive_thenSuccess() throws Throwable {
34+
assertEquals(1, compileAndRun(int.class, 1));
35+
}
36+
37+
private Object compileAndRun(Class<?> argumentType, Object argument) throws Throwable {
38+
String qualifiedClassName = "com.baeldung.javac.Test";
39+
byte[] byteCode = compiler.compile(qualifiedClassName, String.format(CLASS_TEMPLATE, argumentType.getName()));
40+
return runner.run(byteCode, qualifiedClassName, "service", new Class[] {argumentType}, argument);
41+
}
42+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package com.baeldung.javac;
2+
3+
import javax.tools.SimpleJavaFileObject;
4+
import java.io.ByteArrayOutputStream;
5+
import java.io.IOException;
6+
import java.io.OutputStream;
7+
import java.net.URI;
8+
9+
/** Holds compiled byte code in a byte array */
10+
public class SimpleClassFile extends SimpleJavaFileObject {
11+
12+
private ByteArrayOutputStream out;
13+
14+
public SimpleClassFile(URI uri) {
15+
super(uri, Kind.CLASS);
16+
}
17+
18+
@Override
19+
public OutputStream openOutputStream() throws IOException {
20+
return out = new ByteArrayOutputStream();
21+
}
22+
23+
public byte[] getCompiledBinaries() {
24+
return out.toByteArray();
25+
}
26+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package com.baeldung.javac;
2+
3+
import javax.tools.*;
4+
import java.net.URI;
5+
import java.util.ArrayList;
6+
import java.util.List;
7+
8+
/** Adapts {@link SimpleClassFile} to the {@link JavaCompiler} */
9+
public class SimpleFileManager extends ForwardingJavaFileManager<StandardJavaFileManager> {
10+
11+
private final List<SimpleClassFile> compiled = new ArrayList<>();
12+
13+
public SimpleFileManager(StandardJavaFileManager delegate) {
14+
super(delegate);
15+
}
16+
17+
@Override
18+
public JavaFileObject getJavaFileForOutput(Location location,
19+
String className,
20+
JavaFileObject.Kind kind,
21+
FileObject sibling)
22+
{
23+
SimpleClassFile result = new SimpleClassFile(URI.create("string://" + className));
24+
compiled.add(result);
25+
return result;
26+
}
27+
28+
/**
29+
* @return compiled binaries processed by the current class
30+
*/
31+
public List<SimpleClassFile> getCompiled() {
32+
return compiled;
33+
}
34+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.baeldung.javac;
2+
3+
import javax.tools.SimpleJavaFileObject;
4+
import java.net.URI;
5+
6+
/** Exposes given test source to the compiler. */
7+
public class SimpleSourceFile extends SimpleJavaFileObject {
8+
9+
private final String content;
10+
11+
public SimpleSourceFile(String qualifiedClassName, String testSource) {
12+
super(URI.create(String.format("file://%s%s",
13+
qualifiedClassName.replaceAll("\\.", "/"),
14+
Kind.SOURCE.extension)),
15+
Kind.SOURCE);
16+
content = testSource;
17+
}
18+
19+
@Override
20+
public CharSequence getCharContent(boolean ignoreEncodingErrors) {
21+
return content;
22+
}
23+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package com.baeldung.javac;
2+
3+
import javax.tools.JavaCompiler;
4+
import javax.tools.ToolProvider;
5+
import java.io.StringWriter;
6+
import java.util.ArrayList;
7+
import java.util.List;
8+
9+
import static java.util.Arrays.asList;
10+
import static java.util.Collections.singletonList;
11+
12+
public class TestCompiler {
13+
public byte[] compile(String qualifiedClassName, String testSource) {
14+
StringWriter output = new StringWriter();
15+
16+
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
17+
SimpleFileManager fileManager = new SimpleFileManager(compiler.getStandardFileManager(
18+
null,
19+
null,
20+
null
21+
));
22+
List<SimpleSourceFile> compilationUnits = singletonList(new SimpleSourceFile(qualifiedClassName, testSource));
23+
List<String> arguments = new ArrayList<>();
24+
arguments.addAll(asList("-classpath", System.getProperty("java.class.path"),
25+
"-Xplugin:" + SampleJavacPlugin.NAME));
26+
JavaCompiler.CompilationTask task = compiler.getTask(output,
27+
fileManager,
28+
null,
29+
arguments,
30+
null,
31+
compilationUnits);
32+
task.call();
33+
return fileManager.getCompiled().iterator().next().getCompiledBinaries();
34+
}
35+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package com.baeldung.javac;
2+
3+
import java.lang.reflect.InvocationTargetException;
4+
import java.lang.reflect.Method;
5+
6+
public class TestRunner {
7+
8+
public Object run(byte[] byteCode,
9+
String qualifiedClassName,
10+
String methodName,
11+
Class<?>[] argumentTypes,
12+
Object... args)
13+
throws Throwable
14+
{
15+
ClassLoader classLoader = new ClassLoader() {
16+
@Override
17+
protected Class<?> findClass(String name) throws ClassNotFoundException {
18+
return defineClass(name, byteCode, 0, byteCode.length);
19+
}
20+
};
21+
Class<?> clazz;
22+
try {
23+
clazz = classLoader.loadClass(qualifiedClassName);
24+
} catch (ClassNotFoundException e) {
25+
throw new RuntimeException("Can't load compiled test class", e);
26+
}
27+
28+
Method method;
29+
try {
30+
method = clazz.getMethod(methodName, argumentTypes);
31+
} catch (NoSuchMethodException e) {
32+
throw new RuntimeException("Can't find the 'main()' method in the compiled test class", e);
33+
}
34+
35+
try {
36+
return method.invoke(null, args);
37+
} catch (InvocationTargetException e) {
38+
throw e.getCause();
39+
}
40+
}
41+
}

0 commit comments

Comments
 (0)