Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/samples-rust-server.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ jobs:
cargo build --bin ${package##*/} --features cli
target/debug/${package##*/} --help
fi
# Test the validate feature if it exists
if cargo read-manifest | grep -q '"validate"'; then
cargo build --features validate --all-targets
fi
cargo fmt
cargo test
cargo clippy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ public class RustServerCodegen extends AbstractRustCodegen implements CodegenCon
private static final String problemJsonMimeType = "application/problem+json";
private static final String problemXmlMimeType = "application/problem+xml";

// Track if we have models with conflicting names (Ok/Err) that conflict with serde_valid
private boolean hasConflictingModelNames = false;

public RustServerCodegen() {
super();

Expand Down Expand Up @@ -941,6 +944,11 @@ private void postProcessOperationWithModels(CodegenOperation op, List<ModelMap>
if (param.contentType != null && isMimetypeJson(param.contentType)) {
param.vendorExtensions.put("x-consumes-json", true);
}

// Add a vendor extension to flag if this can have validate() run on it.
if (!param.isUuid && !param.isPrimitiveType && !param.isEnum && (!param.isContainer || !languageSpecificPrimitives.contains(typeMapping.get(param.baseType)))) {
param.vendorExtensions.put("x-can-validate", true);
}
}

for (CodegenParameter param : op.formParams) {
Expand Down Expand Up @@ -1454,8 +1462,20 @@ public String toAllOfName(List<String> names, Schema composedSchema) {
public void postProcessModelProperty(CodegenModel model, CodegenProperty property) {
super.postProcessModelProperty(model, property);

// Check for reserved field names that conflict with serde_valid macro internals
if ("ok".equalsIgnoreCase(property.name) || "err".equalsIgnoreCase(property.name)) {
model.vendorExtensions.put("x-skip-serde-valid", true);
}

// Mark properties that reference complex types (models) for nested validation
// Only add nested validation for types that reference generated models (contain "models::")
if (property.dataType != null && property.dataType.contains("models::")) {
property.vendorExtensions.put("x-needs-nested-validation", true);
}

// TODO: We should avoid reverse engineering primitive type status from the data type
if (!languageSpecificPrimitives.contains(stripNullable(property.dataType))) {
String strippedType = stripNullable(property.dataType);
if (!languageSpecificPrimitives.contains(strippedType)) {
// If we use a more qualified model name, then only camelize the actual type, not the qualifier.
if (property.dataType.contains(":")) {
int position = property.dataType.lastIndexOf(":");
Expand Down Expand Up @@ -1528,7 +1548,32 @@ public void postProcessModelProperty(CodegenModel model, CodegenProperty propert

@Override
public ModelsMap postProcessModels(ModelsMap objs) {
return super.postProcessModelsEnum(objs);
ModelsMap result = super.postProcessModelsEnum(objs);

// Check for model names that conflict with serde_valid macro internals
// Once we find one, set a class-level flag that persists across all model batches
if (!hasConflictingModelNames) {
for (ModelMap modelMap : result.getModels()) {
CodegenModel model = modelMap.getModel();
if ("Ok".equalsIgnoreCase(model.classname) || "Err".equalsIgnoreCase(model.classname)) {
hasConflictingModelNames = true;
additionalProperties.put("hasConflictingModelNames", true);
break;
}
}
}

// If there are conflicting names (detected in any batch), skip serde_valid for ALL models
if (hasConflictingModelNames) {
for (ModelMap modelMap : result.getModels()) {
CodegenModel model = modelMap.getModel();
model.vendorExtensions.put("x-skip-serde-valid", true);
}
// Set the flag for this batch's template context
result.put("hasConflictingModelNames", true);
}

return result;
}

private void processParam(CodegenParameter param, CodegenOperation op) {
Expand Down Expand Up @@ -1613,6 +1658,11 @@ private void processParam(CodegenParameter param, CodegenOperation op) {
String exampleString = (example != null) ? "Some(" + example + ")" : "None";
param.vendorExtensions.put("x-example", exampleString);
}

// Add a vendor extension to flag if this can have validate() run on it.
if (!param.isUuid && !param.isPrimitiveType && !param.isEnum && (!param.isContainer || !languageSpecificPrimitives.contains(typeMapping.get(param.baseType)))) {
param.vendorExtensions.put("x-can-validate", true);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ cli = [
conversion = ["frunk", "frunk_derives", "frunk_core", "frunk-enum-core", "frunk-enum-derive"]

mock = ["mockall"]
validate = [{{^apiUsesByteArray}}"regex",{{/apiUsesByteArray}} "serde_valid", "swagger/serdevalid"]

[target.'cfg(any(target_os = "macos", target_os = "windows", target_os = "ios"))'.dependencies]
native-tls = { version = "0.2", optional = true }
Expand Down Expand Up @@ -100,6 +101,8 @@ regex = "1.12"

serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_valid = { version = "0.16", optional = true }

validator = { version = "0.20", features = ["derive"] }

# Crates included if required by the API definition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ The generated library has a few optional features that can be activated through
* This defaults to disabled and creates extra derives on models to allow "transmogrification" between objects of structurally similar types.
* `cli`
* This defaults to disabled and is required for building the included CLI tool.
* `validate`
* This defaults to disabled and allows JSON Schema validation of received data using `MakeService::set_validation` or `Service::set_validation`.
* Note, enabling validation will have a performance penalty, especially if the API heavily uses regex based checks.

See https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section for how to use features in your `Cargo.toml`.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
#![allow(unused_qualifications)]
{{^hasConflictingModelNames}}
#[cfg(not(feature = "validate"))]
use validator::Validate;

use crate::models;
#[cfg(any(feature = "client", feature = "server"))]
use crate::header;
#[cfg(feature = "validate")]
use serde_valid::Validate;
{{/hasConflictingModelNames}}
{{#hasConflictingModelNames}}
use validator::Validate;

use crate::models;
#[cfg(any(feature = "client", feature = "server"))]
use crate::header;
{{/hasConflictingModelNames}}
{{! Don't "use" structs here - they can conflict with the names of models, and mean that the code won't compile }}
{{#models}}
{{#model}}
Expand All @@ -19,6 +30,7 @@ use crate::header;
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize, Hash)]
{{^hasConflictingModelNames}}{{^exts.x-skip-serde-valid}}#[cfg_attr(feature = "validate", derive(Validate))]{{/exts.x-skip-serde-valid}}{{/hasConflictingModelNames}}
#[cfg_attr(feature = "conversion", derive(frunk_enum_derive::LabelledGenericEnum))]{{#xmlName}}
#[serde(rename = "{{{.}}}")]{{/xmlName}}
pub enum {{{classname}}} {
Expand Down Expand Up @@ -60,11 +72,14 @@ impl std::str::FromStr for {{{classname}}} {
{{^isEnum}}
{{#dataType}}
#[derive(Debug, Clone, PartialEq, {{#exts.x-partial-ord}}PartialOrd, {{/exts.x-partial-ord}}serde::Serialize, serde::Deserialize)]
{{^hasConflictingModelNames}}{{^exts.x-skip-serde-valid}}#[cfg_attr(feature = "validate", derive(Validate))]{{/exts.x-skip-serde-valid}}{{/hasConflictingModelNames}}
#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))]
{{#xmlName}}
#[serde(rename = "{{{.}}}")]
{{/xmlName}}
pub struct {{{classname}}}({{{dataType}}});
pub struct {{{classname}}}(
{{>validate}} {{{dataType}}}
);

impl std::convert::From<{{{dataType}}}> for {{{classname}}} {
fn from(x: {{{dataType}}}) -> Self {
Expand Down Expand Up @@ -176,6 +191,7 @@ where
{{/exts}}
{{! vec}}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
{{^hasConflictingModelNames}}{{^exts.x-skip-serde-valid}}#[cfg_attr(feature = "validate", derive(Validate))]{{/exts.x-skip-serde-valid}}{{/hasConflictingModelNames}}
#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))]
pub struct {{{classname}}}(
{{#exts}}
Expand Down Expand Up @@ -272,7 +288,7 @@ impl std::str::FromStr for {{{classname}}} {
{{/arrayModelType}}
{{^arrayModelType}}
{{! general struct}}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)]
#[derive(Debug, Clone, PartialEq, Validate, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "conversion", derive(frunk::LabelledGeneric))]
{{#xmlName}}
#[serde(rename = "{{{.}}}")]
Expand All @@ -288,7 +304,12 @@ pub struct {{{classname}}} {
{{/x-item-xml-name}}
{{/exts}}
{{#hasValidation}}
{{^hasConflictingModelNames}}
#[cfg_attr(not(feature = "validate"), validate(
{{/hasConflictingModelNames}}
{{#hasConflictingModelNames}}
#[validate(
{{/hasConflictingModelNames}}
{{#maxLength}}
{{#minLength}}
length(min = {{minLength}}, max = {{maxLength}}),
Expand Down Expand Up @@ -336,8 +357,19 @@ pub struct {{{classname}}} {
length(min = {{minItems}}),
{{/minItems}}
{{/maxItems}}
{{^hasConflictingModelNames}}
))]
{{/hasConflictingModelNames}}
{{#hasConflictingModelNames}}
)]
{{/hasConflictingModelNames}}
{{/hasValidation}}
{{^hasConflictingModelNames}}{{>validate}}{{/hasConflictingModelNames}}
{{^hasConflictingModelNames}}
{{#exts.x-needs-nested-validation}}
#[cfg_attr(feature = "validate", validate)]
{{/exts.x-needs-nested-validation}}
{{/hasConflictingModelNames}}
{{#required}}
pub {{{name}}}: {{{dataType}}},
{{/required}}
Expand All @@ -346,6 +378,11 @@ pub struct {{{classname}}} {
#[serde(deserialize_with = "swagger::nullable_format::deserialize_optional_nullable")]
#[serde(default = "swagger::nullable_format::default_optional_nullable")]
{{/isNullable}}
{{^hasConflictingModelNames}}
{{#exts.x-needs-nested-validation}}
#[cfg_attr(feature = "validate", validate)]
{{/exts.x-needs-nested-validation}}
{{/hasConflictingModelNames}}
#[serde(skip_serializing_if="Option::is_none")]
pub {{{name}}}: Option<{{{dataType}}}>,
{{/required}}
Expand All @@ -365,6 +402,7 @@ lazy_static::lazy_static! {
lazy_static::lazy_static! {
static ref RE_{{#lambda.uppercase}}{{{classname}}}_{{{name}}}{{/lambda.uppercase}}: regex::bytes::Regex = regex::bytes::Regex::new(r"{{ pattern }}").unwrap();
}
#[cfg(not(feature = "validate"))]
fn validate_byte_{{#lambda.lowercase}}{{{classname}}}_{{{name}}}{{/lambda.lowercase}}(
b: &swagger::ByteArray
) -> Result<(), validator::ValidationError> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use http_body_util::{combinators::BoxBody, Full};
use hyper::{body::{Body, Incoming}, HeaderMap, Request, Response, StatusCode};
use hyper::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use log::warn;
#[cfg(feature = "validate")]
use serde_valid::Validate;
#[allow(unused_imports)]
use std::convert::{TryFrom, TryInto};
use std::{convert::Infallible, error::Error};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ where
multipart_form_size_limit: Option<u64>,
{{/apiUsesMultipartFormData}}
marker: PhantomData<C>,
validation: bool
}

impl<T, C> MakeService<T, C>
Expand All @@ -22,7 +23,8 @@ where
{{#apiUsesMultipartFormData}}
multipart_form_size_limit: Some(8 * 1024 * 1024),
{{/apiUsesMultipartFormData}}
marker: PhantomData
marker: PhantomData,
validation: false
}
}
{{#apiUsesMultipartFormData}}
Expand All @@ -37,6 +39,12 @@ where
self
}
{{/apiUsesMultipartFormData}}

// Turn on/off validation for the service being made.
#[cfg(feature = "validate")]
pub fn set_validation(&mut self, validation: bool) {
self.validation = validation;
}
}

impl<T, C> Clone for MakeService<T, C>
Expand All @@ -51,6 +59,7 @@ where
multipart_form_size_limit: Some(8 * 1024 * 1024),
{{/apiUsesMultipartFormData}}
marker: PhantomData,
validation: self.validation
}
}
}
Expand All @@ -65,10 +74,8 @@ where
type Future = future::Ready<Result<Self::Response, Self::Error>>;

fn call(&self, target: Target) -> Self::Future {
let service = Service::new(self.api_impl.clone()){{^apiUsesMultipartFormData}};{{/apiUsesMultipartFormData}}
{{#apiUsesMultipartFormData}}
.multipart_form_size_limit(self.multipart_form_size_limit);
{{/apiUsesMultipartFormData}}
let service = Service::new(self.api_impl.clone(), self.validation){{^apiUsesMultipartFormData}};{{/apiUsesMultipartFormData}}{{#apiUsesMultipartFormData}}
.multipart_form_size_limit(self.multipart_form_size_limit);{{/apiUsesMultipartFormData}}

future::ok(service)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@
.expect("Unable to create Bad Request response for missing query parameter {{{baseName}}}")),
};
{{/required}}
{{#exts.x-can-validate}}
#[cfg(not(feature = "validate"))]
run_validation!(param_{{{paramName}}}, "{{{baseName}}}", validation);
{{/exts.x-can-validate}}
{{/isArray}}
{{#-last}}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@
.expect("Unable to create Bad Request response for missing body parameter {{{baseName}}}")),
};
{{/required}}
{{#exts.x-can-validate}}
#[cfg(not(feature = "validate"))]
run_validation!(param_{{{paramName}}}, "{{{baseName}}}", validation);
{{/exts.x-can-validate}}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
Box::pin(run(
self.api_impl.clone(),
req,
{{#apiUsesMultipartFormData}}
self.multipart_form_size_limit,
{{/apiUsesMultipartFormData}}
self.validation{{#apiUsesMultipartFormData}},
self.multipart_form_size_limit{{/apiUsesMultipartFormData}}
))
}
}
Loading
Loading