From e9c05d8c7cc32c1df19a2777c389752ebcf44a60 Mon Sep 17 00:00:00 2001 From: bread Date: Sat, 26 Apr 2025 00:40:18 +0200 Subject: [PATCH] fix multi name support, create if not existing, update if existing, more type fixing --- src/client.rs | 225 +++++++++++++++++++++++++------------------------- src/main.rs | 154 +++++++++++++++++++++++++--------- src/models.rs | 98 ++++++++++++++++++++-- 3 files changed, 314 insertions(+), 163 deletions(-) diff --git a/src/client.rs b/src/client.rs index 7c52deb..60c5cde 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,75 +1,13 @@ use crate::*; -use reqwest::{ - Client, Method, Request, Url, - header::HeaderValue, -}; -use std::borrow::Borrow; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum RecordType { - A, - AAAA, - NS, - MX, - CNAME, - RP, - TXT, - SOA, - HINFO, - SRV, - DANE, - TLSA, - DS, - CAA, -} - -impl ToString for RecordType { - fn to_string(&self) -> String { - match self { - RecordType::A => "A".to_string(), - RecordType::AAAA => "AAAA".to_string(), - RecordType::NS => "NS".to_string(), - RecordType::MX => "MX".to_string(), - RecordType::CNAME => "CNAME".to_string(), - RecordType::RP => "RP".to_string(), - RecordType::TXT => "TXT".to_string(), - RecordType::SOA => "SOA".to_string(), - RecordType::HINFO => "HINFO".to_string(), - RecordType::SRV => "SRV".to_string(), - RecordType::DANE => "DANE".to_string(), - RecordType::TLSA => "TLSA".to_string(), - RecordType::DS => "DS".to_string(), - RecordType::CAA => "CAA".to_string(), - } - } -} - -impl TryFrom<&str> for RecordType { - type Error = &'static str; - fn try_from(value: &str) -> Result { - Ok(match value { - "A" => RecordType::A, - "AAAA" => RecordType::AAAA, - "NS" => RecordType::NS, - "MX" => RecordType::MX, - "CNAME" => RecordType::CNAME, - "RP" => RecordType::RP, - "TXT" => RecordType::TXT, - "SOA" => RecordType::SOA, - "HINFO" => RecordType::HINFO, - "SRV" => RecordType::SRV, - "DANE" => RecordType::DANE, - "TLSA" => RecordType::TLSA, - "DS" => RecordType::DS, - "CAA" => RecordType::CAA, - _ => return Err(""), - }) - } -} +use reqwest::{Client, Method, Request, StatusCode, Url, header::HeaderValue}; +use std::{borrow::Borrow, fmt}; #[derive(Serialize)] -struct _RecordQuery { - records: Vec, +struct _RecordQuery +where + T: Serialize, +{ + records: Vec, } pub struct HetznerDNSAPIClient { @@ -78,6 +16,48 @@ pub struct HetznerDNSAPIClient { client: Client, } +#[derive(Debug)] +pub enum HetznerDNSAPIError { + Unauthorized, + Forbidden, + NotFound, + NotAcceptable, + Conflict, + UnprocessableEntity, +} + +impl fmt::Display for HetznerDNSAPIError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}", + match self { + Self::Unauthorized => "Missing authorization, did you set a token?", + Self::Forbidden => "Forbidden action!", + Self::NotFound => "Entity not found!", + Self::NotAcceptable => "Request failed, possibly bad input!", + Self::Conflict => "Conflict encountered", + Self::UnprocessableEntity => + "Request not procesasble, input is invalid or might already exist!", + } + ) + } +} +impl TryFrom for HetznerDNSAPIError { + type Error = String; + fn try_from(value: StatusCode) -> Result { + Ok(match value { + StatusCode::UNAUTHORIZED => Self::Unauthorized, + StatusCode::FORBIDDEN => Self::Forbidden, + StatusCode::NOT_FOUND => Self::NotFound, + StatusCode::NOT_ACCEPTABLE => Self::NotAcceptable, + StatusCode::CONFLICT => Self::Conflict, + StatusCode::UNPROCESSABLE_ENTITY => Self::UnprocessableEntity, + _ => return Err(format!("Invalid error code: {}", value)), + }) + } +} + impl HetznerDNSAPIClient { pub fn new(token: String) -> Self { HetznerDNSAPIClient { @@ -87,13 +67,13 @@ impl HetznerDNSAPIClient { } } - pub async fn api_call<'a, T, U, I, K, V>( + async fn api_call<'a, T, U, I, K, V>( &self, url: &'a str, method: Method, query: Option, payload: Option, - ) -> Result + ) -> Result where T: for<'de> Deserialize<'de>, U: Serialize, @@ -106,6 +86,7 @@ impl HetznerDNSAPIClient { method, self.host.join(url).map_err(|e| { println!("url formatting error: {}", e); + HetznerDNSAPIError::UnprocessableEntity })?, ); req.headers_mut().append( @@ -117,9 +98,16 @@ impl HetznerDNSAPIClient { serde_json::to_string(&payload) .map_err(|e| { println!("body encoding error: {}", e); + HetznerDNSAPIError::UnprocessableEntity })? .into(), ); + //println!("{}", serde_json::to_string_pretty(&payload).unwrap()); + } + if let Some(query) = query { + req.url_mut() + .query_pairs_mut() + .extend_pairs(query.into_iter()); } let t = self .client @@ -127,18 +115,19 @@ impl HetznerDNSAPIClient { .await .map_err(|e| { println!("request execution error: {}", e); + HetznerDNSAPIError::UnprocessableEntity })? .error_for_status() - .map_err(|e| { - println!("request error: {}", e); - })? + .map_err(|e| HetznerDNSAPIError::try_from(e.status().unwrap()).unwrap())? .text() .await .map_err(|e| { println!("request decoding error: {}", e); + HetznerDNSAPIError::UnprocessableEntity })?; serde_json::from_str::(&t).map_err(|e| { println!("json response parsing error: {}", e); + HetznerDNSAPIError::UnprocessableEntity }) } @@ -148,7 +137,7 @@ impl HetznerDNSAPIClient { page: Option, per_page: Option, search_name: Option<&'a str>, - ) -> Result, ()> { + ) -> Result, HetznerDNSAPIError> { let result: ZoneResult = self .api_call( "zones", @@ -165,7 +154,11 @@ impl HetznerDNSAPIClient { Ok(result.zones) } - pub async fn create_zone(&self, name: String, ttl: Option) -> Result { + pub async fn create_zone( + &self, + name: String, + ttl: Option, + ) -> Result { self.api_call( "zones", Method::POST, @@ -175,7 +168,7 @@ impl HetznerDNSAPIClient { .await } - pub async fn get_zone(&self, id: String) -> Result { + pub async fn get_zone(&self, id: String) -> Result { self.api_call( format!("zones/{}", id).as_str(), Method::GET, @@ -190,7 +183,7 @@ impl HetznerDNSAPIClient { id: String, name: String, ttl: Option, - ) -> Result { + ) -> Result { self.api_call( format!("zones/{}", id).as_str(), Method::PUT, @@ -200,14 +193,16 @@ impl HetznerDNSAPIClient { .await } - pub async fn delete_zone(&self, id: String) -> Result<(), ()> { - self.api_call( - format!("zones/{}", id).as_str(), - Method::DELETE, - None::<&[(&str, &str); 0]>, - None::<&str>, - ) - .await? + pub async fn delete_zone(&self, id: String) -> Result<(), HetznerDNSAPIError> { + let result: () = self + .api_call( + format!("zones/{}", id).as_str(), + Method::DELETE, + None::<&[(&str, &str); 0]>, + None::<&str>, + ) + .await?; + Ok(result) } pub async fn import_zone() { @@ -225,7 +220,7 @@ impl HetznerDNSAPIClient { page: Option, per_page: Option, zone_id: Option, - ) -> Result, ()> { + ) -> Result, HetznerDNSAPIError> { let result: RecordsResult = self .api_call( "records", @@ -237,12 +232,14 @@ impl HetznerDNSAPIClient { ]), None::, ) - .await - .map_err(|_| ())?; + .await?; Ok(result.records) } - pub async fn create_record(&self, payload: RecordPayload) -> Result { + pub async fn create_record( + &self, + payload: RecordPayload, + ) -> Result { let result: RecordResult = self .api_call( "records", @@ -250,12 +247,11 @@ impl HetznerDNSAPIClient { None::<[(&str, &str); 0]>, Some(payload), ) - .await - .map_err(|_| ())?; + .await?; Ok(result.record) } - pub async fn get_record(&self, record_id: String) -> Result { + pub async fn get_record(&self, record_id: String) -> Result { let result: RecordResult = self .api_call( format!("records/{}", record_id).as_str(), @@ -263,8 +259,7 @@ impl HetznerDNSAPIClient { None::<[(&str, &str); 0]>, None::, ) - .await - .map_err(|_| ())?; + .await?; Ok(result.record) } @@ -272,7 +267,7 @@ impl HetznerDNSAPIClient { &self, record_id: String, payload: RecordPayload, - ) -> Result { + ) -> Result { let result: RecordResult = self .api_call( format!("records/{}", record_id).as_str(), @@ -280,22 +275,26 @@ impl HetznerDNSAPIClient { None::<[(&str, &str); 0]>, Some(payload), ) - .await - .map_err(|_| ())?; + .await?; Ok(result.record) } - pub async fn delete_record(&self, record_id: String) -> Result<(), ()> { - self.api_call( - format!("records/{}", record_id).as_str(), - Method::DELETE, - None::<[(&str, &str); 0]>, - None::, - ) - .await? + pub async fn delete_record(&self, record_id: String) -> Result<(), HetznerDNSAPIError> { + let result: () = self + .api_call( + format!("records/{}", record_id).as_str(), + Method::DELETE, + None::<[(&str, &str); 0]>, + None::, + ) + .await?; + Ok(result) } - pub async fn create_records(&self, payloads: Vec) -> Result, ()> { + pub async fn create_records( + &self, + payloads: Vec, + ) -> Result, HetznerDNSAPIError> { let result: RecordsResult = self .api_call( "records/bulk", @@ -303,24 +302,22 @@ impl HetznerDNSAPIClient { None::<[(&str, &str); 0]>, Some(_RecordQuery { records: payloads }), ) - .await - .map_err(|_| ())?; + .await?; Ok(result.records) } pub async fn update_records( &self, - payloads: Vec<(String, RecordPayload)>, - ) -> Result, ()> { + payloads: Vec, + ) -> Result, HetznerDNSAPIError> { let result: RecordsResult = self .api_call( "records/bulk", Method::PUT, None::<[(&str, &str); 0]>, - Some(payloads), + Some(_RecordQuery { records: payloads }), ) - .await - .map_err(|_| ())?; + .await?; Ok(result.records) } } diff --git a/src/main.rs b/src/main.rs index ffb5876..391ef9e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,21 +24,18 @@ enum SubMode { Delete, } -#[derive(Debug)] struct ZoneContext { all: bool, zone: String, name: String, - ttl: u64, + ttl: Option, } -#[derive(Debug)] struct RecordContext { all: bool, records: Vec, } -#[derive(Debug)] struct Context { mode: Mode, submode: SubMode, @@ -57,16 +54,17 @@ async fn main() { all: false, zone: String::new(), name: String::new(), - ttl: 86400, + ttl: None, }, record_context: RecordContext { all: false, records: vec![RecordPayload { + id: None, zone_id: String::new(), r#type: RecordType::A, name: String::new(), value: String::new(), - ttl: 0, + ttl: None, }], }, }; @@ -110,11 +108,8 @@ async fn main() { ctx.zone_context.name = std::env::args().nth(idx + 1).unwrap() } "--ttl" => { - ctx.zone_context.ttl = std::env::args() - .nth(idx + 1) - .unwrap() - .parse() - .unwrap_or(86400) + ctx.zone_context.ttl = + std::env::args().nth(idx + 1).unwrap().parse().ok() } "--zone" => { ctx.zone_context.zone = std::env::args().nth(idx + 1).unwrap() @@ -148,11 +143,8 @@ async fn main() { .unwrap() } "--ttl" => { - ctx.record_context.records[0].ttl = std::env::args() - .nth(idx + 1) - .unwrap() - .parse() - .unwrap_or(86400) + ctx.record_context.records[0].ttl = + std::env::args().nth(idx + 1).unwrap().parse().ok() } _ => todo!(), } @@ -192,7 +184,7 @@ async fn main() { SubMode::Create => { if !ctx.zone_context.name.is_empty() { if let Ok(zone) = client - .create_zone(ctx.zone_context.name, Some(ctx.zone_context.ttl)) + .create_zone(ctx.zone_context.name, ctx.zone_context.ttl) .await { println!("{:#?}", zone); @@ -214,7 +206,7 @@ async fn main() { .update_zone( zones.into_iter().next().unwrap().id, ctx.zone_context.name, - Some(ctx.zone_context.ttl), + ctx.zone_context.ttl, ) .await { @@ -285,15 +277,17 @@ async fn main() { .unwrap()[0] .id; ctx.record_context.records[0].zone_id = zone.to_string(); - if let Ok(record) = client + match client .create_records( ctx.record_context .records + .clone() .into_iter() .flat_map(|r| { r.name .split(",") .map(|s| RecordPayload { + id: None, zone_id: r.zone_id.clone(), r#type: r.r#type.clone(), name: String::from(s), @@ -306,7 +300,61 @@ async fn main() { ) .await { - println!("{:#?}", record); + Ok(records) => records.into_iter().for_each(|r| println!("{}", r)), + Err(e) => match e { + HetznerDNSAPIError::UnprocessableEntity => { + eprintln!("Records already exist, updating instead..."); + if let Ok(existing) = client + .get_records(None, None, Some(zone.to_string())) + .await + { + existing + .clone() + .into_iter() + .for_each(|r| println!("{}", r)); + let records = client + .update_records( + existing + .into_iter() + .filter(|r| { + ctx.record_context + .records + .clone() + .into_iter() + .any(|o| { + o.name + .split(",") + .any(|s| s == r.name) + }) + }) + .map(|r| RecordPayload { + id: Some(r.id.clone()), + zone_id: r.zone_id, + r#type: r.r#type, + name: r.name.clone(), + value: ctx + .record_context + .records + .clone() + .into_iter() + .find(|o| { + o.name + .split(",") + .any(|s| s == r.name) + }) + .unwrap() + .value, + ttl: r.ttl, + }) + .collect::>(), + ) + .await + .unwrap(); + records.into_iter().for_each(|r| println!("{}", r)); + } + } + _ => eprintln!("{}", e), + }, } } } @@ -314,7 +362,7 @@ async fn main() { if !ctx.record_context.records.is_empty() { let mut records = vec![]; let mut updated_records = vec![]; - let mut records_iter = ctx.record_context.records.into_iter(); + let records_iter = ctx.record_context.records.into_iter(); for zone in records_iter .clone() .map(|r| r.zone_id) @@ -331,29 +379,55 @@ async fn main() { client.get_records(None, None, Some(zone.id)).await.unwrap(), ); } - for old_record in records { - if let Some(new_record) = records_iter.find(|r| { - r.name == old_record.name - || r.name - .split(",") - .find(|s| *s == old_record.name) - .is_some() - }) { - updated_records.push(( - old_record.id, - RecordPayload { - zone_id: old_record.zone_id, - r#type: old_record.r#type, - name: old_record.name, - value: new_record.value, + for old_record in &records { + if let Some(new_record) = + records_iter.clone().find(|r| r.name == old_record.name) + { + updated_records.push(RecordPayload { + id: Some(old_record.id.clone()), + zone_id: old_record.zone_id.clone(), + r#type: old_record.r#type.clone(), + name: old_record.name.clone(), + value: new_record.value, + ttl: new_record.ttl, + }); + } else if let Some(new_record) = records_iter + .clone() + .find(|r| r.name.split(",").any(|s| s == old_record.name)) + { + for name in new_record.name.split(",") { + updated_records.push(RecordPayload { + id: Some(old_record.id.clone()), + zone_id: old_record.zone_id.clone(), + r#type: old_record.r#type.clone(), + name: name.to_string(), + value: new_record.value.clone(), ttl: new_record.ttl, - }, - )); + }); + } } } if !updated_records.is_empty() { - let records = client.update_records(updated_records).await.unwrap(); - eprintln!("Updated {} records", records.len()); + let new_records = + client.update_records(updated_records.clone()).await; + match new_records { + Ok(records) => { + records.into_iter().for_each(|r| println!("{}", r)) + } + Err(e) => match e { + HetznerDNSAPIError::UnprocessableEntity => { + eprintln!( + "Updating failed, trying to create them first." + ); + let records = client + .create_records(updated_records) + .await + .unwrap(); + records.into_iter().for_each(|r| println!("{}", r)); + } + _ => eprintln!("{}", e), + }, + } } else { eprintln!( "No records found that require updating. Did you mean to create/c the records instead?" diff --git a/src/models.rs b/src/models.rs index 2d50ab0..765132e 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,7 +1,7 @@ -use crate::client::RecordType; +#[warn(unused)] use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; -use std::borrow::Borrow; +use std::fmt; #[derive(Debug, Deserialize)] pub struct TxtVerification { @@ -9,7 +9,7 @@ pub struct TxtVerification { pub token: String, } -#[derive(Debug, Deserialize)] +#[derive(Deserialize)] pub struct Pagination { pub page: u32, pub per_page: u32, @@ -17,7 +17,7 @@ pub struct Pagination { pub total_entries: u32, } -#[derive(Debug, Deserialize)] +#[derive(Deserialize)] pub struct Meta { pub pagination: Pagination, } @@ -46,22 +46,89 @@ pub struct Zone { pub txt_verification: TxtVerification, } -#[derive(Debug, Deserialize)] +#[derive(Deserialize)] pub struct ZoneResult { pub zones: Vec, pub meta: Meta, } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "UPPERCASE")] +pub enum RecordType { + A, + Aaaa, + Ns, + Mx, + Cname, + Rp, + Txt, + Soa, + Hinfo, + Srv, + Dane, + Tlsa, + Ds, + Caa, +} + +impl fmt::Display for RecordType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + Self::A => "A", + Self::Aaaa => "AAAA", + Self::Ns => "NS", + Self::Mx => "MX", + Self::Cname => "CNAME", + Self::Rp => "RP", + Self::Txt => "TXT", + Self::Soa => "SOA", + Self::Hinfo => "HINFO", + Self::Srv => "SRV", + Self::Dane => "DANE", + Self::Tlsa => "TLSA", + Self::Ds => "DS", + Self::Caa => "CAA", + } + ) + } +} + +impl TryFrom<&str> for RecordType { + type Error = &'static str; + fn try_from(value: &str) -> Result { + Ok(match value { + "A" => Self::A, + "AAAA" => Self::Aaaa, + "NS" => Self::Ns, + "MX" => Self::Mx, + "CNAME" => Self::Cname, + "RP" => Self::Rp, + "TXT" => Self::Txt, + "SOA" => Self::Soa, + "HINFO" => Self::Hinfo, + "SRV" => Self::Srv, + "DANE" => Self::Dane, + "TLSA" => Self::Tlsa, + "DS" => Self::Ds, + "CAA" => Self::Caa, + _ => return Err(""), + }) + } +} +#[derive(Clone, Serialize, Deserialize)] pub struct RecordPayload { + pub id: Option, pub zone_id: String, pub r#type: RecordType, pub name: String, pub value: String, - pub ttl: u64, + pub ttl: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Record { pub id: String, #[serde(with = "hetzner_date")] @@ -75,12 +142,25 @@ pub struct Record { pub ttl: Option, } -#[derive(Debug, Serialize, Deserialize)] +impl fmt::Display for Record { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} | {} | {} | {}", + self.name, + self.r#type, + self.value, + self.ttl.unwrap_or_default() + ) + } +} + +#[derive(Serialize, Deserialize)] pub struct RecordResult { pub record: Record, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] pub struct RecordsResult { pub records: Vec, }