Skip to content

Commit 848217c

Browse files
committed
Merge branch 'release/2.2' into develop
2 parents cfe5975 + 2cb1d10 commit 848217c

File tree

3 files changed

+86
-27
lines changed

3 files changed

+86
-27
lines changed

src/main/java/io/cdap/plugin/spark/dynamic/ScalaSparkCodeExecutor.java

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -170,31 +170,37 @@ public void onEvent(SparkListenerEvent event) {
170170
*/
171171
public Object execute(SparkExecutionPluginContext context,
172172
JavaRDD<StructuredRecord> javaRDD) throws InvocationTargetException, IllegalAccessException {
173-
// RDD case
174-
if (!isDataFrame) {
175-
if (takeContext) {
176-
//noinspection unchecked
177-
return ((RDD<StructuredRecord>) method.invoke(null, javaRDD.rdd(), context)).toJavaRDD();
178-
} else {
179-
//noinspection unchecked
180-
return ((RDD<StructuredRecord>) method.invoke(null, javaRDD.rdd())).toJavaRDD();
173+
ClassLoader oldCL = Thread.currentThread().getContextClassLoader();
174+
Thread.currentThread().setContextClassLoader(interpreter.getClassLoader());
175+
try {
176+
// RDD case
177+
if (!isDataFrame) {
178+
if (takeContext) {
179+
//noinspection unchecked
180+
return ((RDD<StructuredRecord>) method.invoke(null, javaRDD.rdd(), context)).toJavaRDD();
181+
} else {
182+
//noinspection unchecked
183+
return ((RDD<StructuredRecord>) method.invoke(null, javaRDD.rdd())).toJavaRDD();
184+
}
181185
}
182-
}
183186

184-
// DataFrame case
185-
Schema inputSchema = context.getInputSchema();
186-
if (inputSchema == null) {
187-
// Should already been checked in initialize. This is to safeguard in case the call sequence changed in future.
188-
throw new IllegalArgumentException("Input schema must be provided for using DataFrame in Spark Compute");
189-
}
187+
// DataFrame case
188+
Schema inputSchema = context.getInputSchema();
189+
if (inputSchema == null) {
190+
// Should already been checked in initialize. This is to safeguard in case the call sequence changed in future.
191+
throw new IllegalArgumentException("Input schema must be provided for using DataFrame in Spark Compute");
192+
}
190193

191-
SQLContext sqlContext = getSQLContext(context.getSparkContext().sc());
194+
SQLContext sqlContext = getSQLContext(context.getSparkContext().sc());
192195

193-
StructType rowType = DataFrames.toDataType(inputSchema);
194-
JavaRDD<Row> rowRDD = javaRDD.map(new RecordToRow(rowType));
196+
StructType rowType = DataFrames.toDataType(inputSchema);
197+
JavaRDD<Row> rowRDD = javaRDD.map(new RecordToRow(rowType));
195198

196-
Object dataFrame = createDataFrame(sqlContext, rowRDD, rowType);
197-
return takeContext ? method.invoke(null, dataFrame, context) : method.invoke(null, dataFrame);
199+
Object dataFrame = createDataFrame(sqlContext, rowRDD, rowType);
200+
return takeContext ? method.invoke(null, dataFrame, context) : method.invoke(null, dataFrame);
201+
} finally {
202+
Thread.currentThread().setContextClassLoader(oldCL);
203+
}
198204
}
199205

200206
private String generateSourceClass(String className) {

src/main/java/io/cdap/plugin/spark/dynamic/ScalaSparkProgram.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
import io.cdap.cdap.api.spark.SparkMain;
2929
import io.cdap.cdap.api.spark.dynamic.CompilationFailureException;
3030
import io.cdap.cdap.api.spark.dynamic.SparkInterpreter;
31-
import org.slf4j.Logger;
32-
import org.slf4j.LoggerFactory;
3331

3432
import java.io.File;
3533
import java.io.IOException;
@@ -47,8 +45,6 @@
4745
@Description("Executes user-provided Spark program")
4846
public class ScalaSparkProgram implements JavaSparkMain {
4947

50-
private static final Logger LOG = LoggerFactory.getLogger(ScalaSparkProgram.class);
51-
5248
private final Config config;
5349

5450
public ScalaSparkProgram(Config config) throws CompilationFailureException, IOException {
@@ -144,15 +140,18 @@ private Callable<Void> getMethodCallable(ClassLoader classLoader, String mainCla
144140
arg = sec == null ? null : RuntimeArguments.toPosixArray(sec.getRuntimeArguments());
145141
}
146142

147-
return new Callable<Void>() {
148-
@Override
149-
public Void call() throws Exception {
143+
return () -> {
144+
ClassLoader oldCl = Thread.currentThread().getContextClassLoader();
145+
Thread.currentThread().setContextClassLoader(classLoader);
146+
try {
150147
Object instance = null;
151148
if (!Modifier.isStatic(method.getModifiers())) {
152149
instance = cls.newInstance();
153150
}
154151
method.invoke(instance, arg);
155152
return null;
153+
} finally {
154+
Thread.currentThread().setContextClassLoader(oldCl);
156155
}
157156
};
158157

src/test/java/io/cdap/plugin/spark/dynamic/ScalaSparkTest.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,60 @@ public void testScalaProgramDependency() throws Exception {
214214
workflowManager.waitForRun(ProgramRunStatus.COMPLETED, 5, TimeUnit.MINUTES);
215215
}
216216

217+
@Test
218+
public void testScalaSparkProgramClosure() throws Exception {
219+
StringWriter codeWriter = new StringWriter();
220+
try (PrintWriter printer = new PrintWriter(codeWriter, true)) {
221+
printer.println("import io.cdap.cdap.api.spark._");
222+
printer.println("import org.apache.spark._");
223+
printer.println("import org.apache.spark.rdd.RDD");
224+
printer.println("import org.slf4j._");
225+
226+
printer.println("class SparkProgram extends SparkMain {");
227+
printer.println(" import SparkProgram._");
228+
229+
printer.println(" override def run(implicit sec: SparkExecutionContext): Unit = {");
230+
printer.println(" LOG.info(\"Spark Program Started\")");
231+
232+
printer.println(" val sc = new SparkContext");
233+
printer.println(" val points = sc.parallelize(Seq((\"a\", Array(1, 2)), (\"a\", Array(3, 4))))");
234+
235+
printer.println(" val sq = points.mapValues(t => Array(t.apply(0) * t.apply(0), t.apply(1) * t.apply(1)))");
236+
printer.println(" LOG.info(\"squared = {}\", sq.collect)");
237+
238+
printer.println(" val squaredNested = points.mapValues(t => t.map(x => x * x))");
239+
printer.println(" LOG.info(\"squaredNested = {}\", squaredNested.collect)");
240+
241+
printer.println(" LOG.info(\"Spark Program Completed\")");
242+
printer.println(" }");
243+
printer.println("}");
244+
245+
printer.println("object SparkProgram {");
246+
printer.println(" val LOG = LoggerFactory.getLogger(getClass())");
247+
printer.println("}");
248+
}
249+
250+
// Pipeline configuration
251+
ETLBatchConfig etlConfig = ETLBatchConfig.builder()
252+
.addStage(new ETLStage("action", new ETLPlugin("ScalaSparkProgram", "sparkprogram", ImmutableMap.of(
253+
"scalaCode", codeWriter.toString(),
254+
"mainClass", "SparkProgram"
255+
))))
256+
.build();
257+
258+
// Deploy the pipeline
259+
ArtifactSummary artifactSummary = new ArtifactSummary(DATAPIPELINE_ARTIFACT_ID.getArtifact(),
260+
DATAPIPELINE_ARTIFACT_ID.getVersion());
261+
AppRequest<ETLBatchConfig> appRequest = new AppRequest<>(artifactSummary, etlConfig);
262+
ApplicationId appId = NamespaceId.DEFAULT.app("ScalaSparkProgramApp");
263+
ApplicationManager appManager = deployApplication(appId, appRequest);
264+
265+
// Run the pipeline
266+
WorkflowManager workflowManager = appManager.getWorkflowManager(SmartWorkflow.NAME);
267+
workflowManager.start();
268+
workflowManager.waitForRun(ProgramRunStatus.COMPLETED, 5, TimeUnit.MINUTES);
269+
}
270+
217271
@Test
218272
public void testScalaSparkCompute() throws Exception {
219273
Schema inputSchema = Schema.recordOf(

0 commit comments

Comments
 (0)