ExecutionService.java
/*
* Licensed to The Apereo Foundation under one or more contributor license
* agreements. See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
*
* The Apereo Foundation licenses this file to you under the Educational
* Community License, Version 2.0 (the "License"); you may not use this file
* except in compliance with the License. You may obtain a copy of the License
* at:
*
* http://opensource.org/licenses/ecl2.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*
*/
package org.opencastproject.graphql.execution;
import org.opencastproject.graphql.exception.GraphQLNotFoundException;
import org.opencastproject.graphql.execution.context.OpencastContextManager;
import org.opencastproject.graphql.schema.SchemaService;
import org.opencastproject.security.api.SecurityService;
import org.osgi.framework.BundleContext;
import org.osgi.service.component.annotations.Activate;
import org.osgi.service.component.annotations.Component;
import org.osgi.service.component.annotations.Modified;
import org.osgi.service.component.annotations.Reference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import graphql.ExecutionInput;
import graphql.ExecutionResult;
import graphql.GraphQL;
import graphql.analysis.MaxQueryComplexityInstrumentation;
import graphql.analysis.MaxQueryDepthInstrumentation;
import graphql.execution.AsyncExecutionStrategy;
import graphql.execution.AsyncSerialExecutionStrategy;
import graphql.execution.instrumentation.ChainedInstrumentation;
import graphql.execution.instrumentation.Instrumentation;
import graphql.execution.instrumentation.tracing.TracingInstrumentation;
import graphql.schema.GraphQLSchema;
@Component(
service = { ExecutionService.class }
)
public class ExecutionService {
private static final Logger logger = LoggerFactory.getLogger(ExecutionService.class);
private final SecurityService securityService;
private final SchemaService schemaService;
private final BundleContext bundleContext;
private final Map<String, GraphQL> organizationGraphQL = new ConcurrentHashMap<>();
public @interface ExecutionConfiguration {
int execution_max_query_complexity() default 1000;
int execution_max_query_depth() default 25;
}
private ExecutionConfiguration config;
@Activate
public ExecutionService(
@Reference SchemaService schemaService,
@Reference SecurityService securityService,
BundleContext bundleContext,
ExecutionConfiguration config
) {
this.schemaService = schemaService;
this.securityService = securityService;
this.bundleContext = bundleContext;
updateConfiguration(config);
}
@Modified
public void updateConfiguration(ExecutionConfiguration config) {
if (config.execution_max_query_complexity() <= 0) {
throw new IllegalArgumentException("execution_max_query_complexity must be greater than 0");
}
if (config.execution_max_query_depth() <= 0) {
throw new IllegalArgumentException("execution_max_query_depth must be greater than 0");
}
this.config = config;
}
public ExecutionResult execute(
String query,
String operationName,
Map<String, Object> variables,
Map<String, Object> extensions) {
Map<?, Object> context = new HashMap<>();
return execute(ExecutionInput.newExecutionInput()
.query(query)
.operationName(operationName)
.variables((variables != null) ? variables : Collections.emptyMap())
.extensions((extensions != null) ? extensions : Collections.emptyMap())
.graphQLContext(context)
.build());
}
public ExecutionResult execute(ExecutionInput executionInput) {
try {
var context = OpencastContextManager.initiateContext(bundleContext);
context.setOrganization(securityService.getOrganization());
context.setUser(securityService.getUser());
executionInput.getGraphQLContext().put(OpencastContextManager.CONTEXT, context);
var graphQL = getGraphQL(securityService.getOrganization().getId());
if (graphQL == null) {
return ExecutionErrorResult.newExecutionResult()
.addError(new GraphQLNotFoundException("No GraphQL schema found for organization `"
+ securityService.getOrganization().getId() + "`")
).build();
}
return graphQL.execute(executionInput);
} finally {
OpencastContextManager.clearContext();
}
}
private GraphQL getGraphQL(String organizationId) {
GraphQL graphQL = organizationGraphQL.get(organizationId);
GraphQLSchema schema = schemaService.get(organizationId);
List<Instrumentation> chainedList = new ArrayList<>(
List.of(new MaxQueryDepthInstrumentation(config.execution_max_query_depth()),
new MaxQueryComplexityInstrumentation(config.execution_max_query_complexity())));
if (logger.isTraceEnabled()) {
logger.trace("Enabling tracing instrumentation for organization `{}`", organizationId);
chainedList.add(new TracingInstrumentation());
}
if (schema == null) {
return null;
}
if (graphQL == null || !schema.equals(graphQL.getGraphQLSchema())) {
var exceptionHandler = new OpencastDataFetcherExceptionHandler();
graphQL = GraphQL.newGraphQL(schema)
.queryExecutionStrategy(new AsyncExecutionStrategy(exceptionHandler))
.mutationExecutionStrategy(new AsyncSerialExecutionStrategy(exceptionHandler))
.preparsedDocumentProvider(new QueryCache())
.executionIdProvider(new OrganizationExecutionIdProvider(organizationId))
.defaultDataFetcherExceptionHandler(exceptionHandler)
.instrumentation(new ChainedInstrumentation(chainedList))
.build();
organizationGraphQL.put(organizationId, graphQL);
}
return graphQL;
}
}